From 0a093449159a63ea40c2c2644fe3f27e28250f10 Mon Sep 17 00:00:00 2001 From: Quinn Slack Date: Wed, 4 Feb 2015 11:02:37 -0800 Subject: [PATCH 01/20] Don't store Range requests --- httpcache.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/httpcache.go b/httpcache.go index c563cf6..b88a9ae 100644 --- a/httpcache.go +++ b/httpcache.go @@ -135,8 +135,9 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error req = cloneRequest(req) cacheKey := cacheKey(req) cacheableMethod := req.Method == "GET" || req.Method == "HEAD" + cacheable := cacheableMethod && req.Header.Get("range") == "" var cachedResp *http.Response - if cacheableMethod { + if cacheable { cachedResp, err = CachedResponse(t.Cache, req) } else { // Need to invalidate an existing value @@ -148,6 +149,10 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error transport = http.DefaultTransport } + if !cacheable { + return transport.RoundTrip(req) + } + if cachedResp != nil && err == nil && cacheableMethod && req.Header.Get("range") == "" { if t.MarkCachedResponses { cachedResp.Header.Set(XFromCache, "1") From 49a985f3b6e13365d38a241012c61d4e2c85fe0d Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Thu, 19 Feb 2015 09:52:15 +0100 Subject: [PATCH 02/20] Add range request handling --- httpcache.go | 56 ++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 6 deletions(-) diff --git a/httpcache.go b/httpcache.go index b88a9ae..6334374 100644 --- a/httpcache.go +++ b/httpcache.go @@ -11,8 +11,10 @@ import ( "bytes" "errors" "fmt" + "io/ioutil" "net/http" "net/http/httputil" + "strconv" "strings" "sync" "time" @@ -23,7 +25,9 @@ const ( fresh transparent // XFromCache is the header added to responses that are returned from the cache - XFromCache = "X-From-Cache" + XFromCache = "X-From-Cache" + RANGESEPARATOR = "-" + RANGETYPESEPARATOR = "=" ) // A Cache interface is used by the Transport to store and retrieve responses. @@ -51,7 +55,48 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) } b := bytes.NewBuffer(cachedVal) - return http.ReadResponse(bufio.NewReader(b), req) + returnResponse, err := http.ReadResponse(bufio.NewReader(b), req) + if err != nil { + return nil, fmt.Errorf("error loading response from cache: %s\n", err.Error()) + } + + rangeRaw := req.Header.Get("range") + if rangeRaw != "" { + tmp := strings.Split(rangeRaw, RANGETYPESEPARATOR) + // standard format is bytes=START-END + rangetype, rangevalue := tmp[0], tmp[1] + if rangetype != "bytes" { + fmt.Printf("range type %s not supported\n", rangetype) + return returnResponse, nil + } + // we need to read all the body now, close it, and replace it with another reader + // as there is currently no way of "resetting" a Body + body, err := ioutil.ReadAll(returnResponse.Body) + if err != nil { + fmt.Printf("error reading cached response body: %s\n", err.Error()) + return returnResponse, nil + } + returnResponse.Body.Close() + var rangedRequestStart, rangedRequestEnd int64 + //TODO(uovobw): handle corrupted/nonstandard request header + rangeList := strings.Split(rangevalue, RANGESEPARATOR) + // the range is in the form -VAL , the wanted range is (end-val)->end + if strings.HasPrefix(rangevalue, RANGESEPARATOR) { + rangedRequestEnd = int64(len(body)) + end, _ := strconv.ParseInt(rangeList[1], 10, 64) + rangedRequestStart = rangedRequestEnd - end + // the rang is in the form VAL-, the wanted range is val->end + } else if strings.HasSuffix(rangevalue, RANGESEPARATOR) { + rangedRequestStart, _ = strconv.ParseInt(rangeList[0], 10, 64) + rangedRequestEnd = int64(len(body)) + // normal case with START-END + } else { + rangedRequestStart, _ = strconv.ParseInt(rangeList[0], 10, 64) + rangedRequestEnd, _ = strconv.ParseInt(rangeList[1], 10, 64) + } + returnResponse.Body = ioutil.NopCloser(bytes.NewReader(body[rangedRequestStart:rangedRequestEnd])) + } + return returnResponse, nil } // MemoryCache is an implemtation of Cache that stores responses in an in-memory map. @@ -135,9 +180,8 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error req = cloneRequest(req) cacheKey := cacheKey(req) cacheableMethod := req.Method == "GET" || req.Method == "HEAD" - cacheable := cacheableMethod && req.Header.Get("range") == "" var cachedResp *http.Response - if cacheable { + if cacheableMethod { cachedResp, err = CachedResponse(t.Cache, req) } else { // Need to invalidate an existing value @@ -149,11 +193,11 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error transport = http.DefaultTransport } - if !cacheable { + if !cacheableMethod { return transport.RoundTrip(req) } - if cachedResp != nil && err == nil && cacheableMethod && req.Header.Get("range") == "" { + if cachedResp != nil && err == nil && cacheableMethod { if t.MarkCachedResponses { cachedResp.Header.Set(XFromCache, "1") } From 9c9fc5baacb56e664ec9ad29c37dcc8044cb02d3 Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Thu, 19 Feb 2015 10:35:04 +0100 Subject: [PATCH 03/20] Add tests for ranged queries --- httpcache_test.go | 60 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/httpcache_test.go b/httpcache_test.go index f78acb0..0dec3ed 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -2,6 +2,7 @@ package httpcache import ( "fmt" + "io/ioutil" "net/http" "net/http/httptest" "strconv" @@ -90,6 +91,11 @@ func (s *S) SetUpSuite(c *C) { w.Header().Set("Vary", "X-Madeup-Header") w.Write([]byte("Some text content")) })) + mux.HandleFunc("/ranged", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "text/plain") + w.Write([]byte("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")) + })) updateFieldsCounter := 0 mux.HandleFunc("/updatefields", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -113,6 +119,60 @@ func (s *S) TearDownTest(c *C) { clock = &realClock{} } +func (s *S) TestSuffixRangedQuery(c *C) { + req, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) + c.Assert(err, IsNil) + req.Header.Add("Range", "bytes=10-") + resp, err := s.client.Do(req) + defer resp.Body.Close() + c.Assert(err, IsNil) + data, err := ioutil.ReadAll(resp.Body) + c.Assert(len(data), Equals, 52) + resp2, err := s.client.Do(req) + defer resp2.Body.Close() + c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + c.Assert(err, IsNil) + data2, err := ioutil.ReadAll(resp2.Body) + c.Assert(len(data2), Equals, 42) + c.Assert(string(data2), Equals, "KLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") +} + +func (s *S) TestPrefixRangedQuery(c *C) { + req, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) + c.Assert(err, IsNil) + req.Header.Add("Range", "bytes=-10") + resp, err := s.client.Do(req) + defer resp.Body.Close() + c.Assert(err, IsNil) + data, err := ioutil.ReadAll(resp.Body) + c.Assert(len(data), Equals, 52) + resp2, err := s.client.Do(req) + defer resp2.Body.Close() + c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + c.Assert(err, IsNil) + data2, err := ioutil.ReadAll(resp2.Body) + c.Assert(len(data2), Equals, 10) + c.Assert(string(data2), Equals, "qrstuvwxyz") +} + +func (s *S) TestCompleteRangedQuery(c *C) { + req, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) + c.Assert(err, IsNil) + req.Header.Add("Range", "bytes=0-10") + resp, err := s.client.Do(req) + defer resp.Body.Close() + c.Assert(err, IsNil) + data, err := ioutil.ReadAll(resp.Body) + c.Assert(len(data), Equals, 52) + resp2, err := s.client.Do(req) + defer resp2.Body.Close() + c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + c.Assert(err, IsNil) + data2, err := ioutil.ReadAll(resp2.Body) + c.Assert(len(data2), Equals, 10) + c.Assert(string(data2), Equals, "ABCDEFGHIJ") +} + func (s *S) TestGetOnlyIfCachedHit(c *C) { req, err := http.NewRequest("GET", s.server.URL, nil) c.Assert(err, IsNil) From 16bd323623903befeff89d652be334ca21c9d9b0 Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Thu, 19 Feb 2015 15:16:24 +0100 Subject: [PATCH 04/20] Rename constants to follow http://www.reddit.com/r/golang/comments/2az1fz/convention_for_nonexported_constans/ --- httpcache.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/httpcache.go b/httpcache.go index 6334374..a15c533 100644 --- a/httpcache.go +++ b/httpcache.go @@ -26,8 +26,8 @@ const ( transparent // XFromCache is the header added to responses that are returned from the cache XFromCache = "X-From-Cache" - RANGESEPARATOR = "-" - RANGETYPESEPARATOR = "=" + rangeSeparator = "-" + rangeTypeSeparator = "=" ) // A Cache interface is used by the Transport to store and retrieve responses. @@ -62,7 +62,7 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) rangeRaw := req.Header.Get("range") if rangeRaw != "" { - tmp := strings.Split(rangeRaw, RANGETYPESEPARATOR) + tmp := strings.Split(rangeRaw, rangeTypeSeparator) // standard format is bytes=START-END rangetype, rangevalue := tmp[0], tmp[1] if rangetype != "bytes" { @@ -79,14 +79,14 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) returnResponse.Body.Close() var rangedRequestStart, rangedRequestEnd int64 //TODO(uovobw): handle corrupted/nonstandard request header - rangeList := strings.Split(rangevalue, RANGESEPARATOR) + rangeList := strings.Split(rangevalue, rangeSeparator) // the range is in the form -VAL , the wanted range is (end-val)->end - if strings.HasPrefix(rangevalue, RANGESEPARATOR) { + if strings.HasPrefix(rangevalue, rangeSeparator) { rangedRequestEnd = int64(len(body)) end, _ := strconv.ParseInt(rangeList[1], 10, 64) rangedRequestStart = rangedRequestEnd - end // the rang is in the form VAL-, the wanted range is val->end - } else if strings.HasSuffix(rangevalue, RANGESEPARATOR) { + } else if strings.HasSuffix(rangevalue, rangeSeparator) { rangedRequestStart, _ = strconv.ParseInt(rangeList[0], 10, 64) rangedRequestEnd = int64(len(body)) // normal case with START-END From db4602613deb330a725587a6ec4c906f14cd3e21 Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Thu, 19 Feb 2015 20:31:52 +0100 Subject: [PATCH 05/20] AppEngine test without gocheck --- memcache/appengine_test.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/memcache/appengine_test.go b/memcache/appengine_test.go index f4e01a0..181d5e5 100644 --- a/memcache/appengine_test.go +++ b/memcache/appengine_test.go @@ -7,20 +7,12 @@ import ( "testing" "appengine/aetest" - - . "gopkg.in/check.v1" ) -func Test(t *testing.T) { TestingT(t) } - -type S struct{} - -var _ = Suite(&S{}) - -func (s *S) Test(c *C) { +func TestAppEngine(t *testing.T) { ctx, err := aetest.NewContext(nil) if err != nil { - c.Fatal(err) + t.Fatal(err) } defer ctx.Close() @@ -29,17 +21,25 @@ func (s *S) Test(c *C) { key := "testKey" _, ok := cache.Get(key) - c.Assert(ok, Equals, false) + if ok != false { + t.Fatal("could retrieve non existing key") + } val := []byte("some bytes") cache.Set(key, val) retVal, ok := cache.Get(key) - c.Assert(ok, Equals, true) - c.Assert(bytes.Equal(retVal, val), Equals, true) + if ok != true { + t.Fatal("could not retrieve key i just added") + } + if bytes.Equal(retVal, val) != true { + t.Fatal("retrieved something different from what i put in") + } cache.Delete(key) _, ok = cache.Get(key) - c.Assert(ok, Equals, false) + if ok != false { + t.Fatal("retrieved deleted key") + } } From f0553fa36d878cdf36daaa9ac4816336d5b4db3c Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Thu, 19 Feb 2015 20:32:09 +0100 Subject: [PATCH 06/20] DiskCache test without gocheck --- diskcache/diskcache_test.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/diskcache/diskcache_test.go b/diskcache/diskcache_test.go index 896a341..dd70f95 100644 --- a/diskcache/diskcache_test.go +++ b/diskcache/diskcache_test.go @@ -5,20 +5,12 @@ import ( "io/ioutil" "os" "testing" - - . "gopkg.in/check.v1" ) -func Test(t *testing.T) { TestingT(t) } - -type S struct{} - -var _ = Suite(&S{}) - -func (s *S) Test(c *C) { +func TestDiskCache(t *testing.T) { tempDir, err := ioutil.TempDir("", "httpcache") if err != nil { - c.Fatalf("TempDir,: %v", err) + t.Fatalf("TempDir,: %v", err) } defer os.RemoveAll(tempDir) @@ -27,17 +19,25 @@ func (s *S) Test(c *C) { key := "testKey" _, ok := cache.Get(key) - c.Assert(ok, Equals, false) + if ok != false { + t.Fatal("Get() without Add()") + } val := []byte("some bytes") cache.Set(key, val) retVal, ok := cache.Get(key) - c.Assert(ok, Equals, true) - c.Assert(bytes.Equal(retVal, val), Equals, true) + if ok != true { + t.Fatal("did not retrieve the key i just set") + } + if bytes.Equal(retVal, val) != true { + t.Fatal("retrieved value not equal to the stored one") + } cache.Delete(key) _, ok = cache.Get(key) - c.Assert(ok, Equals, false) + if ok != false { + t.Fatal("Delete() key still present") + } } From 3b1bdb9e1624a692b3b2949e0b76c03479ad1b32 Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Thu, 19 Feb 2015 20:32:21 +0100 Subject: [PATCH 07/20] MemCache test without gocheck --- memcache/memcache_test.go | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/memcache/memcache_test.go b/memcache/memcache_test.go index 5c01900..4b8f568 100644 --- a/memcache/memcache_test.go +++ b/memcache/memcache_test.go @@ -7,45 +7,45 @@ import ( "fmt" "net" "testing" - - . "gopkg.in/check.v1" ) const testServer = "localhost:11211" -func Test(t *testing.T) { TestingT(t) } - -type S struct{} - -var _ = Suite(&S{}) - -func (s *S) SetUpSuite(c *C) { +func SetUpSuite() { conn, err := net.Dial("tcp", testServer) if err != nil { // TODO: rather than skip the test, fall back to a faked memcached server - c.Skip(fmt.Sprintf("skipping test; no server running at %s", testServer)) + fmt.Sprintf("skipping test; no server running at %s", testServer) } conn.Write([]byte("flush_all\r\n")) // flush memcache conn.Close() } -func (s *S) Test(c *C) { +func TestMemCache(t *testing.T) { + SetUpSuite() cache := New(testServer) key := "testKey" _, ok := cache.Get(key) - - c.Assert(ok, Equals, false) + if ok != false { + t.Fatal("retrieved key before adding it") + } val := []byte("some bytes") cache.Set(key, val) retVal, ok := cache.Get(key) - c.Assert(ok, Equals, true) - c.Assert(bytes.Equal(retVal, val), Equals, true) + if ok != true { + t.Fatal("could not retrieve an element i just added") + } + if bytes.Equal(retVal, val) != true { + t.Fatal("retrieved a different thing than what i put in") + } cache.Delete(key) _, ok = cache.Get(key) - c.Assert(ok, Equals, false) + if ok != false { + t.Fatal("deleted key still present") + } } From a79b54868dd65086901027917ab6aec401cb9894 Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Fri, 20 Feb 2015 16:49:36 +0100 Subject: [PATCH 08/20] Add module wide logger with configuration functions Clean up logging and avoid fmt --- httpcache.go | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/httpcache.go b/httpcache.go index a15c533..0b885f0 100644 --- a/httpcache.go +++ b/httpcache.go @@ -11,7 +11,9 @@ import ( "bytes" "errors" "fmt" + "io" "io/ioutil" + "log" "net/http" "net/http/httputil" "strconv" @@ -30,6 +32,12 @@ const ( rangeTypeSeparator = "=" ) +var logger *log.Logger + +func init() { + logger = log.New(ioutil.Discard, "httpcache", 0) +} + // A Cache interface is used by the Transport to store and retrieve responses. type Cache interface { // Get returns the []byte representation of a cached response and a bool @@ -66,14 +74,14 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) // standard format is bytes=START-END rangetype, rangevalue := tmp[0], tmp[1] if rangetype != "bytes" { - fmt.Printf("range type %s not supported\n", rangetype) + logger.Print("range type %s not supported", rangetype) return returnResponse, nil } // we need to read all the body now, close it, and replace it with another reader // as there is currently no way of "resetting" a Body body, err := ioutil.ReadAll(returnResponse.Body) if err != nil { - fmt.Printf("error reading cached response body: %s\n", err.Error()) + logger.Print("error reading cached response body: %s", err.Error()) return returnResponse, nil } returnResponse.Body.Close() @@ -83,11 +91,19 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) // the range is in the form -VAL , the wanted range is (end-val)->end if strings.HasPrefix(rangevalue, rangeSeparator) { rangedRequestEnd = int64(len(body)) - end, _ := strconv.ParseInt(rangeList[1], 10, 64) + end, err := strconv.ParseInt(rangeList[1], 10, 64) + if err != nil { + logger.Printf("error parsing range header %s: %s", rangeList[1], err.Error()) + return nil, err + } rangedRequestStart = rangedRequestEnd - end // the rang is in the form VAL-, the wanted range is val->end } else if strings.HasSuffix(rangevalue, rangeSeparator) { - rangedRequestStart, _ = strconv.ParseInt(rangeList[0], 10, 64) + rangedRequestStart, err = strconv.ParseInt(rangeList[0], 10, 64) + if err != nil { + logger.Printf("error parsing range header %s: %s", rangeList[1], err.Error()) + return nil, err + } rangedRequestEnd = int64(len(body)) // normal case with START-END } else { @@ -151,6 +167,12 @@ func NewTransport(c Cache) *Transport { return &Transport{Cache: c, MarkCachedResponses: true} } +// SetLogging has the same parameters as the log.New function and replaces the +// default logger that discards messages +func (t *Transport) SetLogging(out io.Writer, prefix string, flags int) { + logger = log.New(out, prefix, flags) +} + // Client returns an *http.Client that caches responses. func (t *Transport) Client() *http.Client { return &http.Client{Transport: t} From 772bb7b64aa369651973c1b47fc682ef4c6ca4eb Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Fri, 20 Feb 2015 16:52:46 +0100 Subject: [PATCH 09/20] Rename variables in camelCase --- httpcache.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/httpcache.go b/httpcache.go index 0b885f0..a6c9aa7 100644 --- a/httpcache.go +++ b/httpcache.go @@ -72,9 +72,9 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) if rangeRaw != "" { tmp := strings.Split(rangeRaw, rangeTypeSeparator) // standard format is bytes=START-END - rangetype, rangevalue := tmp[0], tmp[1] - if rangetype != "bytes" { - logger.Print("range type %s not supported", rangetype) + rangeType, rangeValue := tmp[0], tmp[1] + if rangeType != "bytes" { + logger.Print("range type %s not supported", rangeType) return returnResponse, nil } // we need to read all the body now, close it, and replace it with another reader @@ -85,32 +85,32 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) return returnResponse, nil } returnResponse.Body.Close() - var rangedRequestStart, rangedRequestEnd int64 + var rangeRequestStart, rangeRequestEnd int64 //TODO(uovobw): handle corrupted/nonstandard request header - rangeList := strings.Split(rangevalue, rangeSeparator) + rangeList := strings.Split(rangeValue, rangeSeparator) // the range is in the form -VAL , the wanted range is (end-val)->end - if strings.HasPrefix(rangevalue, rangeSeparator) { - rangedRequestEnd = int64(len(body)) + if strings.HasPrefix(rangeValue, rangeSeparator) { + rangeRequestEnd = int64(len(body)) end, err := strconv.ParseInt(rangeList[1], 10, 64) if err != nil { logger.Printf("error parsing range header %s: %s", rangeList[1], err.Error()) return nil, err } - rangedRequestStart = rangedRequestEnd - end + rangeRequestStart = rangeRequestEnd - end // the rang is in the form VAL-, the wanted range is val->end - } else if strings.HasSuffix(rangevalue, rangeSeparator) { - rangedRequestStart, err = strconv.ParseInt(rangeList[0], 10, 64) + } else if strings.HasSuffix(rangeValue, rangeSeparator) { + rangeRequestStart, err = strconv.ParseInt(rangeList[0], 10, 64) if err != nil { logger.Printf("error parsing range header %s: %s", rangeList[1], err.Error()) return nil, err } - rangedRequestEnd = int64(len(body)) + rangeRequestEnd = int64(len(body)) // normal case with START-END } else { - rangedRequestStart, _ = strconv.ParseInt(rangeList[0], 10, 64) - rangedRequestEnd, _ = strconv.ParseInt(rangeList[1], 10, 64) + rangeRequestStart, _ = strconv.ParseInt(rangeList[0], 10, 64) + rangeRequestEnd, _ = strconv.ParseInt(rangeList[1], 10, 64) } - returnResponse.Body = ioutil.NopCloser(bytes.NewReader(body[rangedRequestStart:rangedRequestEnd])) + returnResponse.Body = ioutil.NopCloser(bytes.NewReader(body[rangeRequestStart:rangeRequestEnd])) } return returnResponse, nil } From f5c62be202049c333ff998b915d9b8afb9d63a3d Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Fri, 20 Feb 2015 17:08:50 +0100 Subject: [PATCH 10/20] Add handling of comma separated ranges (like 3-4,8-9,-3) and only fulfilling the first one Added check for rangeStart to be less than rangeEnd --- httpcache.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/httpcache.go b/httpcache.go index a6c9aa7..1d55311 100644 --- a/httpcache.go +++ b/httpcache.go @@ -77,6 +77,13 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) logger.Print("range type %s not supported", rangeType) return returnResponse, nil } + // TODO(uovobw): handle comma-separated list of ranges + // in this case we simply split it and only handle the first range provided + if strings.Contains(tmp[1], ",") { + requestedRanges := strings.Split(tmp[1], ",") + logger.Printf("unsupported multiple ranges %s, only fulfilling %s", tmp[1], requestedRanges[0]) + rangeValue = requestedRanges[0] + } // we need to read all the body now, close it, and replace it with another reader // as there is currently no way of "resetting" a Body body, err := ioutil.ReadAll(returnResponse.Body) @@ -86,7 +93,6 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) } returnResponse.Body.Close() var rangeRequestStart, rangeRequestEnd int64 - //TODO(uovobw): handle corrupted/nonstandard request header rangeList := strings.Split(rangeValue, rangeSeparator) // the range is in the form -VAL , the wanted range is (end-val)->end if strings.HasPrefix(rangeValue, rangeSeparator) { @@ -110,6 +116,11 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) rangeRequestStart, _ = strconv.ParseInt(rangeList[0], 10, 64) rangeRequestEnd, _ = strconv.ParseInt(rangeList[1], 10, 64) } + + if rangeRequestStart >= rangeRequestEnd { + logger.Printf("received non valid ranges start %d end %d", rangeRequestStart, rangeRequestEnd) + return nil, fmt.Errorf("non valid ranges specified in range request") + } returnResponse.Body = ioutil.NopCloser(bytes.NewReader(body[rangeRequestStart:rangeRequestEnd])) } return returnResponse, nil From cb4834c4b55c39fed4b1a01800553faf5a0a25a8 Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Fri, 20 Feb 2015 17:12:16 +0100 Subject: [PATCH 11/20] Add missing checks for ParseInt errors --- httpcache.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/httpcache.go b/httpcache.go index 1d55311..10b746c 100644 --- a/httpcache.go +++ b/httpcache.go @@ -113,8 +113,16 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) rangeRequestEnd = int64(len(body)) // normal case with START-END } else { - rangeRequestStart, _ = strconv.ParseInt(rangeList[0], 10, 64) - rangeRequestEnd, _ = strconv.ParseInt(rangeList[1], 10, 64) + rangeRequestStart, err = strconv.ParseInt(rangeList[0], 10, 64) + if err != nil { + logger.Printf("error parsing range start %s: %s", rangeList[0], err.Error()) + return nil, err + } + rangeRequestEnd, err = strconv.ParseInt(rangeList[1], 10, 64) + if err != nil { + logger.Printf("error parsing range end %s: %s", rangeList[1], err.Error()) + return nil, err + } } if rangeRequestStart >= rangeRequestEnd { From b72b80bf8aac68f99aeb95f2e8a95903fd6d9b8f Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Fri, 20 Feb 2015 19:50:43 +0100 Subject: [PATCH 12/20] Add findRange to exctract ranges in CachedResponse Added validateRanges to handle the case we have a subrange of an already existing query --- httpcache.go | 178 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 126 insertions(+), 52 deletions(-) diff --git a/httpcache.go b/httpcache.go index 10b746c..eac168a 100644 --- a/httpcache.go +++ b/httpcache.go @@ -16,6 +16,7 @@ import ( "log" "net/http" "net/http/httputil" + "regexp" "strconv" "strings" "sync" @@ -68,72 +69,142 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) return nil, fmt.Errorf("error loading response from cache: %s\n", err.Error()) } - rangeRaw := req.Header.Get("range") - if rangeRaw != "" { - tmp := strings.Split(rangeRaw, rangeTypeSeparator) - // standard format is bytes=START-END - rangeType, rangeValue := tmp[0], tmp[1] - if rangeType != "bytes" { - logger.Print("range type %s not supported", rangeType) - return returnResponse, nil + if req.Header.Get("range") != "" { + strContentLength := returnResponse.Header.Get("content-length") + contentLength, err := strconv.ParseInt(strContentLength, 10, 64) + if err != nil { + return nil, fmt.Errorf("response loaded from cache has null or malformed content-length: %s", contentLength) + } + rangeRequestStart, rangeRequestEnd, err := findRanges(req, contentLength) + if err != nil { + return nil, err } - // TODO(uovobw): handle comma-separated list of ranges - // in this case we simply split it and only handle the first range provided - if strings.Contains(tmp[1], ",") { - requestedRanges := strings.Split(tmp[1], ",") - logger.Printf("unsupported multiple ranges %s, only fulfilling %s", tmp[1], requestedRanges[0]) - rangeValue = requestedRanges[0] + if !validateRanges(rangeRequestStart, rangeRequestEnd, returnResponse) { + return nil, nil } - // we need to read all the body now, close it, and replace it with another reader - // as there is currently no way of "resetting" a Body + body, err := ioutil.ReadAll(returnResponse.Body) if err != nil { logger.Print("error reading cached response body: %s", err.Error()) return returnResponse, nil } returnResponse.Body.Close() - var rangeRequestStart, rangeRequestEnd int64 - rangeList := strings.Split(rangeValue, rangeSeparator) - // the range is in the form -VAL , the wanted range is (end-val)->end - if strings.HasPrefix(rangeValue, rangeSeparator) { - rangeRequestEnd = int64(len(body)) - end, err := strconv.ParseInt(rangeList[1], 10, 64) - if err != nil { - logger.Printf("error parsing range header %s: %s", rangeList[1], err.Error()) - return nil, err - } - rangeRequestStart = rangeRequestEnd - end - // the rang is in the form VAL-, the wanted range is val->end - } else if strings.HasSuffix(rangeValue, rangeSeparator) { - rangeRequestStart, err = strconv.ParseInt(rangeList[0], 10, 64) - if err != nil { - logger.Printf("error parsing range header %s: %s", rangeList[1], err.Error()) - return nil, err - } - rangeRequestEnd = int64(len(body)) - // normal case with START-END - } else { - rangeRequestStart, err = strconv.ParseInt(rangeList[0], 10, 64) - if err != nil { - logger.Printf("error parsing range start %s: %s", rangeList[0], err.Error()) - return nil, err - } - rangeRequestEnd, err = strconv.ParseInt(rangeList[1], 10, 64) - if err != nil { - logger.Printf("error parsing range end %s: %s", rangeList[1], err.Error()) - return nil, err - } - } - if rangeRequestStart >= rangeRequestEnd { - logger.Printf("received non valid ranges start %d end %d", rangeRequestStart, rangeRequestEnd) - return nil, fmt.Errorf("non valid ranges specified in range request") - } returnResponse.Body = ioutil.NopCloser(bytes.NewReader(body[rangeRequestStart:rangeRequestEnd])) } return returnResponse, nil } +// findRanges parses the range header value and the content length and returns (start, end, error) +// it is not exported since it does not implement multiple ranges and only accepts bytes= ranges +func findRanges(r *http.Request, totalLength int64) (start, end int64, err error) { + rawRange := r.Header.Get("range") + if rawRange == "" { + return -1, -1, fmt.Errorf("not a ranged request") + } + if !strings.HasPrefix(rawRange, "bytes=") { + return -1, -1, fmt.Errorf("non-bytes request %s range type unsupported", rawRange) + } + if strings.Contains(rawRange, ",") { + logger.Printf("unsupported multiple ranges, only fulfilling the first one: %s", rawRange) + } + re := regexp.MustCompile("bytes=([0-9]*)-([0-9]*)") + matchedValues := re.FindStringSubmatch(rawRange)[1:] + strStart := matchedValues[0] + strEnd := matchedValues[1] + // range in the form STRSTART- + if strEnd == "" { + end = totalLength + start, err = strconv.ParseInt(strStart, 10, 64) + if err != nil { + return -1, -1, err + } + // range in the form -STREND + } else if strStart == "" { + end = totalLength + start, err = strconv.ParseInt(strEnd, 10, 64) + if err != nil { + return -1, -1, err + } + start = totalLength - start + // range in the form STRSTART-STREND + } else { + start, err = strconv.ParseInt(strStart, 10, 64) + if err != nil { + return -1, -1, err + } + end, err = strconv.ParseInt(strEnd, 10, 64) + if err != nil { + return -1, -1, err + } + } + if start >= end { + return -1, -1, fmt.Errorf("invalid start %d >= end %d!", start, end) + } + return start, end, nil +} + +// validateRanges checks that a cached request for a given response is within data that has been already loaded +func validateRanges(start, end int64, resp *http.Response) (ok bool) { + // if the response cites partial content we need to compare the partial content we have stored + // with the ranges we require and, if not compatbile, fetch it again + // http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html section 14.16 + if resp.StatusCode == http.StatusPartialContent { + rawContentRange := resp.Header.Get("content-range") + if rawContentRange == "" { + logger.Printf("request for %s is stored as 206-partial-content but has no content-range header!", resp.Request.URL.String()) + return false + } + if !strings.Contains(rawContentRange, "bytes") { + logger.Printf("non-satisfiable range type in %s", rawContentRange) + return false + } + // the format is START-END/TOTAL or START-END/* if TOTAL is unknown + // if we find * we re-fetch the request as most probably the content was ephemereal + // or is highly probable iot has changed + if strings.Contains(rawContentRange, "*") { + return false + } + re := regexp.MustCompile("bytes ([0-9]+)-([0-9]+)/([0-9]+)") + // the first element is always the full match, skip it + matchedValues := re.FindStringSubmatch(rawContentRange)[1:] + currentStart, err := strconv.ParseInt(matchedValues[0], 10, 64) + if err != nil { + logger.Printf("cached response has malformed content-range header %s", rawContentRange) + return false + } + currentEnd, err := strconv.ParseInt(matchedValues[1], 10, 64) + if err != nil { + logger.Printf("cached response has malformed content-range header %s", rawContentRange) + return false + } + total, err := strconv.ParseInt(matchedValues[2], 10, 64) + if err != nil { + logger.Printf("cached response has malformed content-range header %s", rawContentRange) + return false + } + // validate the request ranges against the response headers + if start < currentStart || end > currentEnd || end > total { + logger.Printf("start: %d, currentStart: %d, end: %d, currentEnd: %d, total: %d", start, currentStart, end, currentEnd, total) + return false + } + return true + // the response is full content, use the content-length header to verify ranges + } else { + contentLength, err := strconv.ParseInt(resp.Header.Get("content-length"), 10, 64) + if err != nil { + logger.Printf("stored response has malformed or invalid content length %s", contentLength) + return false + } + if end > contentLength { + return false + } + return true + } + logger.Print("we should never ever reach this line") + return false +} + // MemoryCache is an implemtation of Cache that stores responses in an in-memory map. type MemoryCache struct { mu sync.RWMutex @@ -224,6 +295,9 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error var cachedResp *http.Response if cacheableMethod { cachedResp, err = CachedResponse(t.Cache, req) + if err != nil { + fmt.Print(err) + } } else { // Need to invalidate an existing value t.Cache.Delete(cacheKey) From 919461fca554b86e15196ab51f7f26c5775ac860 Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Fri, 20 Feb 2015 19:51:21 +0100 Subject: [PATCH 13/20] Add tests for range queries --- httpcache_test.go | 64 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/httpcache_test.go b/httpcache_test.go index 0dec3ed..df0fc70 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -92,9 +92,15 @@ func (s *S) SetUpSuite(c *C) { w.Write([]byte("Some text content")) })) mux.HandleFunc("/ranged", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testData := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" w.Header().Set("Cache-Control", "max-age=3600") w.Header().Set("Content-Type", "text/plain") - w.Write([]byte("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")) + start, end, err := findRanges(r, int64(len(testData))) + if err == nil { + w.Write([]byte(testData)[start:end]) + } else { + w.Write([]byte(testData)) + } })) updateFieldsCounter := 0 @@ -122,12 +128,15 @@ func (s *S) TearDownTest(c *C) { func (s *S) TestSuffixRangedQuery(c *C) { req, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) c.Assert(err, IsNil) - req.Header.Add("Range", "bytes=10-") + //req.Header.Add("Range", "bytes=10-") resp, err := s.client.Do(req) defer resp.Body.Close() c.Assert(err, IsNil) data, err := ioutil.ReadAll(resp.Body) c.Assert(len(data), Equals, 52) + c.Assert(string(data), Equals, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") + + req.Header.Add("Range", "bytes=10-") resp2, err := s.client.Do(req) defer resp2.Body.Close() c.Assert(resp2.Header.Get(XFromCache), Equals, "1") @@ -140,12 +149,14 @@ func (s *S) TestSuffixRangedQuery(c *C) { func (s *S) TestPrefixRangedQuery(c *C) { req, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) c.Assert(err, IsNil) - req.Header.Add("Range", "bytes=-10") resp, err := s.client.Do(req) defer resp.Body.Close() c.Assert(err, IsNil) data, err := ioutil.ReadAll(resp.Body) c.Assert(len(data), Equals, 52) + c.Assert(string(data), Equals, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") + + req.Header.Add("Range", "bytes=-10") resp2, err := s.client.Do(req) defer resp2.Body.Close() c.Assert(resp2.Header.Get(XFromCache), Equals, "1") @@ -163,7 +174,7 @@ func (s *S) TestCompleteRangedQuery(c *C) { defer resp.Body.Close() c.Assert(err, IsNil) data, err := ioutil.ReadAll(resp.Body) - c.Assert(len(data), Equals, 52) + c.Assert(len(data), Equals, 10) resp2, err := s.client.Do(req) defer resp2.Body.Close() c.Assert(resp2.Header.Get(XFromCache), Equals, "1") @@ -173,6 +184,51 @@ func (s *S) TestCompleteRangedQuery(c *C) { c.Assert(string(data2), Equals, "ABCDEFGHIJ") } +func (s *S) TestPartialSubrangeRangedQuery(c *C) { + req, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) + c.Assert(err, IsNil) + req.Header.Add("Range", "bytes=0-10") + resp, err := s.client.Do(req) + defer resp.Body.Close() + c.Assert(err, IsNil) + data, err := ioutil.ReadAll(resp.Body) + c.Assert(len(data), Equals, 10) + + req2, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) + c.Assert(err, IsNil) + req2.Header.Add("Range", "bytes=4-6") + resp2, err := s.client.Do(req2) + defer resp2.Body.Close() + c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + c.Assert(err, IsNil) + data2, err := ioutil.ReadAll(resp2.Body) + c.Assert(len(data2), Equals, 2) + c.Assert(string(data2), Equals, "EF") + + // test failing subrange outside previously held one + req3, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) + c.Assert(err, IsNil) + req3.Header.Add("Range", "bytes=8-15") + resp3, err := s.client.Do(req3) + defer resp3.Body.Close() + c.Assert(resp3.Header.Get(XFromCache), Equals, "") + c.Assert(err, IsNil) + data3, err := ioutil.ReadAll(resp3.Body) + c.Assert(len(data3), Equals, 7) + c.Assert(string(data3), Equals, "IJKLMNO") +} + +func (s *S) TestMultipleSubrangeRangedQuery(c *C) { + req, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) + c.Assert(err, IsNil) + req.Header.Add("Range", "bytes=0-10,15-40") + resp, err := s.client.Do(req) + defer resp.Body.Close() + c.Assert(err, IsNil) + data, err := ioutil.ReadAll(resp.Body) + c.Assert(len(data), Equals, 10) +} + func (s *S) TestGetOnlyIfCachedHit(c *C) { req, err := http.NewRequest("GET", s.server.URL, nil) c.Assert(err, IsNil) From 1980876e0cdbec20a8682d33c03af4ec5717e547 Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Fri, 20 Feb 2015 20:31:07 +0100 Subject: [PATCH 14/20] Add check to the setup function to skip the test if no memcached server is available --- memcache/memcache_test.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/memcache/memcache_test.go b/memcache/memcache_test.go index 4b8f568..7ea51fa 100644 --- a/memcache/memcache_test.go +++ b/memcache/memcache_test.go @@ -11,19 +11,26 @@ import ( const testServer = "localhost:11211" -func SetUpSuite() { +func SetUpSuite() bool { conn, err := net.Dial("tcp", testServer) if err != nil { // TODO: rather than skip the test, fall back to a faked memcached server fmt.Sprintf("skipping test; no server running at %s", testServer) + return false } conn.Write([]byte("flush_all\r\n")) // flush memcache conn.Close() + return true } func TestMemCache(t *testing.T) { - SetUpSuite() + if !SetUpSuite() { + t.SkipNow() + } cache := New(testServer) + if cache == recover() { + t.SkipNow() + } key := "testKey" _, ok := cache.Get(key) From 617d34088128196233671f65e63451a992bd03b5 Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Fri, 20 Feb 2015 21:48:04 +0100 Subject: [PATCH 15/20] Remove gocheck from tests and change all related functions --- httpcache_test.go | 419 +++++++++++++++++++++++++++++----------------- 1 file changed, 267 insertions(+), 152 deletions(-) diff --git a/httpcache_test.go b/httpcache_test.go index f78acb0..f1e32a5 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -1,26 +1,21 @@ package httpcache import ( - "fmt" "net/http" "net/http/httptest" "strconv" "testing" "time" - - . "gopkg.in/check.v1" ) -var _ = fmt.Print - -func Test(t *testing.T) { TestingT(t) } - type S struct { server *httptest.Server client http.Client transport *Transport } +var s S + type fakeClock struct { elapsed time.Duration } @@ -29,12 +24,11 @@ func (c *fakeClock) since(t time.Time) time.Duration { return c.elapsed } -var _ = Suite(&S{}) - -func (s *S) SetUpSuite(c *C) { - t := NewMemoryCacheTransport() - client := http.Client{Transport: t} - s.transport = t +func TestAll(t *testing.T) { + s = S{} + tp := NewMemoryCacheTransport() + client := http.Client{Transport: tp} + s.transport = tp s.client = client mux := http.NewServeMux() @@ -102,155 +96,229 @@ func (s *S) SetUpSuite(c *C) { w.Write([]byte("Some text content")) } })) -} -func (s *S) TearDownSuite(c *C) { + // testing does not provide tear up/down functions for the suite, + // we need to invoke them manually + testGetOnlyIfCachedHit(t) + tearDownTest() + testGetOnlyIfCachedMiss(t) + tearDownTest() + testGetNoStoreRequest(t) + tearDownTest() + testGetNoStoreResponse(t) + tearDownTest() + testGetWithEtag(t) + tearDownTest() + testGetWithLastModified(t) + tearDownTest() + testGetWithVary(t) + tearDownTest() + testGetWithDoubleVary(t) + tearDownTest() + testGetWith2VaryHeaders(t) + tearDownTest() + testGetVaryUnused(t) + tearDownTest() + testUpdateFields(t) + tearDownTest() + testParseCacheControl(t) + tearDownTest() + testNoCacheRequestExpiration(t) + tearDownTest() + testNoCacheResponseExpiration(t) + tearDownTest() + testReqMustRevalidate(t) + tearDownTest() + testRespMustRevalidate(t) + tearDownTest() + testFreshExpiration(t) + tearDownTest() + testMaxAge(t) + tearDownTest() + testMaxAgeZero(t) + tearDownTest() + testBothMaxAge(t) + tearDownTest() + testMinFreshWithExpires(t) + tearDownTest() + testEmptyMaxStale(t) + tearDownTest() + testMaxStaleValue(t) + tearDownTest() + testGetEndToEndHeaders(t) + tearDownTest() + s.server.Close() + } -func (s *S) TearDownTest(c *C) { +func tearDownTest() { s.transport.Cache = NewMemoryCache() clock = &realClock{} } -func (s *S) TestGetOnlyIfCachedHit(c *C) { +func testGetOnlyIfCachedHit(t *testing.T) { req, err := http.NewRequest("GET", s.server.URL, nil) - c.Assert(err, IsNil) + if err != nil { + t.FailNow() + } resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(resp.Header.Get(XFromCache), Equals, "") + if resp.Header.Get(XFromCache) != "" { + t.FailNow() + } req2, err2 := http.NewRequest("GET", s.server.URL, nil) req2.Header.Add("cache-control", "only-if-cached") resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") - c.Assert(resp2.StatusCode, Equals, 200) + if err2 != nil || resp2.Header.Get(XFromCache) != "1" || resp2.StatusCode != http.StatusOK { + t.FailNow() + } } -func (s *S) TestGetOnlyIfCachedMiss(c *C) { +func testGetOnlyIfCachedMiss(t *testing.T) { req, err := http.NewRequest("GET", s.server.URL, nil) req.Header.Add("cache-control", "only-if-cached") resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get(XFromCache), Equals, "") - c.Assert(resp.StatusCode, Equals, 504) + if err != nil || resp.Header.Get(XFromCache) != "" || resp.StatusCode != 504 { + t.FailNow() + } } -func (s *S) TestGetNoStoreRequest(c *C) { +func testGetNoStoreRequest(t *testing.T) { req, err := http.NewRequest("GET", s.server.URL, nil) req.Header.Add("Cache-Control", "no-store") resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get(XFromCache), Equals, "") + if err != nil || resp.Header.Get(XFromCache) != "" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "") + if err2 != nil || resp2.Header.Get(XFromCache) != "" { + t.FailNow() + } } -func (s *S) TestGetNoStoreResponse(c *C) { +func testGetNoStoreResponse(t *testing.T) { req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get(XFromCache), Equals, "") + if err != nil || resp.Header.Get(XFromCache) != "" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "") + if err2 != nil || resp2.Header.Get(XFromCache) != "" { + t.FailNow() + } } -func (s *S) TestGetWithEtag(c *C) { +func testGetWithEtag(t *testing.T) { req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get(XFromCache), Equals, "") + if err != nil || resp.Header.Get(XFromCache) != "" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") - + if err2 != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } // additional assertions to verify that 304 response is converted properly - c.Assert(resp2.Status, Equals, "200 OK") + if resp2.Status != "200 OK" { + t.FailNow() + } + _, ok := resp2.Header["Connection"] - c.Assert(ok, Equals, false) + if ok { + t.FailNow() + } } -func (s *S) TestGetWithLastModified(c *C) { +func testGetWithLastModified(t *testing.T) { req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get(XFromCache), Equals, "") + if err != nil || resp.Header.Get(XFromCache) != "" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + if err2 != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } } -func (s *S) TestGetWithVary(c *C) { +func testGetWithVary(t *testing.T) { req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) req.Header.Set("Accept", "text/plain") resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get("Vary"), Equals, "Accept") + if err != nil || resp.Header.Get("Vary") != "Accept" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + if err2 != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } req.Header.Set("Accept", "text/html") resp3, err3 := s.client.Do(req) defer resp3.Body.Close() - c.Assert(err3, IsNil) - c.Assert(resp3.Header.Get(XFromCache), Equals, "") + if err3 != nil || resp3.Header.Get(XFromCache) != "" { + t.FailNow() + } req.Header.Set("Accept", "") resp4, err4 := s.client.Do(req) defer resp4.Body.Close() - c.Assert(err4, IsNil) - c.Assert(resp4.Header.Get(XFromCache), Equals, "") + if err4 != nil || resp4.Header.Get(XFromCache) != "" { + t.FailNow() + } } -func (s *S) TestGetWithDoubleVary(c *C) { +func testGetWithDoubleVary(t *testing.T) { req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) req.Header.Set("Accept", "text/plain") req.Header.Set("Accept-Language", "da, en-gb;q=0.8, en;q=0.7") resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get("Vary"), Not(Equals), "") + if err != nil || resp.Header.Get("Vary") == "" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + if err2 != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } req.Header.Set("Accept-Language", "") resp3, err3 := s.client.Do(req) defer resp3.Body.Close() - c.Assert(err3, IsNil) - c.Assert(resp3.Header.Get(XFromCache), Equals, "") + if err3 != nil || resp3.Header.Get(XFromCache) != "" { + t.FailNow() + } req.Header.Set("Accept-Language", "da") resp4, err4 := s.client.Do(req) defer resp4.Body.Close() - c.Assert(err4, IsNil) - c.Assert(resp4.Header.Get(XFromCache), Equals, "") + if err4 != nil || resp4.Header.Get(XFromCache) != "" { + t.FailNow() + } } -func (s *S) TestGetWith2VaryHeaders(c *C) { +func testGetWith2VaryHeaders(t *testing.T) { // Tests that multiple Vary headers' comma-separated lists are // merged. See https://github.com/gregjones/httpcache/issues/27. const ( @@ -262,169 +330,204 @@ func (s *S) TestGetWith2VaryHeaders(c *C) { req.Header.Set("Accept-Language", acceptLanguage) resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get("Vary"), Not(Equals), "") + if err != nil || resp.Header.Get("Vary") == "" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + if err2 != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } req.Header.Set("Accept-Language", "") resp3, err3 := s.client.Do(req) defer resp3.Body.Close() - c.Assert(err3, IsNil) - c.Assert(resp3.Header.Get(XFromCache), Equals, "") + if err3 != nil || resp3.Header.Get(XFromCache) != "" { + t.FailNow() + } req.Header.Set("Accept-Language", "da") resp4, err4 := s.client.Do(req) defer resp4.Body.Close() - c.Assert(err4, IsNil) - c.Assert(resp4.Header.Get(XFromCache), Equals, "") + if err4 != nil || resp4.Header.Get(XFromCache) != "" { + t.FailNow() + } req.Header.Set("Accept-Language", acceptLanguage) req.Header.Set("Accept", "") resp5, err5 := s.client.Do(req) defer resp5.Body.Close() - c.Assert(err5, IsNil) - c.Assert(resp5.Header.Get(XFromCache), Equals, "") + if err5 != nil || resp5.Header.Get(XFromCache) != "" { + t.FailNow() + } req.Header.Set("Accept", "image/png") resp6, err6 := s.client.Do(req) defer resp6.Body.Close() - c.Assert(err6, IsNil) - c.Assert(resp6.Header.Get(XFromCache), Equals, "") + if err6 != nil || resp6.Header.Get(XFromCache) != "" { + t.FailNow() + } resp7, err7 := s.client.Do(req) defer resp7.Body.Close() - c.Assert(err7, IsNil) - c.Assert(resp7.Header.Get(XFromCache), Equals, "1") + if err7 != nil || resp7.Header.Get(XFromCache) != "1" { + t.FailNow() + } } -func (s *S) TestGetVaryUnused(c *C) { +func testGetVaryUnused(t *testing.T) { req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) req.Header.Set("Accept", "text/plain") resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) - c.Assert(resp.Header.Get("Vary"), Not(Equals), "") + if err != nil || resp.Header.Get("Vary") == "" { + t.FailNow() + } resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + if err2 != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } } -func (s *S) TestUpdateFields(c *C) { +func testUpdateFields(t *testing.T) { req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) resp, err := s.client.Do(req) defer resp.Body.Close() - c.Assert(err, IsNil) + if err != nil { + t.FailNow() + } counter := resp.Header.Get("x-counter") resp2, err2 := s.client.Do(req) defer resp2.Body.Close() - c.Assert(err2, IsNil) - c.Assert(resp2.Header.Get(XFromCache), Equals, "1") + if err2 != nil || resp2.Header.Get(XFromCache) != "1" { + t.FailNow() + } counter2 := resp2.Header.Get("x-counter") - c.Assert(counter, Not(Equals), counter2) + if counter == counter2 { + t.FailNow() + } } -func (s *S) TestParseCacheControl(c *C) { +func testParseCacheControl(t *testing.T) { h := http.Header{} for _ = range parseCacheControl(h) { - c.Fatal("cacheControl should be empty") + t.Fatal("cacheControl should be empty") } h.Set("cache-control", "no-cache") cc := parseCacheControl(h) if _, ok := cc["foo"]; ok { - c.Error("Value shouldn't exist") + t.Error("Value shouldn't exist") } if nocache, ok := cc["no-cache"]; ok { - c.Assert(nocache, Equals, "") + if nocache != "" { + t.FailNow() + } } h.Set("cache-control", "no-cache, max-age=3600") cc = parseCacheControl(h) - c.Assert(cc["no-cache"], Equals, "") - c.Assert(cc["max-age"], Equals, "3600") + if cc["no-cache"] != "" || cc["max-age"] != "3600" { + t.FailNow() + } } -func (s *S) TestNoCacheRequestExpiration(c *C) { +func testNoCacheRequestExpiration(t *testing.T) { respHeaders := http.Header{} respHeaders.Set("Cache-Control", "max-age=7200") reqHeaders := http.Header{} reqHeaders.Set("Cache-Control", "no-cache") - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, transparent) + if getFreshness(respHeaders, reqHeaders) != transparent { + t.FailNow() + } } -func (s *S) TestNoCacheResponseExpiration(c *C) { +func testNoCacheResponseExpiration(t *testing.T) { respHeaders := http.Header{} respHeaders.Set("Cache-Control", "no-cache") respHeaders.Set("Expires", "Wed, 19 Apr 3000 11:43:00 GMT") reqHeaders := http.Header{} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestReqMustRevalidate(c *C) { +func testReqMustRevalidate(t *testing.T) { // not paying attention to request setting max-stale means never returning stale // responses, so always acting as if must-revalidate is set respHeaders := http.Header{} reqHeaders := http.Header{} reqHeaders.Set("Cache-Control", "must-revalidate") - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestRespMustRevalidate(c *C) { +func testRespMustRevalidate(t *testing.T) { respHeaders := http.Header{} respHeaders.Set("Cache-Control", "must-revalidate") reqHeaders := http.Header{} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestFreshExpiration(c *C) { +func testFreshExpiration(t *testing.T) { now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) reqHeaders := http.Header{} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, fresh) + if getFreshness(respHeaders, reqHeaders) != fresh { + t.FailNow() + } clock = &fakeClock{elapsed: 3 * time.Second} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestMaxAge(c *C) { +func testMaxAge(t *testing.T) { now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) respHeaders.Set("cache-control", "max-age=2") reqHeaders := http.Header{} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, fresh) + if getFreshness(respHeaders, reqHeaders) != fresh { + t.FailNow() + } clock = &fakeClock{elapsed: 3 * time.Second} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestMaxAgeZero(c *C) { +func testMaxAgeZero(t *testing.T) { now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) respHeaders.Set("cache-control", "max-age=0") reqHeaders := http.Header{} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestBothMaxAge(c *C) { +func testBothMaxAge(t *testing.T) { now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -432,10 +535,12 @@ func (s *S) TestBothMaxAge(c *C) { reqHeaders := http.Header{} reqHeaders.Set("cache-control", "max-age=0") - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestMinFreshWithExpires(c *C) { +func testMinFreshWithExpires(t *testing.T) { now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -443,14 +548,18 @@ func (s *S) TestMinFreshWithExpires(c *C) { reqHeaders := http.Header{} reqHeaders.Set("cache-control", "min-fresh=1") - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, fresh) + if getFreshness(respHeaders, reqHeaders) != fresh { + t.FailNow() + } reqHeaders = http.Header{} reqHeaders.Set("cache-control", "min-fresh=2") - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } -func (s *S) TestEmptyMaxStale(c *C) { +func testEmptyMaxStale(t *testing.T) { now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -461,14 +570,18 @@ func (s *S) TestEmptyMaxStale(c *C) { clock = &fakeClock{elapsed: 10 * time.Second} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, fresh) + if getFreshness(respHeaders, reqHeaders) != fresh { + t.FailNow() + } clock = &fakeClock{elapsed: 60 * time.Second} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, fresh) + if getFreshness(respHeaders, reqHeaders) != fresh { + t.FailNow() + } } -func (s *S) TestMaxStaleValue(c *C) { +func testMaxStaleValue(t *testing.T) { now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -478,15 +591,21 @@ func (s *S) TestMaxStaleValue(c *C) { reqHeaders.Set("cache-control", "max-stale=20") clock = &fakeClock{elapsed: 5 * time.Second} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, fresh) + if getFreshness(respHeaders, reqHeaders) != fresh { + t.FailNow() + } clock = &fakeClock{elapsed: 15 * time.Second} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, fresh) + if getFreshness(respHeaders, reqHeaders) != fresh { + t.FailNow() + } clock = &fakeClock{elapsed: 30 * time.Second} - c.Assert(getFreshness(respHeaders, reqHeaders), Equals, stale) + if getFreshness(respHeaders, reqHeaders) != stale { + t.FailNow() + } } func containsHeader(headers []string, header string) bool { @@ -498,25 +617,7 @@ func containsHeader(headers []string, header string) bool { return false } -type containsHeaderChecker struct { - *CheckerInfo -} - -func (c *containsHeaderChecker) Check(params []interface{}, names []string) (bool, string) { - items, ok := params[0].([]string) - if !ok { - return false, "Expected first param to be []string" - } - value, ok := params[1].(string) - if !ok { - return false, "Expected 2nd param to be string" - } - return containsHeader(items, value), "" -} - -var ContainsHeader Checker = &containsHeaderChecker{&CheckerInfo{Name: "Contains", Params: []string{"Container", "expected to contain"}}} - -func (s *S) TestGetEndToEndHeaders(c *C) { +func testGetEndToEndHeaders(t *testing.T) { var ( headers http.Header end2end []string @@ -527,24 +628,38 @@ func (s *S) TestGetEndToEndHeaders(c *C) { headers.Set("te", "deflate") end2end = getEndToEndHeaders(headers) - c.Check(end2end, ContainsHeader, "content-type") - c.Check(end2end, Not(ContainsHeader), "te") + if !containsHeader(end2end, "content-type") { + t.FailNow() + } + if containsHeader(end2end, "te") { + t.FailNow() + } headers = http.Header{} headers.Set("connection", "content-type") headers.Set("content-type", "text/csv") headers.Set("te", "deflate") end2end = getEndToEndHeaders(headers) - c.Check(end2end, Not(ContainsHeader), "connection") - c.Check(end2end, Not(ContainsHeader), "content-type") - c.Check(end2end, Not(ContainsHeader), "te") + if containsHeader(end2end, "connection") { + t.FailNow() + } + if containsHeader(end2end, "content-type") { + t.FailNow() + } + if containsHeader(end2end, "te") { + t.FailNow() + } headers = http.Header{} end2end = getEndToEndHeaders(headers) - c.Check(end2end, HasLen, 0) + if len(end2end) != 0 { + t.FailNow() + } headers = http.Header{} headers.Set("connection", "content-type") end2end = getEndToEndHeaders(headers) - c.Check(end2end, HasLen, 0) + if len(end2end) != 0 { + t.FailNow() + } } From b6d8e1333ddb9b691a5a671e2906085810cfce02 Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Mon, 23 Feb 2015 19:46:58 +0100 Subject: [PATCH 16/20] Refactor as per pull request 3: https://github.com/sourcegraph/httpcache/pull/3/ --- httpcache_test.go | 156 +++++++++++++++++++++++----------------------- 1 file changed, 77 insertions(+), 79 deletions(-) diff --git a/httpcache_test.go b/httpcache_test.go index f1e32a5..77614cd 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -24,7 +24,7 @@ func (c *fakeClock) since(t time.Time) time.Duration { return c.elapsed } -func TestAll(t *testing.T) { +func setup() { s = S{} tp := NewMemoryCacheTransport() client := http.Client{Transport: tp} @@ -96,73 +96,25 @@ func TestAll(t *testing.T) { w.Write([]byte("Some text content")) } })) - - // testing does not provide tear up/down functions for the suite, - // we need to invoke them manually - testGetOnlyIfCachedHit(t) - tearDownTest() - testGetOnlyIfCachedMiss(t) - tearDownTest() - testGetNoStoreRequest(t) - tearDownTest() - testGetNoStoreResponse(t) - tearDownTest() - testGetWithEtag(t) - tearDownTest() - testGetWithLastModified(t) - tearDownTest() - testGetWithVary(t) - tearDownTest() - testGetWithDoubleVary(t) - tearDownTest() - testGetWith2VaryHeaders(t) - tearDownTest() - testGetVaryUnused(t) - tearDownTest() - testUpdateFields(t) - tearDownTest() - testParseCacheControl(t) - tearDownTest() - testNoCacheRequestExpiration(t) - tearDownTest() - testNoCacheResponseExpiration(t) - tearDownTest() - testReqMustRevalidate(t) - tearDownTest() - testRespMustRevalidate(t) - tearDownTest() - testFreshExpiration(t) - tearDownTest() - testMaxAge(t) - tearDownTest() - testMaxAgeZero(t) - tearDownTest() - testBothMaxAge(t) - tearDownTest() - testMinFreshWithExpires(t) - tearDownTest() - testEmptyMaxStale(t) - tearDownTest() - testMaxStaleValue(t) - tearDownTest() - testGetEndToEndHeaders(t) - tearDownTest() - - s.server.Close() - } func tearDownTest() { s.transport.Cache = NewMemoryCache() clock = &realClock{} + s.server.Close() } -func testGetOnlyIfCachedHit(t *testing.T) { +func TestGetOnlyIfCachedHit(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL, nil) if err != nil { t.FailNow() } resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err.Error()) + } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.FailNow() @@ -177,7 +129,9 @@ func testGetOnlyIfCachedHit(t *testing.T) { } } -func testGetOnlyIfCachedMiss(t *testing.T) { +func TestGetOnlyIfCachedMiss(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL, nil) req.Header.Add("cache-control", "only-if-cached") resp, err := s.client.Do(req) @@ -187,7 +141,9 @@ func testGetOnlyIfCachedMiss(t *testing.T) { } } -func testGetNoStoreRequest(t *testing.T) { +func TestGetNoStoreRequest(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL, nil) req.Header.Add("Cache-Control", "no-store") resp, err := s.client.Do(req) @@ -203,7 +159,9 @@ func testGetNoStoreRequest(t *testing.T) { } } -func testGetNoStoreResponse(t *testing.T) { +func TestGetNoStoreResponse(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) resp, err := s.client.Do(req) defer resp.Body.Close() @@ -218,7 +176,9 @@ func testGetNoStoreResponse(t *testing.T) { } } -func testGetWithEtag(t *testing.T) { +func TestGetWithEtag(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) resp, err := s.client.Do(req) defer resp.Body.Close() @@ -242,7 +202,9 @@ func testGetWithEtag(t *testing.T) { } } -func testGetWithLastModified(t *testing.T) { +func TestGetWithLastModified(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) resp, err := s.client.Do(req) defer resp.Body.Close() @@ -257,7 +219,9 @@ func testGetWithLastModified(t *testing.T) { } } -func testGetWithVary(t *testing.T) { +func TestGetWithVary(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) req.Header.Set("Accept", "text/plain") resp, err := s.client.Do(req) @@ -287,7 +251,9 @@ func testGetWithVary(t *testing.T) { } } -func testGetWithDoubleVary(t *testing.T) { +func TestGetWithDoubleVary(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) req.Header.Set("Accept", "text/plain") req.Header.Set("Accept-Language", "da, en-gb;q=0.8, en;q=0.7") @@ -318,7 +284,9 @@ func testGetWithDoubleVary(t *testing.T) { } } -func testGetWith2VaryHeaders(t *testing.T) { +func TestGetWith2VaryHeaders(t *testing.T) { + setup() + defer tearDownTest() // Tests that multiple Vary headers' comma-separated lists are // merged. See https://github.com/gregjones/httpcache/issues/27. const ( @@ -376,7 +344,9 @@ func testGetWith2VaryHeaders(t *testing.T) { } } -func testGetVaryUnused(t *testing.T) { +func TestGetVaryUnused(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) req.Header.Set("Accept", "text/plain") resp, err := s.client.Do(req) @@ -392,7 +362,9 @@ func testGetVaryUnused(t *testing.T) { } } -func testUpdateFields(t *testing.T) { +func TestUpdateFields(t *testing.T) { + setup() + defer tearDownTest() req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) resp, err := s.client.Do(req) defer resp.Body.Close() @@ -413,7 +385,9 @@ func testUpdateFields(t *testing.T) { } } -func testParseCacheControl(t *testing.T) { +func TestParseCacheControl(t *testing.T) { + setup() + defer tearDownTest() h := http.Header{} for _ = range parseCacheControl(h) { t.Fatal("cacheControl should be empty") @@ -437,7 +411,9 @@ func testParseCacheControl(t *testing.T) { } } -func testNoCacheRequestExpiration(t *testing.T) { +func TestNoCacheRequestExpiration(t *testing.T) { + setup() + defer tearDownTest() respHeaders := http.Header{} respHeaders.Set("Cache-Control", "max-age=7200") reqHeaders := http.Header{} @@ -448,7 +424,9 @@ func testNoCacheRequestExpiration(t *testing.T) { } } -func testNoCacheResponseExpiration(t *testing.T) { +func TestNoCacheResponseExpiration(t *testing.T) { + setup() + defer tearDownTest() respHeaders := http.Header{} respHeaders.Set("Cache-Control", "no-cache") respHeaders.Set("Expires", "Wed, 19 Apr 3000 11:43:00 GMT") @@ -459,7 +437,9 @@ func testNoCacheResponseExpiration(t *testing.T) { } } -func testReqMustRevalidate(t *testing.T) { +func TestReqMustRevalidate(t *testing.T) { + setup() + defer tearDownTest() // not paying attention to request setting max-stale means never returning stale // responses, so always acting as if must-revalidate is set respHeaders := http.Header{} @@ -471,7 +451,9 @@ func testReqMustRevalidate(t *testing.T) { } } -func testRespMustRevalidate(t *testing.T) { +func TestRespMustRevalidate(t *testing.T) { + setup() + defer tearDownTest() respHeaders := http.Header{} respHeaders.Set("Cache-Control", "must-revalidate") reqHeaders := http.Header{} @@ -481,7 +463,9 @@ func testRespMustRevalidate(t *testing.T) { } } -func testFreshExpiration(t *testing.T) { +func TestFreshExpiration(t *testing.T) { + setup() + defer tearDownTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -498,7 +482,9 @@ func testFreshExpiration(t *testing.T) { } } -func testMaxAge(t *testing.T) { +func TestMaxAge(t *testing.T) { + setup() + defer tearDownTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -515,7 +501,9 @@ func testMaxAge(t *testing.T) { } } -func testMaxAgeZero(t *testing.T) { +func TestMaxAgeZero(t *testing.T) { + setup() + defer tearDownTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -527,7 +515,9 @@ func testMaxAgeZero(t *testing.T) { } } -func testBothMaxAge(t *testing.T) { +func TestBothMaxAge(t *testing.T) { + setup() + defer tearDownTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -540,7 +530,9 @@ func testBothMaxAge(t *testing.T) { } } -func testMinFreshWithExpires(t *testing.T) { +func TestMinFreshWithExpires(t *testing.T) { + setup() + defer tearDownTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -559,7 +551,9 @@ func testMinFreshWithExpires(t *testing.T) { } } -func testEmptyMaxStale(t *testing.T) { +func TestEmptyMaxStale(t *testing.T) { + setup() + defer tearDownTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -581,7 +575,9 @@ func testEmptyMaxStale(t *testing.T) { } } -func testMaxStaleValue(t *testing.T) { +func TestMaxStaleValue(t *testing.T) { + setup() + defer tearDownTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -617,7 +613,8 @@ func containsHeader(headers []string, header string) bool { return false } -func testGetEndToEndHeaders(t *testing.T) { +func TestGetEndToEndHeaders(t *testing.T) { + setup() var ( headers http.Header end2end []string @@ -662,4 +659,5 @@ func testGetEndToEndHeaders(t *testing.T) { if len(end2end) != 0 { t.FailNow() } + tearDownTest() } From 20404899876c585f3b9aa7cf04b8546e2912f9c3 Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Tue, 24 Feb 2015 00:51:32 +0100 Subject: [PATCH 17/20] Compiled regexp in init() instead that on demand Reformat SetLogging() to SetLogger that takes a *log.Logger Check more errors in the code Set content-range header for ranged responses Add more checks to the tests (for ranges and headers) Avoid answering with a partial range to a multiple-range request, answer by refetching the request again --- httpcache.go | 42 +++++++++++++++++++++--------------------- httpcache_test.go | 42 ++++++++++++++++++++++++++++++++++++------ 2 files changed, 57 insertions(+), 27 deletions(-) diff --git a/httpcache.go b/httpcache.go index eac168a..3c1b538 100644 --- a/httpcache.go +++ b/httpcache.go @@ -11,7 +11,6 @@ import ( "bytes" "errors" "fmt" - "io" "io/ioutil" "log" "net/http" @@ -33,10 +32,14 @@ const ( rangeTypeSeparator = "=" ) -var logger *log.Logger +var ( + logger *log.Logger + bytesRangeRegexp *regexp.Regexp +) func init() { logger = log.New(ioutil.Discard, "httpcache", 0) + bytesRangeRegexp = regexp.MustCompile("bytes=([0-9]*)-([0-9]*)") } // A Cache interface is used by the Transport to store and retrieve responses. @@ -91,6 +94,7 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) returnResponse.Body.Close() returnResponse.Body = ioutil.NopCloser(bytes.NewReader(body[rangeRequestStart:rangeRequestEnd])) + returnResponse.Header.Set("content-range", fmt.Sprintf("bytes %d-%d/%d", rangeRequestStart, rangeRequestEnd, contentLength)) } return returnResponse, nil } @@ -106,10 +110,9 @@ func findRanges(r *http.Request, totalLength int64) (start, end int64, err error return -1, -1, fmt.Errorf("non-bytes request %s range type unsupported", rawRange) } if strings.Contains(rawRange, ",") { - logger.Printf("unsupported multiple ranges, only fulfilling the first one: %s", rawRange) + return -1, -1, fmt.Errorf("unsupported multiple ranges: %s", rawRange) } - re := regexp.MustCompile("bytes=([0-9]*)-([0-9]*)") - matchedValues := re.FindStringSubmatch(rawRange)[1:] + matchedValues := bytesRangeRegexp.FindStringSubmatch(rawRange)[1:] strStart := matchedValues[0] strEnd := matchedValues[1] // range in the form STRSTART- @@ -190,19 +193,16 @@ func validateRanges(start, end int64, resp *http.Response) (ok bool) { } return true // the response is full content, use the content-length header to verify ranges - } else { - contentLength, err := strconv.ParseInt(resp.Header.Get("content-length"), 10, 64) - if err != nil { - logger.Printf("stored response has malformed or invalid content length %s", contentLength) - return false - } - if end > contentLength { - return false - } - return true } - logger.Print("we should never ever reach this line") - return false + contentLength, err := strconv.ParseInt(resp.Header.Get("content-length"), 10, 64) + if err != nil { + logger.Printf("stored response has malformed or invalid content length %s", contentLength) + return false + } + if end > contentLength { + return false + } + return true } // MemoryCache is an implemtation of Cache that stores responses in an in-memory map. @@ -257,10 +257,10 @@ func NewTransport(c Cache) *Transport { return &Transport{Cache: c, MarkCachedResponses: true} } -// SetLogging has the same parameters as the log.New function and replaces the -// default logger that discards messages -func (t *Transport) SetLogging(out io.Writer, prefix string, flags int) { - logger = log.New(out, prefix, flags) +// SetLogger takes a *log.Logger and replaces the current one that discards all messages +// this method is not thread safe +func (t *Transport) SetLogger(l *log.Logger) { + logger = l } // Client returns an *http.Client that caches responses. diff --git a/httpcache_test.go b/httpcache_test.go index df0fc70..c6f2d88 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -97,6 +97,7 @@ func (s *S) SetUpSuite(c *C) { w.Header().Set("Content-Type", "text/plain") start, end, err := findRanges(r, int64(len(testData))) if err == nil { + w.Header().Set("content-range", fmt.Sprintf("bytes %d-%d/%d", start, end, len(testData))) w.Write([]byte(testData)[start:end]) } else { w.Write([]byte(testData)) @@ -128,13 +129,13 @@ func (s *S) TearDownTest(c *C) { func (s *S) TestSuffixRangedQuery(c *C) { req, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) c.Assert(err, IsNil) - //req.Header.Add("Range", "bytes=10-") resp, err := s.client.Do(req) defer resp.Body.Close() c.Assert(err, IsNil) data, err := ioutil.ReadAll(resp.Body) c.Assert(len(data), Equals, 52) c.Assert(string(data), Equals, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") + c.Assert(resp.Header.Get(XFromCache), Equals, "") req.Header.Add("Range", "bytes=10-") resp2, err := s.client.Do(req) @@ -144,6 +145,7 @@ func (s *S) TestSuffixRangedQuery(c *C) { data2, err := ioutil.ReadAll(resp2.Body) c.Assert(len(data2), Equals, 42) c.Assert(string(data2), Equals, "KLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") + c.Assert(resp2.Header.Get("content-range"), Equals, "bytes 10-52/52") } func (s *S) TestPrefixRangedQuery(c *C) { @@ -153,8 +155,11 @@ func (s *S) TestPrefixRangedQuery(c *C) { defer resp.Body.Close() c.Assert(err, IsNil) data, err := ioutil.ReadAll(resp.Body) + c.Assert(err, IsNil) c.Assert(len(data), Equals, 52) c.Assert(string(data), Equals, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") + c.Assert(resp.Header.Get("content-range"), Equals, "") + c.Assert(resp.Header.Get(XFromCache), Equals, "") req.Header.Add("Range", "bytes=-10") resp2, err := s.client.Do(req) @@ -162,8 +167,10 @@ func (s *S) TestPrefixRangedQuery(c *C) { c.Assert(resp2.Header.Get(XFromCache), Equals, "1") c.Assert(err, IsNil) data2, err := ioutil.ReadAll(resp2.Body) + c.Assert(err, IsNil) c.Assert(len(data2), Equals, 10) c.Assert(string(data2), Equals, "qrstuvwxyz") + c.Assert(resp2.Header.Get("content-range"), Equals, "bytes 42-52/52") } func (s *S) TestCompleteRangedQuery(c *C) { @@ -174,7 +181,9 @@ func (s *S) TestCompleteRangedQuery(c *C) { defer resp.Body.Close() c.Assert(err, IsNil) data, err := ioutil.ReadAll(resp.Body) + c.Assert(err, IsNil) c.Assert(len(data), Equals, 10) + c.Assert(string(data), Equals, "ABCDEFGHIJ") resp2, err := s.client.Do(req) defer resp2.Body.Close() c.Assert(resp2.Header.Get(XFromCache), Equals, "1") @@ -182,6 +191,7 @@ func (s *S) TestCompleteRangedQuery(c *C) { data2, err := ioutil.ReadAll(resp2.Body) c.Assert(len(data2), Equals, 10) c.Assert(string(data2), Equals, "ABCDEFGHIJ") + c.Assert(resp2.Header.Get("content-range"), Equals, "bytes 0-10/10") } func (s *S) TestPartialSubrangeRangedQuery(c *C) { @@ -189,21 +199,25 @@ func (s *S) TestPartialSubrangeRangedQuery(c *C) { c.Assert(err, IsNil) req.Header.Add("Range", "bytes=0-10") resp, err := s.client.Do(req) - defer resp.Body.Close() c.Assert(err, IsNil) + defer resp.Body.Close() data, err := ioutil.ReadAll(resp.Body) c.Assert(len(data), Equals, 10) + c.Assert(string(data), Equals, "ABCDEFGHIJ") + c.Assert(resp.Header.Get("content-range"), Equals, "bytes 0-10/52") req2, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) c.Assert(err, IsNil) req2.Header.Add("Range", "bytes=4-6") resp2, err := s.client.Do(req2) + c.Assert(err, IsNil) defer resp2.Body.Close() c.Assert(resp2.Header.Get(XFromCache), Equals, "1") - c.Assert(err, IsNil) data2, err := ioutil.ReadAll(resp2.Body) + c.Assert(err, IsNil) c.Assert(len(data2), Equals, 2) c.Assert(string(data2), Equals, "EF") + c.Assert(resp2.Header.Get("content-range"), Equals, "bytes 4-6/10") // test failing subrange outside previously held one req3, err := http.NewRequest("GET", s.server.URL+"/ranged", nil) @@ -214,8 +228,10 @@ func (s *S) TestPartialSubrangeRangedQuery(c *C) { c.Assert(resp3.Header.Get(XFromCache), Equals, "") c.Assert(err, IsNil) data3, err := ioutil.ReadAll(resp3.Body) + c.Assert(err, IsNil) c.Assert(len(data3), Equals, 7) c.Assert(string(data3), Equals, "IJKLMNO") + c.Assert(resp3.Header.Get("content-range"), Equals, "bytes 8-15/52") } func (s *S) TestMultipleSubrangeRangedQuery(c *C) { @@ -223,21 +239,25 @@ func (s *S) TestMultipleSubrangeRangedQuery(c *C) { c.Assert(err, IsNil) req.Header.Add("Range", "bytes=0-10,15-40") resp, err := s.client.Do(req) - defer resp.Body.Close() c.Assert(err, IsNil) + defer resp.Body.Close() data, err := ioutil.ReadAll(resp.Body) - c.Assert(len(data), Equals, 10) + c.Assert(err, IsNil) + c.Assert(len(data), Equals, 52) + c.Assert(string(data), Equals, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") } func (s *S) TestGetOnlyIfCachedHit(c *C) { req, err := http.NewRequest("GET", s.server.URL, nil) c.Assert(err, IsNil) resp, err := s.client.Do(req) + c.Assert(err, IsNil) defer resp.Body.Close() c.Assert(resp.Header.Get(XFromCache), Equals, "") req2, err2 := http.NewRequest("GET", s.server.URL, nil) req2.Header.Add("cache-control", "only-if-cached") + c.Assert(err2, IsNil) resp2, err2 := s.client.Do(req) defer resp2.Body.Close() c.Assert(err2, IsNil) @@ -247,16 +267,18 @@ func (s *S) TestGetOnlyIfCachedHit(c *C) { func (s *S) TestGetOnlyIfCachedMiss(c *C) { req, err := http.NewRequest("GET", s.server.URL, nil) + c.Assert(err, IsNil) req.Header.Add("cache-control", "only-if-cached") resp, err := s.client.Do(req) - defer resp.Body.Close() c.Assert(err, IsNil) + defer resp.Body.Close() c.Assert(resp.Header.Get(XFromCache), Equals, "") c.Assert(resp.StatusCode, Equals, 504) } func (s *S) TestGetNoStoreRequest(c *C) { req, err := http.NewRequest("GET", s.server.URL, nil) + c.Assert(err, IsNil) req.Header.Add("Cache-Control", "no-store") resp, err := s.client.Do(req) defer resp.Body.Close() @@ -271,6 +293,7 @@ func (s *S) TestGetNoStoreRequest(c *C) { func (s *S) TestGetNoStoreResponse(c *C) { req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) + c.Assert(err, IsNil) resp, err := s.client.Do(req) defer resp.Body.Close() c.Assert(err, IsNil) @@ -284,6 +307,7 @@ func (s *S) TestGetNoStoreResponse(c *C) { func (s *S) TestGetWithEtag(c *C) { req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) + c.Assert(err, IsNil) resp, err := s.client.Do(req) defer resp.Body.Close() c.Assert(err, IsNil) @@ -302,6 +326,7 @@ func (s *S) TestGetWithEtag(c *C) { func (s *S) TestGetWithLastModified(c *C) { req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) + c.Assert(err, IsNil) resp, err := s.client.Do(req) defer resp.Body.Close() c.Assert(err, IsNil) @@ -315,6 +340,7 @@ func (s *S) TestGetWithLastModified(c *C) { func (s *S) TestGetWithVary(c *C) { req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) + c.Assert(err, IsNil) req.Header.Set("Accept", "text/plain") resp, err := s.client.Do(req) defer resp.Body.Close() @@ -341,6 +367,7 @@ func (s *S) TestGetWithVary(c *C) { func (s *S) TestGetWithDoubleVary(c *C) { req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) + c.Assert(err, IsNil) req.Header.Set("Accept", "text/plain") req.Header.Set("Accept-Language", "da, en-gb;q=0.8, en;q=0.7") resp, err := s.client.Do(req) @@ -374,6 +401,7 @@ func (s *S) TestGetWith2VaryHeaders(c *C) { acceptLanguage = "da, en-gb;q=0.8, en;q=0.7" ) req, err := http.NewRequest("GET", s.server.URL+"/2varyheaders", nil) + c.Assert(err, IsNil) req.Header.Set("Accept", accept) req.Header.Set("Accept-Language", acceptLanguage) resp, err := s.client.Do(req) @@ -419,6 +447,7 @@ func (s *S) TestGetWith2VaryHeaders(c *C) { func (s *S) TestGetVaryUnused(c *C) { req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) + c.Assert(err, IsNil) req.Header.Set("Accept", "text/plain") resp, err := s.client.Do(req) defer resp.Body.Close() @@ -433,6 +462,7 @@ func (s *S) TestGetVaryUnused(c *C) { func (s *S) TestUpdateFields(c *C) { req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) + c.Assert(err, IsNil) resp, err := s.client.Do(req) defer resp.Body.Close() c.Assert(err, IsNil) From 736f33acd56801bcd593e4068b431991295cf869 Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Tue, 24 Feb 2015 01:26:26 +0100 Subject: [PATCH 18/20] Go vet of the package --- httpcache.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/httpcache.go b/httpcache.go index 3c1b538..3ff2434 100644 --- a/httpcache.go +++ b/httpcache.go @@ -76,7 +76,7 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) strContentLength := returnResponse.Header.Get("content-length") contentLength, err := strconv.ParseInt(strContentLength, 10, 64) if err != nil { - return nil, fmt.Errorf("response loaded from cache has null or malformed content-length: %s", contentLength) + return nil, fmt.Errorf("response loaded from cache has null or malformed content-length: %d", contentLength) } rangeRequestStart, rangeRequestEnd, err := findRanges(req, contentLength) if err != nil { @@ -88,7 +88,7 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) body, err := ioutil.ReadAll(returnResponse.Body) if err != nil { - logger.Print("error reading cached response body: %s", err.Error()) + logger.Printf("error reading cached response body: %s", err.Error()) return returnResponse, nil } returnResponse.Body.Close() @@ -196,7 +196,7 @@ func validateRanges(start, end int64, resp *http.Response) (ok bool) { } contentLength, err := strconv.ParseInt(resp.Header.Get("content-length"), 10, 64) if err != nil { - logger.Printf("stored response has malformed or invalid content length %s", contentLength) + logger.Printf("stored response has malformed or invalid content length %d", contentLength) return false } if end > contentLength { From 00bcec30929f7fc3a21bea5ea86e5167be8d4de4 Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Tue, 24 Feb 2015 01:27:50 +0100 Subject: [PATCH 19/20] Golint the package --- httpcache.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/httpcache.go b/httpcache.go index 3ff2434..96c05fb 100644 --- a/httpcache.go +++ b/httpcache.go @@ -142,7 +142,7 @@ func findRanges(r *http.Request, totalLength int64) (start, end int64, err error } } if start >= end { - return -1, -1, fmt.Errorf("invalid start %d >= end %d!", start, end) + return -1, -1, fmt.Errorf("invalid start %d >= end %d", start, end) } return start, end, nil } @@ -526,7 +526,7 @@ func getEndToEndHeaders(respHeaders http.Header) []string { } } endToEndHeaders := []string{} - for respHeader, _ := range respHeaders { + for respHeader := range respHeaders { if _, ok := hopByHopHeaders[respHeader]; !ok { endToEndHeaders = append(endToEndHeaders, respHeader) } From d50d571f3c3d973a628a9bb01d8e2fd73b115f53 Mon Sep 17 00:00:00 2001 From: Andrea Lusuardi - uovobw Date: Wed, 18 Mar 2015 18:07:11 +0100 Subject: [PATCH 20/20] Add cacheproxy library Add cmd/cacheproxy that contains a standalone caching proxy --- cacheproxy/cmd/cacheproxy/main.go | 41 +++++++++++++++++++++++++++++++ cacheproxy/proxy.go | 32 ++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 cacheproxy/cmd/cacheproxy/main.go create mode 100644 cacheproxy/proxy.go diff --git a/cacheproxy/cmd/cacheproxy/main.go b/cacheproxy/cmd/cacheproxy/main.go new file mode 100644 index 0000000..4705023 --- /dev/null +++ b/cacheproxy/cmd/cacheproxy/main.go @@ -0,0 +1,41 @@ +package main + +import ( + "flag" + "log" + "net/http" + "net/url" + "os" + + "github.com/gorilla/handlers" + "github.com/uovobw/httpcache/cacheproxy" +) + +var ( + bindTo = flag.String("bind", "0.0.0.0:8080", "address to bind to") + debug = flag.Bool("debug", false, "enable debugging") + target = flag.String("target", "", "base url to cache") +) + +func init() { + flag.Parse() + flag.VisitAll(func(f *flag.Flag) { + log.Printf("%s=%v", f.Name, f.Value) + }) + if *target == "" { + log.Fatalln("you must specify a target url") + } + if *debug { + logger := log.New(os.Stdout, "", 0) + cacheproxy.SetLogger(logger) + } +} + +func main() { + URL, err := url.Parse(*target) + if err != nil { + log.Fatal(err) + } + proxy := cacheproxy.NewSingleHostReverseProxy(URL) + log.Fatal(http.ListenAndServe(*bindTo, handlers.CombinedLoggingHandler(os.Stdout, proxy))) +} diff --git a/cacheproxy/proxy.go b/cacheproxy/proxy.go new file mode 100644 index 0000000..7cc1234 --- /dev/null +++ b/cacheproxy/proxy.go @@ -0,0 +1,32 @@ +package cacheproxy + +import ( + "github.com/uovobw/httpcache" + "log" + "net/http" + "net/http/httputil" + "net/url" +) + +var ( + memoryCache = httpcache.NewMemoryCache() + transport = httpcache.NewTransport(memoryCache) +) + +// NewSingleHostReverseProxy wraps net/http/httputil.NewSingleHostReverseProxy +// and sets the Host header based on the target URL. +func NewSingleHostReverseProxy(url *url.URL) *httputil.ReverseProxy { + proxy := httputil.NewSingleHostReverseProxy(url) + oldDirector := proxy.Director + proxy.Director = func(r *http.Request) { + oldDirector(r) + r.Host = url.Host + } + proxy.Transport = transport + return proxy +} + +// SetLogger wraps httpcache.SetLogger +func SetLogger(l *log.Logger) { + transport.SetLogger(l) +}