diff --git a/cacheproxy/cmd/cacheproxy/main.go b/cacheproxy/cmd/cacheproxy/main.go new file mode 100644 index 0000000..4705023 --- /dev/null +++ b/cacheproxy/cmd/cacheproxy/main.go @@ -0,0 +1,41 @@ +package main + +import ( + "flag" + "log" + "net/http" + "net/url" + "os" + + "github.com/gorilla/handlers" + "github.com/uovobw/httpcache/cacheproxy" +) + +var ( + bindTo = flag.String("bind", "0.0.0.0:8080", "address to bind to") + debug = flag.Bool("debug", false, "enable debugging") + target = flag.String("target", "", "base url to cache") +) + +func init() { + flag.Parse() + flag.VisitAll(func(f *flag.Flag) { + log.Printf("%s=%v", f.Name, f.Value) + }) + if *target == "" { + log.Fatalln("you must specify a target url") + } + if *debug { + logger := log.New(os.Stdout, "", 0) + cacheproxy.SetLogger(logger) + } +} + +func main() { + URL, err := url.Parse(*target) + if err != nil { + log.Fatal(err) + } + proxy := cacheproxy.NewSingleHostReverseProxy(URL) + log.Fatal(http.ListenAndServe(*bindTo, handlers.CombinedLoggingHandler(os.Stdout, proxy))) +} diff --git a/cacheproxy/proxy.go b/cacheproxy/proxy.go new file mode 100644 index 0000000..7cc1234 --- /dev/null +++ b/cacheproxy/proxy.go @@ -0,0 +1,32 @@ +package cacheproxy + +import ( + "github.com/uovobw/httpcache" + "log" + "net/http" + "net/http/httputil" + "net/url" +) + +var ( + memoryCache = httpcache.NewMemoryCache() + transport = httpcache.NewTransport(memoryCache) +) + +// NewSingleHostReverseProxy wraps net/http/httputil.NewSingleHostReverseProxy +// and sets the Host header based on the target URL. +func NewSingleHostReverseProxy(url *url.URL) *httputil.ReverseProxy { + proxy := httputil.NewSingleHostReverseProxy(url) + oldDirector := proxy.Director + proxy.Director = func(r *http.Request) { + oldDirector(r) + r.Host = url.Host + } + proxy.Transport = transport + return proxy +} + +// SetLogger wraps httpcache.SetLogger +func SetLogger(l *log.Logger) { + transport.SetLogger(l) +} diff --git a/diskcache/diskcache_test.go b/diskcache/diskcache_test.go index 896a341..dd70f95 100644 --- a/diskcache/diskcache_test.go +++ b/diskcache/diskcache_test.go @@ -5,20 +5,12 @@ import ( "io/ioutil" "os" "testing" - - . "gopkg.in/check.v1" ) -func Test(t *testing.T) { TestingT(t) } - -type S struct{} - -var _ = Suite(&S{}) - -func (s *S) Test(c *C) { +func TestDiskCache(t *testing.T) { tempDir, err := ioutil.TempDir("", "httpcache") if err != nil { - c.Fatalf("TempDir,: %v", err) + t.Fatalf("TempDir,: %v", err) } defer os.RemoveAll(tempDir) @@ -27,17 +19,25 @@ func (s *S) Test(c *C) { key := "testKey" _, ok := cache.Get(key) - c.Assert(ok, Equals, false) + if ok != false { + t.Fatal("Get() without Add()") + } val := []byte("some bytes") cache.Set(key, val) retVal, ok := cache.Get(key) - c.Assert(ok, Equals, true) - c.Assert(bytes.Equal(retVal, val), Equals, true) + if ok != true { + t.Fatal("did not retrieve the key i just set") + } + if bytes.Equal(retVal, val) != true { + t.Fatal("retrieved value not equal to the stored one") + } cache.Delete(key) _, ok = cache.Get(key) - c.Assert(ok, Equals, false) + if ok != false { + t.Fatal("Delete() key still present") + } } diff --git a/httpcache.go b/httpcache.go index c563cf6..96c05fb 100644 --- a/httpcache.go +++ b/httpcache.go @@ -11,8 +11,12 @@ import ( "bytes" "errors" "fmt" + "io/ioutil" + "log" "net/http" "net/http/httputil" + "regexp" + "strconv" "strings" "sync" "time" @@ -23,9 +27,21 @@ const ( fresh transparent // XFromCache is the header added to responses that are returned from the cache - XFromCache = "X-From-Cache" + XFromCache = "X-From-Cache" + rangeSeparator = "-" + rangeTypeSeparator = "=" ) +var ( + logger *log.Logger + bytesRangeRegexp *regexp.Regexp +) + +func init() { + logger = log.New(ioutil.Discard, "httpcache", 0) + bytesRangeRegexp = regexp.MustCompile("bytes=([0-9]*)-([0-9]*)") +} + // A Cache interface is used by the Transport to store and retrieve responses. type Cache interface { // Get returns the []byte representation of a cached response and a bool @@ -51,7 +67,142 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) } b := bytes.NewBuffer(cachedVal) - return http.ReadResponse(bufio.NewReader(b), req) + returnResponse, err := http.ReadResponse(bufio.NewReader(b), req) + if err != nil { + return nil, fmt.Errorf("error loading response from cache: %s\n", err.Error()) + } + + if req.Header.Get("range") != "" { + strContentLength := returnResponse.Header.Get("content-length") + contentLength, err := strconv.ParseInt(strContentLength, 10, 64) + if err != nil { + return nil, fmt.Errorf("response loaded from cache has null or malformed content-length: %d", contentLength) + } + rangeRequestStart, rangeRequestEnd, err := findRanges(req, contentLength) + if err != nil { + return nil, err + } + if !validateRanges(rangeRequestStart, rangeRequestEnd, returnResponse) { + return nil, nil + } + + body, err := ioutil.ReadAll(returnResponse.Body) + if err != nil { + logger.Printf("error reading cached response body: %s", err.Error()) + return returnResponse, nil + } + returnResponse.Body.Close() + + returnResponse.Body = ioutil.NopCloser(bytes.NewReader(body[rangeRequestStart:rangeRequestEnd])) + returnResponse.Header.Set("content-range", fmt.Sprintf("bytes %d-%d/%d", rangeRequestStart, rangeRequestEnd, contentLength)) + } + return returnResponse, nil +} + +// findRanges parses the range header value and the content length and returns (start, end, error) +// it is not exported since it does not implement multiple ranges and only accepts bytes= ranges +func findRanges(r *http.Request, totalLength int64) (start, end int64, err error) { + rawRange := r.Header.Get("range") + if rawRange == "" { + return -1, -1, fmt.Errorf("not a ranged request") + } + if !strings.HasPrefix(rawRange, "bytes=") { + return -1, -1, fmt.Errorf("non-bytes request %s range type unsupported", rawRange) + } + if strings.Contains(rawRange, ",") { + return -1, -1, fmt.Errorf("unsupported multiple ranges: %s", rawRange) + } + matchedValues := bytesRangeRegexp.FindStringSubmatch(rawRange)[1:] + strStart := matchedValues[0] + strEnd := matchedValues[1] + // range in the form STRSTART- + if strEnd == "" { + end = totalLength + start, err = strconv.ParseInt(strStart, 10, 64) + if err != nil { + return -1, -1, err + } + // range in the form -STREND + } else if strStart == "" { + end = totalLength + start, err = strconv.ParseInt(strEnd, 10, 64) + if err != nil { + return -1, -1, err + } + start = totalLength - start + // range in the form STRSTART-STREND + } else { + start, err = strconv.ParseInt(strStart, 10, 64) + if err != nil { + return -1, -1, err + } + end, err = strconv.ParseInt(strEnd, 10, 64) + if err != nil { + return -1, -1, err + } + } + if start >= end { + return -1, -1, fmt.Errorf("invalid start %d >= end %d", start, end) + } + return start, end, nil +} + +// validateRanges checks that a cached request for a given response is within data that has been already loaded +func validateRanges(start, end int64, resp *http.Response) (ok bool) { + // if the response cites partial content we need to compare the partial content we have stored + // with the ranges we require and, if not compatbile, fetch it again + // http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html section 14.16 + if resp.StatusCode == http.StatusPartialContent { + rawContentRange := resp.Header.Get("content-range") + if rawContentRange == "" { + logger.Printf("request for %s is stored as 206-partial-content but has no content-range header!", resp.Request.URL.String()) + return false + } + if !strings.Contains(rawContentRange, "bytes") { + logger.Printf("non-satisfiable range type in %s", rawContentRange) + return false + } + // the format is START-END/TOTAL or START-END/* if TOTAL is unknown + // if we find * we re-fetch the request as most probably the content was ephemereal + // or is highly probable iot has changed + if strings.Contains(rawContentRange, "*") { + return false + } + re := regexp.MustCompile("bytes ([0-9]+)-([0-9]+)/([0-9]+)") + // the first element is always the full match, skip it + matchedValues := re.FindStringSubmatch(rawContentRange)[1:] + currentStart, err := strconv.ParseInt(matchedValues[0], 10, 64) + if err != nil { + logger.Printf("cached response has malformed content-range header %s", rawContentRange) + return false + } + currentEnd, err := strconv.ParseInt(matchedValues[1], 10, 64) + if err != nil { + logger.Printf("cached response has malformed content-range header %s", rawContentRange) + return false + } + total, err := strconv.ParseInt(matchedValues[2], 10, 64) + if err != nil { + logger.Printf("cached response has malformed content-range header %s", rawContentRange) + return false + } + // validate the request ranges against the response headers + if start < currentStart || end > currentEnd || end > total { + logger.Printf("start: %d, currentStart: %d, end: %d, currentEnd: %d, total: %d", start, currentStart, end, currentEnd, total) + return false + } + return true + // the response is full content, use the content-length header to verify ranges + } + contentLength, err := strconv.ParseInt(resp.Header.Get("content-length"), 10, 64) + if err != nil { + logger.Printf("stored response has malformed or invalid content length %d", contentLength) + return false + } + if end > contentLength { + return false + } + return true } // MemoryCache is an implemtation of Cache that stores responses in an in-memory map. @@ -106,6 +257,12 @@ func NewTransport(c Cache) *Transport { return &Transport{Cache: c, MarkCachedResponses: true} } +// SetLogger takes a *log.Logger and replaces the current one that discards all messages +// this method is not thread safe +func (t *Transport) SetLogger(l *log.Logger) { + logger = l +} + // Client returns an *http.Client that caches responses. func (t *Transport) Client() *http.Client { return &http.Client{Transport: t} @@ -138,6 +295,9 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error var cachedResp *http.Response if cacheableMethod { cachedResp, err = CachedResponse(t.Cache, req) + if err != nil { + fmt.Print(err) + } } else { // Need to invalidate an existing value t.Cache.Delete(cacheKey) @@ -148,7 +308,11 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error transport = http.DefaultTransport } - if cachedResp != nil && err == nil && cacheableMethod && req.Header.Get("range") == "" { + if !cacheableMethod { + return transport.RoundTrip(req) + } + + if cachedResp != nil && err == nil && cacheableMethod { if t.MarkCachedResponses { cachedResp.Header.Set(XFromCache, "1") } @@ -362,7 +526,7 @@ func getEndToEndHeaders(respHeaders http.Header) []string { } } endToEndHeaders := []string{} - for respHeader, _ := range respHeaders { + for respHeader := range respHeaders { if _, ok := hopByHopHeaders[respHeader]; !ok { endToEndHeaders = append(endToEndHeaders, respHeader) } diff --git a/httpcache_test.go b/httpcache_test.go index f78acb0..54c6631 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -2,25 +2,22 @@ package httpcache import ( "fmt" + "io/ioutil" "net/http" "net/http/httptest" "strconv" "testing" "time" - - . "gopkg.in/check.v1" ) -var _ = fmt.Print - -func Test(t *testing.T) { TestingT(t) } - type S struct { server *httptest.Server client http.Client transport *Transport } +var s S + type fakeClock struct { elapsed time.Duration } @@ -29,12 +26,11 @@ func (c *fakeClock) since(t time.Time) time.Duration { return c.elapsed } -var _ = Suite(&S{}) - -func (s *S) SetUpSuite(c *C) { - t := NewMemoryCacheTransport() - client := http.Client{Transport: t} - s.transport = t +func setup() { + s = S{} + tp := NewMemoryCacheTransport() + client := http.Client{Transport: tp} + s.transport = tp s.client = client mux := http.NewServeMux() @@ -90,6 +86,18 @@ func (s *S) SetUpSuite(c *C) { w.Header().Set("Vary", "X-Madeup-Header") w.Write([]byte("Some text content")) })) + mux.HandleFunc("/ranged", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testData := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "text/plain") + start, end, err := findRanges(r, int64(len(testData))) + if err == nil { + w.Header().Set("content-range", fmt.Sprintf("bytes %d-%d/%d", start, end, len(testData))) + w.Write([]byte(testData)[start:end]) + } else { + w.Write([]byte(testData)) + } + })) updateFieldsCounter := 0 mux.HandleFunc("/updatefields", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -104,153 +112,397 @@ func (s *S) SetUpSuite(c *C) { })) } -func (s *S) TearDownSuite(c *C) { +func tearDownTest() { + s.transport.Cache = NewMemoryCache() + clock = &realClock{} s.server.Close() } -func (s *S) TearDownTest(c *C) { - s.transport.Cache = NewMemoryCache() - clock = &realClock{} +func TestSuffixRangedQuery(t *testing.T) { + setup() + defer tearDownTest() + req, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) + if err != nil { + t.FailNow() + } + resp, err := s.client.Do(req) + defer resp.Body.Close() + if err != nil { + t.FailNow() + } + data, err := ioutil.ReadAll(resp.Body) + if len(data) != 52 || string(data) != "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" || resp.Header.Get(XFromCache) != "" { + t.FailNow() + } + + req.Header.Add("Range", "bytes=10-") + resp2, err := s.client.Do(req) + defer resp2.Body.Close() + if err != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } + data2, err := ioutil.ReadAll(resp2.Body) + if len(data2) != 42 || string(data2) != "KLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" || resp2.Header.Get("content-range") != "bytes 10-52/52" { + t.FailNow() + } } -func (s *S) TestGetOnlyIfCachedHit(c *C) { +func TestPrefixRangedQuery(t *testing.T) { + setup() + defer tearDownTest() + req, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) + if err != nil { + t.FailNow() + } + resp, err := s.client.Do(req) + defer resp.Body.Close() + if err != nil { + t.FailNow() + } + data, err := ioutil.ReadAll(resp.Body) + failedTest := err != nil || + len(data) != 52 || + string(data) != "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" || + resp.Header.Get("content-range") != "" || + resp.Header.Get(XFromCache) != "" + if failedTest { + t.FailNow() + } + + req.Header.Add("Range", "bytes=-10") + resp2, err := s.client.Do(req) + defer resp2.Body.Close() + if err != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } + data2, err := ioutil.ReadAll(resp2.Body) + failedTest = err != nil || + len(data2) != 10 || + string(data2) != "qrstuvwxyz" || + resp2.Header.Get("content-range") != "bytes 42-52/52" + if failedTest { + t.FailNow() + } +} + +func TestCompleteRangedQuery(t *testing.T) { + setup() + defer tearDownTest() + req, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) + if err != nil { + t.FailNow() + } + req.Header.Add("Range", "bytes=0-10") + resp, err := s.client.Do(req) + defer resp.Body.Close() + if err != nil { + t.FailNow() + } + data, err := ioutil.ReadAll(resp.Body) + if err != nil || len(data) != 10 || string(data) != "ABCDEFGHIJ" { + t.FailNow() + } + resp2, err := s.client.Do(req) + defer resp2.Body.Close() + if err != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } + data2, err := ioutil.ReadAll(resp2.Body) + if len(data2) != 10 || string(data2) != "ABCDEFGHIJ" || resp2.Header.Get("content-range") != "bytes 0-10/10" { + t.FailNow() + } +} + +func TestPartialSubrangeRangedQuery(t *testing.T) { + setup() + defer tearDownTest() + req, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) + if err != nil { + t.FailNow() + } + req.Header.Add("Range", "bytes=0-10") + resp, err := s.client.Do(req) + if err != nil { + t.FailNow() + } + defer resp.Body.Close() + data, err := ioutil.ReadAll(resp.Body) + if len(data) != 10 || string(data) != "ABCDEFGHIJ" || resp.Header.Get("content-range") != "bytes 0-10/52" { + t.FailNow() + } + + req2, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) + if err != nil { + t.FailNow() + } + req2.Header.Add("Range", "bytes=4-6") + resp2, err := s.client.Do(req2) + if err != nil { + t.FailNow() + } + defer resp2.Body.Close() + if resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } + data2, err := ioutil.ReadAll(resp2.Body) + failedTest := err != nil || + len(data2) != 2 || + string(data2) != "EF" || + resp2.Header.Get("content-range") != "bytes 4-6/10" + if failedTest { + t.FailNow() + } + + // test failing subrange outside previously held one + req3, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) + if err != nil { + t.FailNow() + } + req3.Header.Add("Range", "bytes=8-15") + resp3, err := s.client.Do(req3) + defer resp3.Body.Close() + if err != nil || resp3.Header.Get(XFromCache) != "" { + t.FailNow() + } + data3, err := ioutil.ReadAll(resp3.Body) + failedTest = err != nil || + len(data3) != 7 || + string(data3) != "IJKLMNO" || + resp3.Header.Get("content-range") != "bytes 8-15/52" + if failedTest { + t.FailNow() + } +} + +func TestMultipleSubrangeRangedQuery(t *testing.T) { + setup() + defer tearDownTest() + req, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) + if err != nil { + t.FailNow() + } + req.Header.Add("Range", "bytes=0-10,15-40") + resp, err := s.client.Do(req) + if err != nil { + t.FailNow() + } + defer resp.Body.Close() + data, err := ioutil.ReadAll(resp.Body) + failedTest := err != nil || + len(data) != 52 || + string(data) != "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + if failedTest { + t.FailNow() + } +} + +func TestGetOnlyIfCachedHit(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL, nil) - c.Assert(err, IsNil) + if err != nil { + t.FailNow() + } resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err.Error()) + } defer resp.Body.Close() - c.Assert(resp.Header.Get(XFromCache), Equals, "") + if resp.Header.Get(XFromCache) != "" { + t.FailNow() + } req2, err2 := http.NewRequest("GET", s.server.URL, nil) req2.Header.Add("cache-control", "only-if-cached") + if err2 != nil { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") - c.Assert(resp2.StatusCode, Equals, 200) + if err2 != nil || resp2.Header.Get(XFromCache) != "1" || resp2.StatusCode != http.StatusOK { + t.FailNow() + } } -func (s *S) TestGetOnlyIfCachedMiss(c *C) { +func TestGetOnlyIfCachedMiss(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.FailNow() + } req.Header.Add("cache-control", "only-if-cached") resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get(XFromCache), Equals, "") - c.Assert(resp.StatusCode, Equals, 504) + if err != nil || resp.Header.Get(XFromCache) != "" || resp.StatusCode != 504 { + t.FailNow() + } } -func (s *S) TestGetNoStoreRequest(c *C) { +func TestGetNoStoreRequest(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.FailNow() + } req.Header.Add("Cache-Control", "no-store") resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get(XFromCache), Equals, "") + if err != nil || resp.Header.Get(XFromCache) != "" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "") + if err2 != nil || resp2.Header.Get(XFromCache) != "" { + t.FailNow() + } } -func (s *S) TestGetNoStoreResponse(c *C) { +func TestGetNoStoreResponse(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) + if err != nil { + t.FailNow() + } resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get(XFromCache), Equals, "") + if err != nil || resp.Header.Get(XFromCache) != "" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "") + if err2 != nil || resp2.Header.Get(XFromCache) != "" { + t.FailNow() + } } -func (s *S) TestGetWithEtag(c *C) { +func TestGetWithEtag(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) + if err != nil { + t.FailNow() + } resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get(XFromCache), Equals, "") + if err != nil || resp.Header.Get(XFromCache) != "" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") - + if err2 != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } // additional assertions to verify that 304 response is converted properly - c.Assert(resp2.Status, Equals, "200 OK") + if resp2.Status != "200 OK" { + t.FailNow() + } + _, ok := resp2.Header["Connection"] - c.Assert(ok, Equals, false) + if ok { + t.FailNow() + } } -func (s *S) TestGetWithLastModified(c *C) { +func TestGetWithLastModified(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) + if err != nil { + t.FailNow() + } resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get(XFromCache), Equals, "") + if err != nil || resp.Header.Get(XFromCache) != "" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + if err2 != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } } -func (s *S) TestGetWithVary(c *C) { +func TestGetWithVary(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) + if err != nil { + t.FailNow() + } req.Header.Set("Accept", "text/plain") resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get("Vary"), Equals, "Accept") + if err != nil || resp.Header.Get("Vary") != "Accept" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + if err2 != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } req.Header.Set("Accept", "text/html") resp3, err3 := s.client.Do(req) defer resp3.Body.Close() - c.Assert(err3, IsNil) - c.Assert(resp3.Header.Get(XFromCache), Equals, "") + if err3 != nil || resp3.Header.Get(XFromCache) != "" { + t.FailNow() + } req.Header.Set("Accept", "") resp4, err4 := s.client.Do(req) defer resp4.Body.Close() - c.Assert(err4, IsNil) - c.Assert(resp4.Header.Get(XFromCache), Equals, "") + if err4 != nil || resp4.Header.Get(XFromCache) != "" { + t.FailNow() + } } -func (s *S) TestGetWithDoubleVary(c *C) { +func TestGetWithDoubleVary(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) + if err != nil { + t.FailNow() + } req.Header.Set("Accept", "text/plain") req.Header.Set("Accept-Language", "da, en-gb;q=0.8, en;q=0.7") resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get("Vary"), Not(Equals), "") + if err != nil || resp.Header.Get("Vary") == "" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + if err2 != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } req.Header.Set("Accept-Language", "") resp3, err3 := s.client.Do(req) defer resp3.Body.Close() - c.Assert(err3, IsNil) - c.Assert(resp3.Header.Get(XFromCache), Equals, "") + if err3 != nil || resp3.Header.Get(XFromCache) != "" { + t.FailNow() + } req.Header.Set("Accept-Language", "da") resp4, err4 := s.client.Do(req) defer resp4.Body.Close() - c.Assert(err4, IsNil) - c.Assert(resp4.Header.Get(XFromCache), Equals, "") + if err4 != nil || resp4.Header.Get(XFromCache) != "" { + t.FailNow() + } } -func (s *S) TestGetWith2VaryHeaders(c *C) { +func TestGetWith2VaryHeaders(t *testing.T) { + setup() + defer tearDownTest() // Tests that multiple Vary headers' comma-separated lists are // merged. See https://github.com/gregjones/httpcache/issues/27. const ( @@ -258,173 +510,239 @@ func (s *S) TestGetWith2VaryHeaders(c *C) { acceptLanguage = "da, en-gb;q=0.8, en;q=0.7" ) req, err := http.NewRequest("GET", s.server.URL+"/2varyheaders", nil) + if err != nil { + t.FailNow() + } req.Header.Set("Accept", accept) req.Header.Set("Accept-Language", acceptLanguage) resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get("Vary"), Not(Equals), "") + if err != nil || resp.Header.Get("Vary") == "" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + if err2 != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } req.Header.Set("Accept-Language", "") resp3, err3 := s.client.Do(req) defer resp3.Body.Close() - c.Assert(err3, IsNil) - c.Assert(resp3.Header.Get(XFromCache), Equals, "") + if err3 != nil || resp3.Header.Get(XFromCache) != "" { + t.FailNow() + } req.Header.Set("Accept-Language", "da") resp4, err4 := s.client.Do(req) defer resp4.Body.Close() - c.Assert(err4, IsNil) - c.Assert(resp4.Header.Get(XFromCache), Equals, "") + if err4 != nil || resp4.Header.Get(XFromCache) != "" { + t.FailNow() + } req.Header.Set("Accept-Language", acceptLanguage) req.Header.Set("Accept", "") resp5, err5 := s.client.Do(req) defer resp5.Body.Close() - c.Assert(err5, IsNil) - c.Assert(resp5.Header.Get(XFromCache), Equals, "") + if err5 != nil || resp5.Header.Get(XFromCache) != "" { + t.FailNow() + } req.Header.Set("Accept", "image/png") resp6, err6 := s.client.Do(req) defer resp6.Body.Close() - c.Assert(err6, IsNil) - c.Assert(resp6.Header.Get(XFromCache), Equals, "") + if err6 != nil || resp6.Header.Get(XFromCache) != "" { + t.FailNow() + } resp7, err7 := s.client.Do(req) defer resp7.Body.Close() - c.Assert(err7, IsNil) - c.Assert(resp7.Header.Get(XFromCache), Equals, "1") + if err7 != nil || resp7.Header.Get(XFromCache) != "1" { + t.FailNow() + } } -func (s *S) TestGetVaryUnused(c *C) { +func TestGetVaryUnused(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) + if err != nil { + t.FailNow() + } req.Header.Set("Accept", "text/plain") resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get("Vary"), Not(Equals), "") + if err != nil || resp.Header.Get("Vary") == "" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + if err2 != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } } -func (s *S) TestUpdateFields(c *C) { +func TestUpdateFields(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) + if err != nil { + t.FailNow() + } resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) + if err != nil { + t.FailNow() + } counter := resp.Header.Get("x-counter") resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + if err2 != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } counter2 := resp2.Header.Get("x-counter") - c.Assert(counter, Not(Equals), counter2) + if counter == counter2 { + t.FailNow() + } } -func (s *S) TestParseCacheControl(c *C) { +func TestParseCacheControl(t *testing.T) { + setup() + defer tearDownTest() h := http.Header{} for _ = range parseCacheControl(h) { - c.Fatal("cacheControl should be empty") + t.Fatal("cacheControl should be empty") } h.Set("cache-control", "no-cache") cc := parseCacheControl(h) if _, ok := cc["foo"]; ok { - c.Error("Value shouldn't exist") + t.Error("Value shouldn't exist") } if nocache, ok := cc["no-cache"]; ok { - c.Assert(nocache, Equals, "") + if nocache != "" { + t.FailNow() + } } h.Set("cache-control", "no-cache, max-age=3600") cc = parseCacheControl(h) - c.Assert(cc["no-cache"], Equals, "") - c.Assert(cc["max-age"], Equals, "3600") + if cc["no-cache"] != "" || cc["max-age"] != "3600" { + t.FailNow() + } } -func (s *S) TestNoCacheRequestExpiration(c *C) { +func TestNoCacheRequestExpiration(t *testing.T) { + setup() + defer tearDownTest() respHeaders := http.Header{} respHeaders.Set("Cache-Control", "max-age=7200") reqHeaders := http.Header{} reqHeaders.Set("Cache-Control", "no-cache") - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, transparent) + if getFreshness(respHeaders, reqHeaders) != transparent { + t.FailNow() + } } -func (s *S) TestNoCacheResponseExpiration(c *C) { +func TestNoCacheResponseExpiration(t *testing.T) { + setup() + defer tearDownTest() respHeaders := http.Header{} respHeaders.Set("Cache-Control", "no-cache") respHeaders.Set("Expires", "Wed, 19 Apr 3000 11:43:00 GMT") reqHeaders := http.Header{} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestReqMustRevalidate(c *C) { +func TestReqMustRevalidate(t *testing.T) { + setup() + defer tearDownTest() // not paying attention to request setting max-stale means never returning stale // responses, so always acting as if must-revalidate is set respHeaders := http.Header{} reqHeaders := http.Header{} reqHeaders.Set("Cache-Control", "must-revalidate") - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestRespMustRevalidate(c *C) { +func TestRespMustRevalidate(t *testing.T) { + setup() + defer tearDownTest() respHeaders := http.Header{} respHeaders.Set("Cache-Control", "must-revalidate") reqHeaders := http.Header{} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestFreshExpiration(c *C) { +func TestFreshExpiration(t *testing.T) { + setup() + defer tearDownTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) reqHeaders := http.Header{} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, fresh) + if getFreshness(respHeaders, reqHeaders) != fresh { + t.FailNow() + } clock = &fakeClock{elapsed: 3 * time.Second} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestMaxAge(c *C) { +func TestMaxAge(t *testing.T) { + setup() + defer tearDownTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) respHeaders.Set("cache-control", "max-age=2") reqHeaders := http.Header{} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, fresh) + if getFreshness(respHeaders, reqHeaders) != fresh { + t.FailNow() + } clock = &fakeClock{elapsed: 3 * time.Second} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestMaxAgeZero(c *C) { +func TestMaxAgeZero(t *testing.T) { + setup() + defer tearDownTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) respHeaders.Set("cache-control", "max-age=0") reqHeaders := http.Header{} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestBothMaxAge(c *C) { +func TestBothMaxAge(t *testing.T) { + setup() + defer tearDownTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -432,10 +750,14 @@ func (s *S) TestBothMaxAge(c *C) { reqHeaders := http.Header{} reqHeaders.Set("cache-control", "max-age=0") - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestMinFreshWithExpires(c *C) { +func TestMinFreshWithExpires(t *testing.T) { + setup() + defer tearDownTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -443,14 +765,20 @@ func (s *S) TestMinFreshWithExpires(c *C) { reqHeaders := http.Header{} reqHeaders.Set("cache-control", "min-fresh=1") - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, fresh) + if getFreshness(respHeaders, reqHeaders) != fresh { + t.FailNow() + } reqHeaders = http.Header{} reqHeaders.Set("cache-control", "min-fresh=2") - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestEmptyMaxStale(c *C) { +func TestEmptyMaxStale(t *testing.T) { + setup() + defer tearDownTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -461,14 +789,20 @@ func (s *S) TestEmptyMaxStale(c *C) { clock = &fakeClock{elapsed: 10 * time.Second} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, fresh) + if getFreshness(respHeaders, reqHeaders) != fresh { + t.FailNow() + } clock = &fakeClock{elapsed: 60 * time.Second} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, fresh) + if getFreshness(respHeaders, reqHeaders) != fresh { + t.FailNow() + } } -func (s *S) TestMaxStaleValue(c *C) { +func TestMaxStaleValue(t *testing.T) { + setup() + defer tearDownTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -478,15 +812,21 @@ func (s *S) TestMaxStaleValue(c *C) { reqHeaders.Set("cache-control", "max-stale=20") clock = &fakeClock{elapsed: 5 * time.Second} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, fresh) + if getFreshness(respHeaders, reqHeaders) != fresh { + t.FailNow() + } clock = &fakeClock{elapsed: 15 * time.Second} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, fresh) + if getFreshness(respHeaders, reqHeaders) != fresh { + t.FailNow() + } clock = &fakeClock{elapsed: 30 * time.Second} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } func containsHeader(headers []string, header string) bool { @@ -498,25 +838,8 @@ func containsHeader(headers []string, header string) bool { return false } -type containsHeaderChecker struct { - *CheckerInfo -} - -func (c *containsHeaderChecker) Check(params []interface{}, names []string) (bool, string) { - items, ok := params[0].([]string) - if !ok { - return false, "Expected first param to be []string" - } - value, ok := params[1].(string) - if !ok { - return false, "Expected 2nd param to be string" - } - return containsHeader(items, value), "" -} - -var ContainsHeader Checker = &containsHeaderChecker{&CheckerInfo{Name: "Contains", Params: []string{"Container", "expected to contain"}}} - -func (s *S) TestGetEndToEndHeaders(c *C) { +func TestGetEndToEndHeaders(t *testing.T) { + setup() var ( headers http.Header end2end []string @@ -527,24 +850,39 @@ func (s *S) TestGetEndToEndHeaders(c *C) { headers.Set("te", "deflate") end2end = getEndToEndHeaders(headers) - c.Check(end2end, ContainsHeader, "content-type") - c.Check(end2end, Not(ContainsHeader), "te") + if !containsHeader(end2end, "content-type") { + t.FailNow() + } + if containsHeader(end2end, "te") { + t.FailNow() + } headers = http.Header{} headers.Set("connection", "content-type") headers.Set("content-type", "text/csv") headers.Set("te", "deflate") end2end = getEndToEndHeaders(headers) - c.Check(end2end, Not(ContainsHeader), "connection") - c.Check(end2end, Not(ContainsHeader), "content-type") - c.Check(end2end, Not(ContainsHeader), "te") + if containsHeader(end2end, "connection") { + t.FailNow() + } + if containsHeader(end2end, "content-type") { + t.FailNow() + } + if containsHeader(end2end, "te") { + t.FailNow() + } headers = http.Header{} end2end = getEndToEndHeaders(headers) - c.Check(end2end, HasLen, 0) + if len(end2end) != 0 { + t.FailNow() + } headers = http.Header{} headers.Set("connection", "content-type") end2end = getEndToEndHeaders(headers) - c.Check(end2end, HasLen, 0) + if len(end2end) != 0 { + t.FailNow() + } + tearDownTest() } diff --git a/memcache/appengine_test.go b/memcache/appengine_test.go index f4e01a0..181d5e5 100644 --- a/memcache/appengine_test.go +++ b/memcache/appengine_test.go @@ -7,20 +7,12 @@ import ( "testing" "appengine/aetest" - - . "gopkg.in/check.v1" ) -func Test(t *testing.T) { TestingT(t) } - -type S struct{} - -var _ = Suite(&S{}) - -func (s *S) Test(c *C) { +func TestAppEngine(t *testing.T) { ctx, err := aetest.NewContext(nil) if err != nil { - c.Fatal(err) + t.Fatal(err) } defer ctx.Close() @@ -29,17 +21,25 @@ func (s *S) Test(c *C) { key := "testKey" _, ok := cache.Get(key) - c.Assert(ok, Equals, false) + if ok != false { + t.Fatal("could retrieve non existing key") + } val := []byte("some bytes") cache.Set(key, val) retVal, ok := cache.Get(key) - c.Assert(ok, Equals, true) - c.Assert(bytes.Equal(retVal, val), Equals, true) + if ok != true { + t.Fatal("could not retrieve key i just added") + } + if bytes.Equal(retVal, val) != true { + t.Fatal("retrieved something different from what i put in") + } cache.Delete(key) _, ok = cache.Get(key) - c.Assert(ok, Equals, false) + if ok != false { + t.Fatal("retrieved deleted key") + } } diff --git a/memcache/memcache_test.go b/memcache/memcache_test.go index 5c01900..7ea51fa 100644 --- a/memcache/memcache_test.go +++ b/memcache/memcache_test.go @@ -7,45 +7,52 @@ import ( "fmt" "net" "testing" - - . "gopkg.in/check.v1" ) const testServer = "localhost:11211" -func Test(t *testing.T) { TestingT(t) } - -type S struct{} - -var _ = Suite(&S{}) - -func (s *S) SetUpSuite(c *C) { +func SetUpSuite() bool { conn, err := net.Dial("tcp", testServer) if err != nil { // TODO: rather than skip the test, fall back to a faked memcached server - c.Skip(fmt.Sprintf("skipping test; no server running at %s", testServer)) + fmt.Sprintf("skipping test; no server running at %s", testServer) + return false } conn.Write([]byte("flush_all\r\n")) // flush memcache conn.Close() + return true } -func (s *S) Test(c *C) { +func TestMemCache(t *testing.T) { + if !SetUpSuite() { + t.SkipNow() + } cache := New(testServer) + if cache == recover() { + t.SkipNow() + } key := "testKey" _, ok := cache.Get(key) - - c.Assert(ok, Equals, false) + if ok != false { + t.Fatal("retrieved key before adding it") + } val := []byte("some bytes") cache.Set(key, val) retVal, ok := cache.Get(key) - c.Assert(ok, Equals, true) - c.Assert(bytes.Equal(retVal, val), Equals, true) + if ok != true { + t.Fatal("could not retrieve an element i just added") + } + if bytes.Equal(retVal, val) != true { + t.Fatal("retrieved a different thing than what i put in") + } cache.Delete(key) _, ok = cache.Get(key) - c.Assert(ok, Equals, false) + if ok != false { + t.Fatal("deleted key still present") + } }