Increase coverage

pull/2159/head
darkweak 2 years ago
parent 778c8677b8
commit 3f8b5bfc5c
No known key found for this signature in database
GPG Key ID: 57052F3222742CA6
  1. 210
      filter/httpcache/configuration.go
  2. 113
      filter/httpcache/configuration_test.go
  3. 1
      filter/httpcache/httpcache.go
  4. 54
      filter/httpcache/httpcache_test.go

@ -1,7 +1,6 @@
package httpcache package httpcache
import ( import (
"encoding/json"
"regexp" "regexp"
"time" "time"
@ -10,7 +9,12 @@ import (
"github.com/go-kratos/kratos/v2/config" "github.com/go-kratos/kratos/v2/config"
) )
const configuration_key = "httpcache" const (
configurationKey = "httpcache"
path = "path"
url = "url"
configurationPK = "configuration"
)
func parseRecursively(values map[string]config.Value) map[string]interface{} { func parseRecursively(values map[string]config.Value) map[string]interface{} {
result := make(map[string]interface{}) result := make(map[string]interface{})
@ -19,15 +23,20 @@ func parseRecursively(values map[string]config.Value) map[string]interface{} {
result[key] = v result[key] = v
continue continue
} }
if v, e := value.Duration(); e == nil { switch value.Load().(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
if v, e := value.Int(); e == nil {
result[key] = v result[key] = v
continue continue
} }
case float32, float64:
if v, e := value.Float(); e == nil { if v, e := value.Float(); e == nil {
result[key] = v result[key] = v
continue continue
} }
if v, e := value.Int(); e == nil { }
svalue, _ := value.String()
if v, e := time.ParseDuration(svalue); e == nil {
result[key] = v result[key] = v
continue continue
} }
@ -40,28 +49,17 @@ func parseRecursively(values map[string]config.Value) map[string]interface{} {
return result return result
} }
// ParseConfiguration parse the Kratos configuration into a valid HTTP func parseAPI(apiConfiguration map[string]config.Value) configurationtypes.API {
// cache configuration object.
func ParseConfiguration(c config.Config) plugins.BaseConfiguration {
var configuration plugins.BaseConfiguration
values, _ := c.Value(configuration_key).Map()
for key, v := range values {
switch key {
case "api":
var a configurationtypes.API var a configurationtypes.API
var prometheusConfiguration, souinConfiguration, securityConfiguration map[string]config.Value var prometheusConfiguration, souinConfiguration map[string]config.Value
apiConfiguration, _ := v.Map()
for apiK, apiV := range apiConfiguration { for apiK, apiV := range apiConfiguration {
switch apiK { switch apiK {
case "prometheus": case "prometheus":
prometheusConfiguration, _ = apiV.Map() prometheusConfiguration, _ = apiV.Map()
case "souin": case "souin":
souinConfiguration, _ = apiV.Map() souinConfiguration, _ = apiV.Map()
case "security":
securityConfiguration, _ = apiV.Map()
} }
} }
if prometheusConfiguration != nil { if prometheusConfiguration != nil {
a.Prometheus = configurationtypes.APIEndpoint{} a.Prometheus = configurationtypes.APIEndpoint{}
@ -69,9 +67,6 @@ func ParseConfiguration(c config.Config) plugins.BaseConfiguration {
if prometheusConfiguration["basepath"] != nil { if prometheusConfiguration["basepath"] != nil {
a.Prometheus.BasePath, _ = prometheusConfiguration["basepath"].String() a.Prometheus.BasePath, _ = prometheusConfiguration["basepath"].String()
} }
if prometheusConfiguration["security"] != nil {
a.Prometheus.Security, _ = prometheusConfiguration["security"].Bool()
}
} }
if souinConfiguration != nil { if souinConfiguration != nil {
a.Souin = configurationtypes.APIEndpoint{} a.Souin = configurationtypes.APIEndpoint{}
@ -79,35 +74,14 @@ func ParseConfiguration(c config.Config) plugins.BaseConfiguration {
if souinConfiguration["basepath"] != nil { if souinConfiguration["basepath"] != nil {
a.Souin.BasePath, _ = souinConfiguration["basepath"].String() a.Souin.BasePath, _ = souinConfiguration["basepath"].String()
} }
if souinConfiguration["security"] != nil {
a.Souin.Security, _ = souinConfiguration["security"].Bool()
}
}
if securityConfiguration != nil {
a.Security = configurationtypes.SecurityAPI{}
a.Security.Enable = true
if securityConfiguration["basepath"] != nil {
a.Security.BasePath, _ = securityConfiguration["basepath"].String()
}
if securityConfiguration["users"] != nil {
users, _ := securityConfiguration["users"].Slice()
a.Security.Users = make([]configurationtypes.User, 0)
for _, user := range users {
currentUser, _ := user.Map()
username, _ := currentUser["username"].String()
password, _ := currentUser["password"].String()
a.Security.Users = append(a.Security.Users, configurationtypes.User{
Username: username,
Password: password,
})
}
} }
return a
} }
configuration.API = a
case "cache_keys": func parseCacheKeys(ccConfiguration map[string]config.Value) map[configurationtypes.RegValue]configurationtypes.Key {
cacheKeys := make(map[configurationtypes.RegValue]configurationtypes.Key) cacheKeys := make(map[configurationtypes.RegValue]configurationtypes.Key)
cacheKeysConfiguration, _ := v.Map() for cacheKeysConfigurationK, cacheKeysConfigurationV := range ccConfiguration {
for cacheKeysConfigurationK, cacheKeysConfigurationV := range cacheKeysConfiguration {
ck := configurationtypes.Key{} ck := configurationtypes.Key{}
cacheKeysConfigurationVMap, _ := cacheKeysConfigurationV.Map() cacheKeysConfigurationVMap, _ := cacheKeysConfigurationV.Map()
for cacheKeysConfigurationVMapK := range cacheKeysConfigurationVMap { for cacheKeysConfigurationVMapK := range cacheKeysConfigurationVMap {
@ -123,8 +97,11 @@ func ParseConfiguration(c config.Config) plugins.BaseConfiguration {
rg := regexp.MustCompile(cacheKeysConfigurationK) rg := regexp.MustCompile(cacheKeysConfigurationK)
cacheKeys[configurationtypes.RegValue{Regexp: rg}] = ck cacheKeys[configurationtypes.RegValue{Regexp: rg}] = ck
} }
configuration.CacheKeys = cacheKeys
case "default_cache": return cacheKeys
}
func parseDefaultCache(dcConfiguration map[string]config.Value) *configurationtypes.DefaultCache {
dc := configurationtypes.DefaultCache{ dc := configurationtypes.DefaultCache{
Distributed: false, Distributed: false,
Headers: []string{}, Headers: []string{},
@ -137,29 +114,35 @@ func ParseConfiguration(c config.Config) plugins.BaseConfiguration {
TTL: configurationtypes.Duration{}, TTL: configurationtypes.Duration{},
DefaultCacheControl: "", DefaultCacheControl: "",
} }
defaultCache, _ := v.Map() for defaultCacheK, defaultCacheV := range dcConfiguration {
for defaultCacheK, defaultCacheV := range defaultCache {
switch defaultCacheK { switch defaultCacheK {
case "allowed_http_verbs":
headers, _ := defaultCacheV.Slice()
dc.AllowedHTTPVerbs = make([]string, 0)
for _, header := range headers {
h, _ := header.String()
dc.AllowedHTTPVerbs = append(dc.AllowedHTTPVerbs, h)
}
case "badger": case "badger":
provider := configurationtypes.CacheProvider{} provider := configurationtypes.CacheProvider{}
badgerConfiguration, _ := v.Map() badgerConfiguration, _ := defaultCacheV.Map()
for badgerConfigurationK, badgerConfigurationV := range badgerConfiguration { for badgerConfigurationK, badgerConfigurationV := range badgerConfiguration {
switch badgerConfigurationK { switch badgerConfigurationK {
case "url": case url:
provider.URL, _ = badgerConfigurationV.String() provider.URL, _ = badgerConfigurationV.String()
case "path": case path:
provider.Path, _ = badgerConfigurationV.String() provider.Path, _ = badgerConfigurationV.String()
case "configuration": case configurationPK:
configMap, e := badgerConfigurationV.Map() configMap, e := badgerConfigurationV.Map()
if e == nil { if e == nil {
provider.Configuration = parseRecursively(configMap) provider.Configuration = parseRecursively(configMap)
} }
} }
} }
configuration.DefaultCache.Badger = provider dc.Badger = provider
case "cdn": case "cdn":
cdn := configurationtypes.CDN{} cdn := configurationtypes.CDN{}
cdnConfiguration, _ := v.Map() cdnConfiguration, _ := defaultCacheV.Map()
for cdnConfigurationK, cdnConfigurationV := range cdnConfiguration { for cdnConfigurationK, cdnConfigurationV := range cdnConfiguration {
switch cdnConfigurationK { switch cdnConfigurationK {
case "api_key": case "api_key":
@ -176,24 +159,24 @@ func ParseConfiguration(c config.Config) plugins.BaseConfiguration {
cdn.Strategy, _ = cdnConfigurationV.String() cdn.Strategy, _ = cdnConfigurationV.String()
} }
} }
configuration.DefaultCache.CDN = cdn dc.CDN = cdn
case "etcd": case "etcd":
provider := configurationtypes.CacheProvider{} provider := configurationtypes.CacheProvider{}
etcdConfiguration, _ := v.Map() etcdConfiguration, _ := defaultCacheV.Map()
for etcdConfigurationK, etcdConfigurationV := range etcdConfiguration { for etcdConfigurationK, etcdConfigurationV := range etcdConfiguration {
switch etcdConfigurationK { switch etcdConfigurationK {
case "url": case url:
provider.URL, _ = etcdConfigurationV.String() provider.URL, _ = etcdConfigurationV.String()
case "path": case path:
provider.Path, _ = etcdConfigurationV.String() provider.Path, _ = etcdConfigurationV.String()
case "configuration": case configurationPK:
configMap, e := etcdConfigurationV.Map() configMap, e := etcdConfigurationV.Map()
if e == nil { if e == nil {
provider.Configuration = parseRecursively(configMap) provider.Configuration = parseRecursively(configMap)
} }
} }
} }
configuration.DefaultCache.Etcd = provider dc.Etcd = provider
case "headers": case "headers":
headers, _ := defaultCacheV.Slice() headers, _ := defaultCacheV.Slice()
dc.Headers = make([]string, 0) dc.Headers = make([]string, 0)
@ -203,39 +186,39 @@ func ParseConfiguration(c config.Config) plugins.BaseConfiguration {
} }
case "nuts": case "nuts":
provider := configurationtypes.CacheProvider{} provider := configurationtypes.CacheProvider{}
nutsConfiguration, _ := v.Map() nutsConfiguration, _ := defaultCacheV.Map()
for nutsConfigurationK, nutsConfigurationV := range nutsConfiguration { for nutsConfigurationK, nutsConfigurationV := range nutsConfiguration {
switch nutsConfigurationK { switch nutsConfigurationK {
case "url": case url:
provider.URL, _ = nutsConfigurationV.String() provider.URL, _ = nutsConfigurationV.String()
case "path": case path:
provider.Path, _ = nutsConfigurationV.String() provider.Path, _ = nutsConfigurationV.String()
case "configuration": case configurationPK:
configMap, e := nutsConfigurationV.Map() configMap, e := nutsConfigurationV.Map()
if e == nil { if e == nil {
provider.Configuration = parseRecursively(configMap) provider.Configuration = parseRecursively(configMap)
} }
} }
} }
configuration.DefaultCache.Nuts = provider dc.Nuts = provider
case "olric": case "olric":
provider := configurationtypes.CacheProvider{} provider := configurationtypes.CacheProvider{}
olricConfiguration, _ := v.Map() olricConfiguration, _ := defaultCacheV.Map()
for olricConfigurationK, olricConfigurationV := range olricConfiguration { for olricConfigurationK, olricConfigurationV := range olricConfiguration {
switch olricConfigurationK { switch olricConfigurationK {
case "url": case url:
provider.URL, _ = olricConfigurationV.String() provider.URL, _ = olricConfigurationV.String()
case "path": case path:
provider.Path, _ = olricConfigurationV.String() provider.Path, _ = olricConfigurationV.String()
case "configuration": case configurationPK:
configMap, e := olricConfigurationV.Map() configMap, e := olricConfigurationV.Map()
if e == nil { if e == nil {
provider.Configuration = parseRecursively(configMap) provider.Configuration = parseRecursively(configMap)
} }
} }
} }
configuration.DefaultCache.Distributed = true dc.Distributed = true
configuration.DefaultCache.Olric = provider dc.Olric = provider
case "regex": case "regex":
regex, _ := defaultCacheV.Map() regex, _ := defaultCacheV.Map()
exclude, _ := regex["exclude"].String() exclude, _ := regex["exclude"].String()
@ -258,12 +241,12 @@ func ParseConfiguration(c config.Config) plugins.BaseConfiguration {
dc.DefaultCacheControl, _ = defaultCacheV.String() dc.DefaultCacheControl, _ = defaultCacheV.String()
} }
} }
configuration.DefaultCache = &dc
case "log_level": return &dc
configuration.LogLevel, _ = v.String() }
case "urls":
func parseURLs(urls map[string]config.Value) map[string]configurationtypes.URL {
u := make(map[string]configurationtypes.URL) u := make(map[string]configurationtypes.URL)
urls, _ := v.Map()
for urlK, urlV := range urls { for urlK, urlV := range urls {
currentURL := configurationtypes.URL{ currentURL := configurationtypes.URL{
@ -271,12 +254,14 @@ func ParseConfiguration(c config.Config) plugins.BaseConfiguration {
Headers: nil, Headers: nil,
} }
currentValue, _ := urlV.Map() currentValue, _ := urlV.Map()
if currentValue["headers"] != nil {
currentURL.Headers = make([]string, 0) currentURL.Headers = make([]string, 0)
headers, _ := currentValue["headers"].Slice() headers, _ := currentValue["headers"].Slice()
for _, header := range headers { for _, header := range headers {
h, _ := header.String() h, _ := header.String()
currentURL.Headers = append(currentURL.Headers, h) currentURL.Headers = append(currentURL.Headers, h)
} }
}
sttl, err := currentValue["ttl"].String() sttl, err := currentValue["ttl"].String()
ttl, _ := time.ParseDuration(sttl) ttl, _ := time.ParseDuration(sttl)
if err == nil { if err == nil {
@ -287,17 +272,66 @@ func ParseConfiguration(c config.Config) plugins.BaseConfiguration {
} }
u[urlK] = currentURL u[urlK] = currentURL
} }
configuration.URLs = u
return u
}
func parseSurrogateKeys(surrogates map[string]config.Value) map[string]configurationtypes.SurrogateKeys {
u := make(map[string]configurationtypes.SurrogateKeys)
for surrogateK, surrogateV := range surrogates {
surrogate := configurationtypes.SurrogateKeys{}
currentValue, _ := surrogateV.Map()
for key, value := range currentValue {
switch key {
case "headers":
surrogate.Headers = map[string]string{}
headers, e := value.Map()
if e == nil {
for hKey, hValue := range headers {
v, _ := hValue.String()
surrogate.Headers[hKey] = v
}
}
case "url":
surl, _ := currentValue["url"].String()
surrogate.URL = surl
}
}
u[surrogateK] = surrogate
}
return u
}
// ParseConfiguration parse the Kratos configuration into a valid HTTP
// cache configuration object.
func ParseConfiguration(c config.Config) plugins.BaseConfiguration {
var configuration plugins.BaseConfiguration
values, _ := c.Value(configurationKey).Map()
for key, v := range values {
switch key {
case "api":
apiConfiguration, _ := v.Map()
configuration.API = parseAPI(apiConfiguration)
case "cache_keys":
cacheKeysConfiguration, _ := v.Map()
configuration.CacheKeys = parseCacheKeys(cacheKeysConfiguration)
case "default_cache":
defaultCache, _ := v.Map()
configuration.DefaultCache = parseDefaultCache(defaultCache)
case "log_level":
configuration.LogLevel, _ = v.String()
case "urls":
urls, _ := v.Map()
configuration.URLs = parseURLs(urls)
case "ykeys": case "ykeys":
ykeys := make(map[string]configurationtypes.SurrogateKeys) ykeys, _ := v.Map()
d, _ := json.Marshal(v) configuration.Ykeys = parseSurrogateKeys(ykeys)
_ = json.Unmarshal(d, &ykeys)
configuration.Ykeys = ykeys
case "surrogate_keys": case "surrogate_keys":
ykeys := make(map[string]configurationtypes.SurrogateKeys) surrogates, _ := v.Map()
d, _ := json.Marshal(v) configuration.SurrogateKeys = parseSurrogateKeys(surrogates)
_ = json.Unmarshal(d, &ykeys)
configuration.Ykeys = ykeys
} }
} }

@ -0,0 +1,113 @@
package httpcache
import (
"io/ioutil"
"testing"
"time"
"github.com/go-kratos/kratos/v2/config"
"github.com/go-kratos/kratos/v2/config/file"
"gopkg.in/yaml.v2"
)
var dummyConfig = []byte(`
httpcache:
api:
basepath: /souin-api
prometheus:
basepath: /anything-for-prometheus-metrics
souin:
basepath: /anything-for-souin
cache_keys:
.+:
disable_body: true
disable_host: true
disable_method: true
default_cache:
allowed_http_verbs:
- GET
- POST
- HEAD
badger:
url: /badger/url
path: /badger/path
configuration:
SyncEnable: false
cdn:
api_key: XXXX
dynamic: true
hostname: XXXX
network: XXXX
provider: fastly
strategy: soft
etcd:
url: /etcd/url
path: /etcd/path
configuration:
SyncEnable: false
headers:
- Authorization
nuts:
url: /etcd/url
path: /etcd/path
configuration:
SyncEnable: false
ValueFloat: 1.123
ValueDuration: 1s
ValueInt: 2
ValueObject:
ValueFloat: 1.123
olric:
url: 'olric:3320'
path: /olric/path
configuration:
SyncEnable: false
regex:
exclude: 'ARegexHere'
ttl: 10s
stale: 10s
default_cache_control: no-store
log_level: debug
urls:
'https:\/\/domain.com\/first-.+':
ttl: 1000s
'https:\/\/domain.com\/second-route':
ttl: 10s
headers:
- Authorization
'https?:\/\/mysubdomain\.domain\.com':
ttl: 50s
default_cache_control: no-cache
headers:
- Authorization
- 'Content-Type'
ykeys:
The_First_Test:
headers:
Content-Type: '.+'
The_Second_Test:
url: 'the/second/.+'
surrogate_keys:
The_First_Test:
headers:
Content-Type: '.+'
The_Second_Test:
url: 'the/second/.+'
`)
func Test_ParseConfiguration(t *testing.T) {
filename := "/tmp/httpcache-" + time.Now().String() + ".yml"
ioutil.WriteFile(filename, dummyConfig, 0777)
c := config.New(
config.WithSource(file.NewSource(filename)),
config.WithDecoder(func(kv *config.KeyValue, v map[string]interface{}) error {
return yaml.Unmarshal(kv.Value, v)
}),
)
if err := c.Load(); err != nil {
panic(err)
}
_ = ParseConfiguration(c)
}

@ -91,6 +91,7 @@ func (s *httpcacheKratosPlugin) handle(next http.Handler) http.Handler {
combo.next.ServeHTTP(customWriter, r) combo.next.ServeHTTP(customWriter, r)
combo.req.Response = customWriter.Response combo.req.Response = customWriter.Response
// nolint
if combo.req.Response, e = s.Retriever.GetTransport().(*rfc.VaryTransport).UpdateCacheEventually(combo.req); e != nil { if combo.req.Response, e = s.Retriever.GetTransport().(*rfc.VaryTransport).UpdateCacheEventually(combo.req); e != nil {
return e return e
} }

@ -13,8 +13,7 @@ import (
"github.com/darkweak/souin/plugins" "github.com/darkweak/souin/plugins"
) )
var ( var devDefaultConfiguration = plugins.BaseConfiguration{
devDefaultConfiguration = plugins.BaseConfiguration{
API: configurationtypes.API{ API: configurationtypes.API{
BasePath: "/httpcache_api", BasePath: "/httpcache_api",
Prometheus: configurationtypes.APIEndpoint{ Prometheus: configurationtypes.APIEndpoint{
@ -35,13 +34,12 @@ var (
}, },
LogLevel: "debug", LogLevel: "debug",
} }
)
type next struct{} type next struct{}
func (n *next) ServeHTTP(rw http.ResponseWriter, r *http.Request) { func (n *next) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
rw.Write([]byte("Hello Kratos!")) _, _ = rw.Write([]byte("Hello Kratos!"))
} }
var nextFilter = &next{} var nextFilter = &next{}
@ -65,14 +63,18 @@ func Test_HttpcacheKratosPlugin_NewHTTPCacheFilter(t *testing.T) {
handler := NewHTTPCacheFilter(devDefaultConfiguration)(nextFilter) handler := NewHTTPCacheFilter(devDefaultConfiguration)(nextFilter)
req, res, res2 := prepare("/handled") req, res, res2 := prepare("/handled")
handler.ServeHTTP(res, req) handler.ServeHTTP(res, req)
if res.Result().Header.Get("Cache-Status") != "Souin; fwd=uri-miss; stored" { rs := res.Result()
rs.Body.Close()
if rs.Header.Get("Cache-Status") != "Souin; fwd=uri-miss; stored" {
t.Error("The response must contain a Cache-Status header with the stored directive.") t.Error("The response must contain a Cache-Status header with the stored directive.")
} }
handler.ServeHTTP(res2, req) handler.ServeHTTP(res2, req)
if res2.Result().Header.Get("Cache-Status") != "Souin; hit; ttl=0" { rs = res2.Result()
rs.Body.Close()
if rs.Header.Get("Cache-Status") != "Souin; hit; ttl=0" {
t.Error("The response must contain a Cache-Status header with the hit and ttl directives.") t.Error("The response must contain a Cache-Status header with the hit and ttl directives.")
} }
if res2.Result().Header.Get("Age") != "1" { if rs.Header.Get("Age") != "1" {
t.Error("The response must contain a Age header with the value 1.") t.Error("The response must contain a Age header with the value 1.")
} }
} }
@ -82,14 +84,18 @@ func Test_HttpcacheKratosPlugin_NewHTTPCacheFilter_Excluded(t *testing.T) {
handler := NewHTTPCacheFilter(devDefaultConfiguration)(nextFilter) handler := NewHTTPCacheFilter(devDefaultConfiguration)(nextFilter)
req, res, res2 := prepare("/excluded") req, res, res2 := prepare("/excluded")
handler.ServeHTTP(res, req) handler.ServeHTTP(res, req)
if res.Result().Header.Get("Cache-Status") != "Souin; fwd=uri-miss" { rs := res.Result()
rs.Body.Close()
if rs.Header.Get("Cache-Status") != "Souin; fwd=uri-miss" {
t.Error("The response must contain a Cache-Status header without the stored directive and with the uri-miss only.") t.Error("The response must contain a Cache-Status header without the stored directive and with the uri-miss only.")
} }
handler.ServeHTTP(res2, req) handler.ServeHTTP(res2, req)
if res2.Result().Header.Get("Cache-Status") != "Souin; fwd=uri-miss" { rs = res2.Result()
rs.Body.Close()
if rs.Header.Get("Cache-Status") != "Souin; fwd=uri-miss" {
t.Error("The response must contain a Cache-Status header without the stored directive and with the uri-miss only.") t.Error("The response must contain a Cache-Status header without the stored directive and with the uri-miss only.")
} }
if res2.Result().Header.Get("Age") != "" { if rs.Header.Get("Age") != "" {
t.Error("The response must not contain a Age header.") t.Error("The response must not contain a Age header.")
} }
} }
@ -102,15 +108,19 @@ func Test_HttpcacheKratosPlugin_NewHTTPCacheFilter_Mutation(t *testing.T) {
req, res, res2 := prepare("/handled") req, res, res2 := prepare("/handled")
req.Body = ioutil.NopCloser(bytes.NewBuffer([]byte(`{"query":"mutation":{something mutated}}`))) req.Body = ioutil.NopCloser(bytes.NewBuffer([]byte(`{"query":"mutation":{something mutated}}`)))
handler.ServeHTTP(res, req) handler.ServeHTTP(res, req)
if res.Result().Header.Get("Cache-Status") != "Souin; fwd=uri-miss" { rs := res.Result()
rs.Body.Close()
if rs.Header.Get("Cache-Status") != "Souin; fwd=uri-miss" {
t.Error("The response must contain a Cache-Status header without the stored directive and with the uri-miss only.") t.Error("The response must contain a Cache-Status header without the stored directive and with the uri-miss only.")
} }
req.Body = ioutil.NopCloser(bytes.NewBuffer([]byte(`{"query":"mutation":{something mutated}}`))) req.Body = ioutil.NopCloser(bytes.NewBuffer([]byte(`{"query":"mutation":{something mutated}}`)))
handler.ServeHTTP(res2, req) handler.ServeHTTP(res2, req)
if res2.Result().Header.Get("Cache-Status") != "Souin; fwd=uri-miss" { rs = res2.Result()
rs.Body.Close()
if rs.Header.Get("Cache-Status") != "Souin; fwd=uri-miss" {
t.Error("The response must contain a Cache-Status header without the stored directive and with the uri-miss only.") t.Error("The response must contain a Cache-Status header without the stored directive and with the uri-miss only.")
} }
if res2.Result().Header.Get("Age") != "" { if rs.Header.Get("Age") != "" {
t.Error("The response must not contain a Age header.") t.Error("The response must not contain a Age header.")
} }
} }
@ -120,26 +130,32 @@ func Test_HttpcacheKratosPlugin_NewHTTPCacheFilter_API(t *testing.T) {
handler := NewHTTPCacheFilter(devDefaultConfiguration)(nextFilter) handler := NewHTTPCacheFilter(devDefaultConfiguration)(nextFilter)
req, res, res2 := prepare("/httpcache_api/httpcache") req, res, res2 := prepare("/httpcache_api/httpcache")
handler.ServeHTTP(res, req) handler.ServeHTTP(res, req)
if res.Result().Header.Get("Content-Type") != "application/json" { rs := res.Result()
defer rs.Body.Close()
if rs.Header.Get("Content-Type") != "application/json" {
t.Error("The response must contain be in JSON.") t.Error("The response must contain be in JSON.")
} }
b, _ := ioutil.ReadAll(res.Result().Body) b, _ := ioutil.ReadAll(rs.Body)
res.Result().Body.Close() res.Result().Body.Close()
if string(b) != "[]" { if string(b) != "[]" {
t.Error("The response body must be an empty array because no request has been stored") t.Error("The response body must be an empty array because no request has been stored")
} }
req2 := httptest.NewRequest(http.MethodGet, "/handled", nil) req2 := httptest.NewRequest(http.MethodGet, "/handled", nil)
handler.ServeHTTP(res2, req2) handler.ServeHTTP(res2, req2)
if res2.Result().Header.Get("Cache-Status") != "Souin; fwd=uri-miss; stored" { rs = res2.Result()
rs.Body.Close()
if rs.Header.Get("Cache-Status") != "Souin; fwd=uri-miss; stored" {
t.Error("The response must contain a Cache-Status header with the stored directive.") t.Error("The response must contain a Cache-Status header with the stored directive.")
} }
res3 := httptest.NewRecorder() res3 := httptest.NewRecorder()
handler.ServeHTTP(res3, req) handler.ServeHTTP(res3, req)
if res3.Result().Header.Get("Content-Type") != "application/json" { rs = res3.Result()
rs.Body.Close()
if rs.Header.Get("Content-Type") != "application/json" {
t.Error("The response must contain be in JSON.") t.Error("The response must contain be in JSON.")
} }
b, _ = ioutil.ReadAll(res3.Result().Body) b, _ = ioutil.ReadAll(rs.Body)
res3.Result().Body.Close() rs.Body.Close()
var payload []string var payload []string
_ = json.Unmarshal(b, &payload) _ = json.Unmarshal(b, &payload)
if len(payload) != 2 { if len(payload) != 2 {

Loading…
Cancel
Save