diff --git a/filter/httpcache/configuration.go b/filter/httpcache/configuration.go index fb707226f..7f55107ce 100644 --- a/filter/httpcache/configuration.go +++ b/filter/httpcache/configuration.go @@ -1,7 +1,6 @@ package httpcache import ( - "encoding/json" "regexp" "time" @@ -10,7 +9,12 @@ import ( "github.com/go-kratos/kratos/v2/config" ) -const configuration_key = "httpcache" +const ( + configurationKey = "httpcache" + path = "path" + url = "url" + configurationPK = "configuration" +) func parseRecursively(values map[string]config.Value) map[string]interface{} { result := make(map[string]interface{}) @@ -19,15 +23,20 @@ func parseRecursively(values map[string]config.Value) map[string]interface{} { result[key] = v continue } - if v, e := value.Duration(); e == nil { - result[key] = v - continue - } - if v, e := value.Float(); e == nil { - result[key] = v - continue + switch value.Load().(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + if v, e := value.Int(); e == nil { + result[key] = v + continue + } + case float32, float64: + if v, e := value.Float(); e == nil { + result[key] = v + continue + } } - if v, e := value.Int(); e == nil { + svalue, _ := value.String() + if v, e := time.ParseDuration(svalue); e == nil { result[key] = v continue } @@ -40,264 +49,289 @@ func parseRecursively(values map[string]config.Value) map[string]interface{} { return result } -// ParseConfiguration parse the Kratos configuration into a valid HTTP -// cache configuration object. -func ParseConfiguration(c config.Config) plugins.BaseConfiguration { - var configuration plugins.BaseConfiguration +func parseAPI(apiConfiguration map[string]config.Value) configurationtypes.API { + var a configurationtypes.API + var prometheusConfiguration, souinConfiguration map[string]config.Value - values, _ := c.Value(configuration_key).Map() - for key, v := range values { - switch key { - case "api": - var a configurationtypes.API - var prometheusConfiguration, souinConfiguration, securityConfiguration map[string]config.Value - apiConfiguration, _ := v.Map() - for apiK, apiV := range apiConfiguration { - switch apiK { - case "prometheus": - prometheusConfiguration, _ = apiV.Map() - case "souin": - souinConfiguration, _ = apiV.Map() - case "security": - securityConfiguration, _ = apiV.Map() - } + for apiK, apiV := range apiConfiguration { + switch apiK { + case "prometheus": + prometheusConfiguration, _ = apiV.Map() + case "souin": + souinConfiguration, _ = apiV.Map() + } + } + if prometheusConfiguration != nil { + a.Prometheus = configurationtypes.APIEndpoint{} + a.Prometheus.Enable = true + if prometheusConfiguration["basepath"] != nil { + a.Prometheus.BasePath, _ = prometheusConfiguration["basepath"].String() + } + } + if souinConfiguration != nil { + a.Souin = configurationtypes.APIEndpoint{} + a.Souin.Enable = true + if souinConfiguration["basepath"] != nil { + a.Souin.BasePath, _ = souinConfiguration["basepath"].String() + } + } + return a +} + +func parseCacheKeys(ccConfiguration map[string]config.Value) map[configurationtypes.RegValue]configurationtypes.Key { + cacheKeys := make(map[configurationtypes.RegValue]configurationtypes.Key) + for cacheKeysConfigurationK, cacheKeysConfigurationV := range ccConfiguration { + ck := configurationtypes.Key{} + cacheKeysConfigurationVMap, _ := cacheKeysConfigurationV.Map() + for cacheKeysConfigurationVMapK := range cacheKeysConfigurationVMap { + switch cacheKeysConfigurationVMapK { + case "disable_body": + ck.DisableBody = true + case "disable_host": + ck.DisableHost = true + case "disable_method": + ck.DisableMethod = true } - if prometheusConfiguration != nil { - a.Prometheus = configurationtypes.APIEndpoint{} - a.Prometheus.Enable = true - if prometheusConfiguration["basepath"] != nil { - a.Prometheus.BasePath, _ = prometheusConfiguration["basepath"].String() - } - if prometheusConfiguration["security"] != nil { - a.Prometheus.Security, _ = prometheusConfiguration["security"].Bool() - } + } + rg := regexp.MustCompile(cacheKeysConfigurationK) + cacheKeys[configurationtypes.RegValue{Regexp: rg}] = ck + } + + return cacheKeys +} + +func parseDefaultCache(dcConfiguration map[string]config.Value) *configurationtypes.DefaultCache { + dc := configurationtypes.DefaultCache{ + Distributed: false, + Headers: []string{}, + Olric: configurationtypes.CacheProvider{ + URL: "", + Path: "", + Configuration: nil, + }, + Regex: configurationtypes.Regex{}, + TTL: configurationtypes.Duration{}, + DefaultCacheControl: "", + } + for defaultCacheK, defaultCacheV := range dcConfiguration { + switch defaultCacheK { + case "allowed_http_verbs": + headers, _ := defaultCacheV.Slice() + dc.AllowedHTTPVerbs = make([]string, 0) + for _, header := range headers { + h, _ := header.String() + dc.AllowedHTTPVerbs = append(dc.AllowedHTTPVerbs, h) } - if souinConfiguration != nil { - a.Souin = configurationtypes.APIEndpoint{} - a.Souin.Enable = true - if souinConfiguration["basepath"] != nil { - a.Souin.BasePath, _ = souinConfiguration["basepath"].String() + case "badger": + provider := configurationtypes.CacheProvider{} + badgerConfiguration, _ := defaultCacheV.Map() + for badgerConfigurationK, badgerConfigurationV := range badgerConfiguration { + switch badgerConfigurationK { + case url: + provider.URL, _ = badgerConfigurationV.String() + case path: + provider.Path, _ = badgerConfigurationV.String() + case configurationPK: + configMap, e := badgerConfigurationV.Map() + if e == nil { + provider.Configuration = parseRecursively(configMap) + } } - if souinConfiguration["security"] != nil { - a.Souin.Security, _ = souinConfiguration["security"].Bool() + } + dc.Badger = provider + case "cdn": + cdn := configurationtypes.CDN{} + cdnConfiguration, _ := defaultCacheV.Map() + for cdnConfigurationK, cdnConfigurationV := range cdnConfiguration { + switch cdnConfigurationK { + case "api_key": + cdn.APIKey, _ = cdnConfigurationV.String() + case "dynamic": + cdn.Dynamic = true + case "hostname": + cdn.Hostname, _ = cdnConfigurationV.String() + case "network": + cdn.Network, _ = cdnConfigurationV.String() + case "provider": + cdn.Provider, _ = cdnConfigurationV.String() + case "strategy": + cdn.Strategy, _ = cdnConfigurationV.String() } } - if securityConfiguration != nil { - a.Security = configurationtypes.SecurityAPI{} - a.Security.Enable = true - if securityConfiguration["basepath"] != nil { - a.Security.BasePath, _ = securityConfiguration["basepath"].String() + dc.CDN = cdn + case "etcd": + provider := configurationtypes.CacheProvider{} + etcdConfiguration, _ := defaultCacheV.Map() + for etcdConfigurationK, etcdConfigurationV := range etcdConfiguration { + switch etcdConfigurationK { + case url: + provider.URL, _ = etcdConfigurationV.String() + case path: + provider.Path, _ = etcdConfigurationV.String() + case configurationPK: + configMap, e := etcdConfigurationV.Map() + if e == nil { + provider.Configuration = parseRecursively(configMap) + } } - if securityConfiguration["users"] != nil { - users, _ := securityConfiguration["users"].Slice() - a.Security.Users = make([]configurationtypes.User, 0) - for _, user := range users { - currentUser, _ := user.Map() - username, _ := currentUser["username"].String() - password, _ := currentUser["password"].String() - a.Security.Users = append(a.Security.Users, configurationtypes.User{ - Username: username, - Password: password, - }) + } + dc.Etcd = provider + case "headers": + headers, _ := defaultCacheV.Slice() + dc.Headers = make([]string, 0) + for _, header := range headers { + h, _ := header.String() + dc.Headers = append(dc.Headers, h) + } + case "nuts": + provider := configurationtypes.CacheProvider{} + nutsConfiguration, _ := defaultCacheV.Map() + for nutsConfigurationK, nutsConfigurationV := range nutsConfiguration { + switch nutsConfigurationK { + case url: + provider.URL, _ = nutsConfigurationV.String() + case path: + provider.Path, _ = nutsConfigurationV.String() + case configurationPK: + configMap, e := nutsConfigurationV.Map() + if e == nil { + provider.Configuration = parseRecursively(configMap) } } } - configuration.API = a - case "cache_keys": - cacheKeys := make(map[configurationtypes.RegValue]configurationtypes.Key) - cacheKeysConfiguration, _ := v.Map() - for cacheKeysConfigurationK, cacheKeysConfigurationV := range cacheKeysConfiguration { - ck := configurationtypes.Key{} - cacheKeysConfigurationVMap, _ := cacheKeysConfigurationV.Map() - for cacheKeysConfigurationVMapK := range cacheKeysConfigurationVMap { - switch cacheKeysConfigurationVMapK { - case "disable_body": - ck.DisableBody = true - case "disable_host": - ck.DisableHost = true - case "disable_method": - ck.DisableMethod = true + dc.Nuts = provider + case "olric": + provider := configurationtypes.CacheProvider{} + olricConfiguration, _ := defaultCacheV.Map() + for olricConfigurationK, olricConfigurationV := range olricConfiguration { + switch olricConfigurationK { + case url: + provider.URL, _ = olricConfigurationV.String() + case path: + provider.Path, _ = olricConfigurationV.String() + case configurationPK: + configMap, e := olricConfigurationV.Map() + if e == nil { + provider.Configuration = parseRecursively(configMap) } } - rg := regexp.MustCompile(cacheKeysConfigurationK) - cacheKeys[configurationtypes.RegValue{Regexp: rg}] = ck } - configuration.CacheKeys = cacheKeys - case "default_cache": - dc := configurationtypes.DefaultCache{ - Distributed: false, - Headers: []string{}, - Olric: configurationtypes.CacheProvider{ - URL: "", - Path: "", - Configuration: nil, - }, - Regex: configurationtypes.Regex{}, - TTL: configurationtypes.Duration{}, - DefaultCacheControl: "", + dc.Distributed = true + dc.Olric = provider + case "regex": + regex, _ := defaultCacheV.Map() + exclude, _ := regex["exclude"].String() + if exclude != "" { + dc.Regex = configurationtypes.Regex{Exclude: exclude} } - defaultCache, _ := v.Map() - for defaultCacheK, defaultCacheV := range defaultCache { - switch defaultCacheK { - case "badger": - provider := configurationtypes.CacheProvider{} - badgerConfiguration, _ := v.Map() - for badgerConfigurationK, badgerConfigurationV := range badgerConfiguration { - switch badgerConfigurationK { - case "url": - provider.URL, _ = badgerConfigurationV.String() - case "path": - provider.Path, _ = badgerConfigurationV.String() - case "configuration": - configMap, e := badgerConfigurationV.Map() - if e == nil { - provider.Configuration = parseRecursively(configMap) - } - } - } - configuration.DefaultCache.Badger = provider - case "cdn": - cdn := configurationtypes.CDN{} - cdnConfiguration, _ := v.Map() - for cdnConfigurationK, cdnConfigurationV := range cdnConfiguration { - switch cdnConfigurationK { - case "api_key": - cdn.APIKey, _ = cdnConfigurationV.String() - case "dynamic": - cdn.Dynamic = true - case "hostname": - cdn.Hostname, _ = cdnConfigurationV.String() - case "network": - cdn.Network, _ = cdnConfigurationV.String() - case "provider": - cdn.Provider, _ = cdnConfigurationV.String() - case "strategy": - cdn.Strategy, _ = cdnConfigurationV.String() - } - } - configuration.DefaultCache.CDN = cdn - case "etcd": - provider := configurationtypes.CacheProvider{} - etcdConfiguration, _ := v.Map() - for etcdConfigurationK, etcdConfigurationV := range etcdConfiguration { - switch etcdConfigurationK { - case "url": - provider.URL, _ = etcdConfigurationV.String() - case "path": - provider.Path, _ = etcdConfigurationV.String() - case "configuration": - configMap, e := etcdConfigurationV.Map() - if e == nil { - provider.Configuration = parseRecursively(configMap) - } - } - } - configuration.DefaultCache.Etcd = provider - case "headers": - headers, _ := defaultCacheV.Slice() - dc.Headers = make([]string, 0) - for _, header := range headers { - h, _ := header.String() - dc.Headers = append(dc.Headers, h) - } - case "nuts": - provider := configurationtypes.CacheProvider{} - nutsConfiguration, _ := v.Map() - for nutsConfigurationK, nutsConfigurationV := range nutsConfiguration { - switch nutsConfigurationK { - case "url": - provider.URL, _ = nutsConfigurationV.String() - case "path": - provider.Path, _ = nutsConfigurationV.String() - case "configuration": - configMap, e := nutsConfigurationV.Map() - if e == nil { - provider.Configuration = parseRecursively(configMap) - } - } - } - configuration.DefaultCache.Nuts = provider - case "olric": - provider := configurationtypes.CacheProvider{} - olricConfiguration, _ := v.Map() - for olricConfigurationK, olricConfigurationV := range olricConfiguration { - switch olricConfigurationK { - case "url": - provider.URL, _ = olricConfigurationV.String() - case "path": - provider.Path, _ = olricConfigurationV.String() - case "configuration": - configMap, e := olricConfigurationV.Map() - if e == nil { - provider.Configuration = parseRecursively(configMap) - } - } - } - configuration.DefaultCache.Distributed = true - configuration.DefaultCache.Olric = provider - case "regex": - regex, _ := defaultCacheV.Map() - exclude, _ := regex["exclude"].String() - if exclude != "" { - dc.Regex = configurationtypes.Regex{Exclude: exclude} - } - case "ttl": - sttl, err := defaultCacheV.String() - ttl, _ := time.ParseDuration(sttl) - if err == nil { - dc.TTL = configurationtypes.Duration{Duration: ttl} - } - case "stale": - sstale, err := defaultCacheV.String() - stale, _ := time.ParseDuration(sstale) - if err == nil { - dc.Stale = configurationtypes.Duration{Duration: stale} + case "ttl": + sttl, err := defaultCacheV.String() + ttl, _ := time.ParseDuration(sttl) + if err == nil { + dc.TTL = configurationtypes.Duration{Duration: ttl} + } + case "stale": + sstale, err := defaultCacheV.String() + stale, _ := time.ParseDuration(sstale) + if err == nil { + dc.Stale = configurationtypes.Duration{Duration: stale} + } + case "default_cache_control": + dc.DefaultCacheControl, _ = defaultCacheV.String() + } + } + + return &dc +} + +func parseURLs(urls map[string]config.Value) map[string]configurationtypes.URL { + u := make(map[string]configurationtypes.URL) + + for urlK, urlV := range urls { + currentURL := configurationtypes.URL{ + TTL: configurationtypes.Duration{}, + Headers: nil, + } + currentValue, _ := urlV.Map() + if currentValue["headers"] != nil { + currentURL.Headers = make([]string, 0) + headers, _ := currentValue["headers"].Slice() + for _, header := range headers { + h, _ := header.String() + currentURL.Headers = append(currentURL.Headers, h) + } + } + sttl, err := currentValue["ttl"].String() + ttl, _ := time.ParseDuration(sttl) + if err == nil { + currentURL.TTL = configurationtypes.Duration{Duration: ttl} + } + if _, exists := currentValue["default_cache_control"]; exists { + currentURL.DefaultCacheControl, _ = currentValue["default_cache_control"].String() + } + u[urlK] = currentURL + } + + return u +} + +func parseSurrogateKeys(surrogates map[string]config.Value) map[string]configurationtypes.SurrogateKeys { + u := make(map[string]configurationtypes.SurrogateKeys) + + for surrogateK, surrogateV := range surrogates { + surrogate := configurationtypes.SurrogateKeys{} + currentValue, _ := surrogateV.Map() + for key, value := range currentValue { + switch key { + case "headers": + surrogate.Headers = map[string]string{} + headers, e := value.Map() + if e == nil { + for hKey, hValue := range headers { + v, _ := hValue.String() + surrogate.Headers[hKey] = v } - case "default_cache_control": - dc.DefaultCacheControl, _ = defaultCacheV.String() } + case "url": + surl, _ := currentValue["url"].String() + surrogate.URL = surl } - configuration.DefaultCache = &dc + } + u[surrogateK] = surrogate + } + + return u +} + +// ParseConfiguration parse the Kratos configuration into a valid HTTP +// cache configuration object. +func ParseConfiguration(c config.Config) plugins.BaseConfiguration { + var configuration plugins.BaseConfiguration + + values, _ := c.Value(configurationKey).Map() + for key, v := range values { + switch key { + case "api": + apiConfiguration, _ := v.Map() + configuration.API = parseAPI(apiConfiguration) + case "cache_keys": + cacheKeysConfiguration, _ := v.Map() + configuration.CacheKeys = parseCacheKeys(cacheKeysConfiguration) + case "default_cache": + defaultCache, _ := v.Map() + configuration.DefaultCache = parseDefaultCache(defaultCache) case "log_level": configuration.LogLevel, _ = v.String() case "urls": - u := make(map[string]configurationtypes.URL) urls, _ := v.Map() - - for urlK, urlV := range urls { - currentURL := configurationtypes.URL{ - TTL: configurationtypes.Duration{}, - Headers: nil, - } - currentValue, _ := urlV.Map() - currentURL.Headers = make([]string, 0) - headers, _ := currentValue["headers"].Slice() - for _, header := range headers { - h, _ := header.String() - currentURL.Headers = append(currentURL.Headers, h) - } - sttl, err := currentValue["ttl"].String() - ttl, _ := time.ParseDuration(sttl) - if err == nil { - currentURL.TTL = configurationtypes.Duration{Duration: ttl} - } - if _, exists := currentValue["default_cache_control"]; exists { - currentURL.DefaultCacheControl, _ = currentValue["default_cache_control"].String() - } - u[urlK] = currentURL - } - configuration.URLs = u + configuration.URLs = parseURLs(urls) case "ykeys": - ykeys := make(map[string]configurationtypes.SurrogateKeys) - d, _ := json.Marshal(v) - _ = json.Unmarshal(d, &ykeys) - configuration.Ykeys = ykeys + ykeys, _ := v.Map() + configuration.Ykeys = parseSurrogateKeys(ykeys) case "surrogate_keys": - ykeys := make(map[string]configurationtypes.SurrogateKeys) - d, _ := json.Marshal(v) - _ = json.Unmarshal(d, &ykeys) - configuration.Ykeys = ykeys + surrogates, _ := v.Map() + configuration.SurrogateKeys = parseSurrogateKeys(surrogates) } } diff --git a/filter/httpcache/configuration_test.go b/filter/httpcache/configuration_test.go new file mode 100644 index 000000000..2be5f1611 --- /dev/null +++ b/filter/httpcache/configuration_test.go @@ -0,0 +1,113 @@ +package httpcache + +import ( + "io/ioutil" + "testing" + "time" + + "github.com/go-kratos/kratos/v2/config" + "github.com/go-kratos/kratos/v2/config/file" + "gopkg.in/yaml.v2" +) + +var dummyConfig = []byte(` +httpcache: + api: + basepath: /souin-api + prometheus: + basepath: /anything-for-prometheus-metrics + souin: + basepath: /anything-for-souin + cache_keys: + .+: + disable_body: true + disable_host: true + disable_method: true + default_cache: + allowed_http_verbs: + - GET + - POST + - HEAD + badger: + url: /badger/url + path: /badger/path + configuration: + SyncEnable: false + cdn: + api_key: XXXX + dynamic: true + hostname: XXXX + network: XXXX + provider: fastly + strategy: soft + etcd: + url: /etcd/url + path: /etcd/path + configuration: + SyncEnable: false + headers: + - Authorization + nuts: + url: /etcd/url + path: /etcd/path + configuration: + SyncEnable: false + ValueFloat: 1.123 + ValueDuration: 1s + ValueInt: 2 + ValueObject: + ValueFloat: 1.123 + olric: + url: 'olric:3320' + path: /olric/path + configuration: + SyncEnable: false + regex: + exclude: 'ARegexHere' + ttl: 10s + stale: 10s + default_cache_control: no-store + log_level: debug + urls: + 'https:\/\/domain.com\/first-.+': + ttl: 1000s + 'https:\/\/domain.com\/second-route': + ttl: 10s + headers: + - Authorization + 'https?:\/\/mysubdomain\.domain\.com': + ttl: 50s + default_cache_control: no-cache + headers: + - Authorization + - 'Content-Type' + ykeys: + The_First_Test: + headers: + Content-Type: '.+' + The_Second_Test: + url: 'the/second/.+' + surrogate_keys: + The_First_Test: + headers: + Content-Type: '.+' + The_Second_Test: + url: 'the/second/.+' +`) + +func Test_ParseConfiguration(t *testing.T) { + filename := "/tmp/httpcache-" + time.Now().String() + ".yml" + ioutil.WriteFile(filename, dummyConfig, 0777) + c := config.New( + config.WithSource(file.NewSource(filename)), + config.WithDecoder(func(kv *config.KeyValue, v map[string]interface{}) error { + return yaml.Unmarshal(kv.Value, v) + }), + ) + + if err := c.Load(); err != nil { + panic(err) + } + + _ = ParseConfiguration(c) +} diff --git a/filter/httpcache/httpcache.go b/filter/httpcache/httpcache.go index 84d62d777..394993219 100644 --- a/filter/httpcache/httpcache.go +++ b/filter/httpcache/httpcache.go @@ -91,6 +91,7 @@ func (s *httpcacheKratosPlugin) handle(next http.Handler) http.Handler { combo.next.ServeHTTP(customWriter, r) combo.req.Response = customWriter.Response + // nolint if combo.req.Response, e = s.Retriever.GetTransport().(*rfc.VaryTransport).UpdateCacheEventually(combo.req); e != nil { return e } diff --git a/filter/httpcache/httpcache_test.go b/filter/httpcache/httpcache_test.go index 4e780495b..b777cb1d3 100644 --- a/filter/httpcache/httpcache_test.go +++ b/filter/httpcache/httpcache_test.go @@ -13,35 +13,33 @@ import ( "github.com/darkweak/souin/plugins" ) -var ( - devDefaultConfiguration = plugins.BaseConfiguration{ - API: configurationtypes.API{ - BasePath: "/httpcache_api", - Prometheus: configurationtypes.APIEndpoint{ - Enable: true, - }, - Souin: configurationtypes.APIEndpoint{ - BasePath: "/httpcache", - Enable: true, - }, +var devDefaultConfiguration = plugins.BaseConfiguration{ + API: configurationtypes.API{ + BasePath: "/httpcache_api", + Prometheus: configurationtypes.APIEndpoint{ + Enable: true, }, - DefaultCache: &configurationtypes.DefaultCache{ - Regex: configurationtypes.Regex{ - Exclude: "/excluded", - }, - TTL: configurationtypes.Duration{ - Duration: time.Second, - }, + Souin: configurationtypes.APIEndpoint{ + BasePath: "/httpcache", + Enable: true, }, - LogLevel: "debug", - } -) + }, + DefaultCache: &configurationtypes.DefaultCache{ + Regex: configurationtypes.Regex{ + Exclude: "/excluded", + }, + TTL: configurationtypes.Duration{ + Duration: time.Second, + }, + }, + LogLevel: "debug", +} type next struct{} func (n *next) ServeHTTP(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(http.StatusOK) - rw.Write([]byte("Hello Kratos!")) + _, _ = rw.Write([]byte("Hello Kratos!")) } var nextFilter = &next{} @@ -65,14 +63,18 @@ func Test_HttpcacheKratosPlugin_NewHTTPCacheFilter(t *testing.T) { handler := NewHTTPCacheFilter(devDefaultConfiguration)(nextFilter) req, res, res2 := prepare("/handled") handler.ServeHTTP(res, req) - if res.Result().Header.Get("Cache-Status") != "Souin; fwd=uri-miss; stored" { + rs := res.Result() + rs.Body.Close() + if rs.Header.Get("Cache-Status") != "Souin; fwd=uri-miss; stored" { t.Error("The response must contain a Cache-Status header with the stored directive.") } handler.ServeHTTP(res2, req) - if res2.Result().Header.Get("Cache-Status") != "Souin; hit; ttl=0" { + rs = res2.Result() + rs.Body.Close() + if rs.Header.Get("Cache-Status") != "Souin; hit; ttl=0" { t.Error("The response must contain a Cache-Status header with the hit and ttl directives.") } - if res2.Result().Header.Get("Age") != "1" { + if rs.Header.Get("Age") != "1" { t.Error("The response must contain a Age header with the value 1.") } } @@ -82,14 +84,18 @@ func Test_HttpcacheKratosPlugin_NewHTTPCacheFilter_Excluded(t *testing.T) { handler := NewHTTPCacheFilter(devDefaultConfiguration)(nextFilter) req, res, res2 := prepare("/excluded") handler.ServeHTTP(res, req) - if res.Result().Header.Get("Cache-Status") != "Souin; fwd=uri-miss" { + rs := res.Result() + rs.Body.Close() + if rs.Header.Get("Cache-Status") != "Souin; fwd=uri-miss" { t.Error("The response must contain a Cache-Status header without the stored directive and with the uri-miss only.") } handler.ServeHTTP(res2, req) - if res2.Result().Header.Get("Cache-Status") != "Souin; fwd=uri-miss" { + rs = res2.Result() + rs.Body.Close() + if rs.Header.Get("Cache-Status") != "Souin; fwd=uri-miss" { t.Error("The response must contain a Cache-Status header without the stored directive and with the uri-miss only.") } - if res2.Result().Header.Get("Age") != "" { + if rs.Header.Get("Age") != "" { t.Error("The response must not contain a Age header.") } } @@ -102,15 +108,19 @@ func Test_HttpcacheKratosPlugin_NewHTTPCacheFilter_Mutation(t *testing.T) { req, res, res2 := prepare("/handled") req.Body = ioutil.NopCloser(bytes.NewBuffer([]byte(`{"query":"mutation":{something mutated}}`))) handler.ServeHTTP(res, req) - if res.Result().Header.Get("Cache-Status") != "Souin; fwd=uri-miss" { + rs := res.Result() + rs.Body.Close() + if rs.Header.Get("Cache-Status") != "Souin; fwd=uri-miss" { t.Error("The response must contain a Cache-Status header without the stored directive and with the uri-miss only.") } req.Body = ioutil.NopCloser(bytes.NewBuffer([]byte(`{"query":"mutation":{something mutated}}`))) handler.ServeHTTP(res2, req) - if res2.Result().Header.Get("Cache-Status") != "Souin; fwd=uri-miss" { + rs = res2.Result() + rs.Body.Close() + if rs.Header.Get("Cache-Status") != "Souin; fwd=uri-miss" { t.Error("The response must contain a Cache-Status header without the stored directive and with the uri-miss only.") } - if res2.Result().Header.Get("Age") != "" { + if rs.Header.Get("Age") != "" { t.Error("The response must not contain a Age header.") } } @@ -120,26 +130,32 @@ func Test_HttpcacheKratosPlugin_NewHTTPCacheFilter_API(t *testing.T) { handler := NewHTTPCacheFilter(devDefaultConfiguration)(nextFilter) req, res, res2 := prepare("/httpcache_api/httpcache") handler.ServeHTTP(res, req) - if res.Result().Header.Get("Content-Type") != "application/json" { + rs := res.Result() + defer rs.Body.Close() + if rs.Header.Get("Content-Type") != "application/json" { t.Error("The response must contain be in JSON.") } - b, _ := ioutil.ReadAll(res.Result().Body) + b, _ := ioutil.ReadAll(rs.Body) res.Result().Body.Close() if string(b) != "[]" { t.Error("The response body must be an empty array because no request has been stored") } req2 := httptest.NewRequest(http.MethodGet, "/handled", nil) handler.ServeHTTP(res2, req2) - if res2.Result().Header.Get("Cache-Status") != "Souin; fwd=uri-miss; stored" { + rs = res2.Result() + rs.Body.Close() + if rs.Header.Get("Cache-Status") != "Souin; fwd=uri-miss; stored" { t.Error("The response must contain a Cache-Status header with the stored directive.") } res3 := httptest.NewRecorder() handler.ServeHTTP(res3, req) - if res3.Result().Header.Get("Content-Type") != "application/json" { + rs = res3.Result() + rs.Body.Close() + if rs.Header.Get("Content-Type") != "application/json" { t.Error("The response must contain be in JSON.") } - b, _ = ioutil.ReadAll(res3.Result().Body) - res3.Result().Body.Close() + b, _ = ioutil.ReadAll(rs.Body) + rs.Body.Close() var payload []string _ = json.Unmarshal(b, &payload) if len(payload) != 2 {