From 51a3a3250259eceedb1968ea0ff024dcb8582029 Mon Sep 17 00:00:00 2001 From: Tony Chen Date: Tue, 15 Jun 2021 21:41:55 +0800 Subject: [PATCH] middleware/metadata: add md test (#1064) * add metadata test --- middleware/metadata/metadata.go | 34 +++---- middleware/metadata/metadata_test.go | 127 +++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 15 deletions(-) create mode 100644 middleware/metadata/metadata_test.go diff --git a/middleware/metadata/metadata.go b/middleware/metadata/metadata.go index aaa566743..d251ec568 100644 --- a/middleware/metadata/metadata.go +++ b/middleware/metadata/metadata.go @@ -17,6 +17,16 @@ type options struct { md metadata.Metadata } +func (o *options) hasPrefix(key string) bool { + k := strings.ToLower(key) + for _, prefix := range o.prefix { + if strings.HasPrefix(k, prefix) { + return true + } + } + return false +} + // WithConstants with constant metadata key value. func WithConstants(md metadata.Metadata) Option { return func(o *options) { @@ -33,22 +43,19 @@ func WithPropagatedPrefix(prefix ...string) Option { // Server is middleware server-side metadata. func Server(opts ...Option) middleware.Middleware { - options := options{ + options := &options{ prefix: []string{"x-md-"}, // x-md-global-, x-md-local } for _, o := range opts { - o(&options) + o(options) } return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (reply interface{}, err error) { if tr, ok := transport.FromServerContext(ctx); ok { - md := metadata.New() + md := options.md.Clone() for _, k := range tr.Header().Keys() { - for _, prefix := range options.prefix { - if strings.HasPrefix(strings.ToLower(k), prefix) { - md.Set(k, tr.Header().Get(k)) - break - } + if options.hasPrefix(k) { + md.Set(k, tr.Header().Get(k)) } } ctx = metadata.NewServerContext(ctx, md) @@ -60,11 +67,11 @@ func Server(opts ...Option) middleware.Middleware { // Client is middleware client-side metadata. func Client(opts ...Option) middleware.Middleware { - options := options{ + options := &options{ prefix: []string{"x-md-global-"}, } for _, o := range opts { - o(&options) + o(options) } return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (reply interface{}, err error) { @@ -81,11 +88,8 @@ func Client(opts ...Option) middleware.Middleware { // x-md-global- if md, ok := metadata.FromServerContext(ctx); ok { for k, v := range md { - for _, prefix := range options.prefix { - if strings.HasPrefix(k, prefix) { - tr.Header().Set(k, v) - break - } + if options.hasPrefix(k) { + tr.Header().Set(k, v) } } } diff --git a/middleware/metadata/metadata_test.go b/middleware/metadata/metadata_test.go new file mode 100644 index 000000000..917dea59c --- /dev/null +++ b/middleware/metadata/metadata_test.go @@ -0,0 +1,127 @@ +package metadata + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/go-kratos/kratos/v2/metadata" + "github.com/go-kratos/kratos/v2/transport" +) + +type headerCarrier http.Header + +func (hc headerCarrier) Get(key string) string { return http.Header(hc).Get(key) } + +func (hc headerCarrier) Set(key string, value string) { http.Header(hc).Set(key, value) } + +// Keys lists the keys stored in this carrier. +func (hc headerCarrier) Keys() []string { + keys := make([]string, 0, len(hc)) + for k := range http.Header(hc) { + keys = append(keys, k) + } + return keys +} + +type testTransport struct{ header headerCarrier } + +func (tr *testTransport) Kind() transport.Kind { return transport.KindHTTP } +func (tr *testTransport) Endpoint() string { return "" } +func (tr *testTransport) Operation() string { return "" } +func (tr *testTransport) Header() transport.Header { return tr.header } + +func TestSever(t *testing.T) { + var ( + globalKey = "x-md-global-key" + globalValue = "global-value" + localKey = "x-md-local-key" + localValue = "local-value" + constKey = "x-md-local-const" + constValue = "x-md-local-const" + ) + hs := func(ctx context.Context, in interface{}) (interface{}, error) { + md, ok := metadata.FromServerContext(ctx) + if !ok { + return nil, errors.New("no md") + } + if md.Get(constKey) != constValue { + return nil, errors.New("const not equal") + } + if md.Get(globalKey) != globalValue { + return nil, errors.New("global not equal") + } + if md.Get(localKey) != localValue { + return nil, errors.New("local not equal") + } + return in, nil + } + hc := headerCarrier{} + hc.Set(globalKey, globalValue) + hc.Set(localKey, localValue) + ctx := transport.NewServerContext(context.Background(), &testTransport{hc}) + // const md + constMD := metadata.New() + constMD.Set(constKey, constValue) + reply, err := Server(WithConstants(constMD))(hs)(ctx, "foo") + if err != nil { + t.Fatal(err) + } + if reply.(string) != "foo" { + t.Fatalf("want foo got %v", reply) + } +} + +func TestClient(t *testing.T) { + var ( + globalKey = "x-md-global-key" + globalValue = "global-value" + localKey = "x-md-local-key" + localValue = "local-value" + customKey = "x-md-local-custom" + customValue = "custom-value" + constKey = "x-md-local-const" + constValue = "x-md-local-const" + ) + hs := func(ctx context.Context, in interface{}) (interface{}, error) { + tr, ok := transport.FromClientContext(ctx) + if !ok { + return nil, errors.New("no md") + } + if tr.Header().Get(constKey) != constValue { + return nil, errors.New("const not equal") + } + if tr.Header().Get(customKey) != customValue { + return nil, errors.New("custom not equal") + } + if tr.Header().Get(globalKey) != globalValue { + return nil, errors.New("global not equal") + } + if tr.Header().Get(localKey) != "" { + return nil, errors.New("local must empty") + } + return in, nil + } + // server md + serverMD := metadata.New() + serverMD.Set(globalKey, globalValue) + serverMD.Set(localKey, localValue) + ctx := metadata.NewServerContext(context.Background(), serverMD) + // client md + clientMD := metadata.New() + clientMD.Set(customKey, customValue) + ctx = metadata.NewClientContext(ctx, clientMD) + // transport carrier + ctx = transport.NewClientContext(ctx, &testTransport{headerCarrier{}}) + // const md + constMD := metadata.New() + constMD.Set(constKey, constValue) + reply, err := Client(WithConstants(constMD))(hs)(ctx, "bar") + if err != nil { + t.Fatal(err) + } + if reply.(string) != "bar" { + t.Fatalf("want foo got %v", reply) + } +}