diff --git a/middleware/metadata/metadata.go b/middleware/metadata/metadata.go index 5c29189f9..da9e43a96 100644 --- a/middleware/metadata/metadata.go +++ b/middleware/metadata/metadata.go @@ -51,16 +51,19 @@ func Server(opts ...Option) middleware.Middleware { } return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (reply interface{}, err error) { - if tr, ok := transport.FromServerContext(ctx); ok { - md := options.md.Clone() - header := tr.RequestHeader() - for _, k := range header.Keys() { - if options.hasPrefix(k) { - md.Set(k, header.Get(k)) - } + tr, ok := transport.FromServerContext(ctx) + if !ok { + return handler(ctx, req) + } + + md := options.md.Clone() + header := tr.RequestHeader() + for _, k := range header.Keys() { + if options.hasPrefix(k) { + md.Set(k, header.Get(k)) } - ctx = metadata.NewServerContext(ctx, md) } + ctx = metadata.NewServerContext(ctx, md) return handler(ctx, req) } } @@ -76,25 +79,28 @@ func Client(opts ...Option) middleware.Middleware { } return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (reply interface{}, err error) { - if tr, ok := transport.FromClientContext(ctx); ok { - header := tr.RequestHeader() - // x-md-local- - for k, v := range options.md { + tr, ok := transport.FromClientContext(ctx) + if !ok { + return handler(ctx, req) + } + + header := tr.RequestHeader() + // x-md-local- + for k, v := range options.md { + header.Set(k, v) + } + if md, ok := metadata.FromClientContext(ctx); ok { + for k, v := range md { header.Set(k, v) } - if md, ok := metadata.FromClientContext(ctx); ok { - for k, v := range md { + } + // x-md-global- + if md, ok := metadata.FromServerContext(ctx); ok { + for k, v := range md { + if options.hasPrefix(k) { header.Set(k, v) } } - // x-md-global- - if md, ok := metadata.FromServerContext(ctx); ok { - for k, v := range md { - if options.hasPrefix(k) { - header.Set(k, v) - } - } - } } return handler(ctx, req) } diff --git a/middleware/metadata/metadata_test.go b/middleware/metadata/metadata_test.go index e46a94f50..b083f9c10 100644 --- a/middleware/metadata/metadata_test.go +++ b/middleware/metadata/metadata_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "reflect" "testing" "github.com/go-kratos/kratos/v2/metadata" @@ -33,15 +34,18 @@ func (tr *testTransport) Operation() string { return "" } func (tr *testTransport) RequestHeader() transport.Header { return tr.header } func (tr *testTransport) ReplyHeader() transport.Header { return tr.header } +var ( + globalKey = "x-md-global-key" + globalValue = "global-value" + localKey = "x-md-local-key" + localValue = "local-value" + customKey = "x-md-local-custom" + customValue = "custom-value" + constKey = "x-md-local-const" + constValue = "x-md-local-const" +) + func TestSever(t *testing.T) { - var ( - globalKey = "x-md-global-key" - globalValue = "global-value" - localKey = "x-md-local-key" - localValue = "local-value" - constKey = "x-md-local-const" - constValue = "x-md-local-const" - ) hs := func(ctx context.Context, in interface{}) (interface{}, error) { md, ok := metadata.FromServerContext(ctx) if !ok { @@ -75,16 +79,6 @@ func TestSever(t *testing.T) { } func TestClient(t *testing.T) { - var ( - globalKey = "x-md-global-key" - globalValue = "global-value" - localKey = "x-md-local-key" - localValue = "local-value" - customKey = "x-md-local-custom" - customValue = "custom-value" - constKey = "x-md-local-const" - constValue = "x-md-local-const" - ) hs := func(ctx context.Context, in interface{}) (interface{}, error) { tr, ok := transport.FromClientContext(ctx) if !ok { @@ -127,13 +121,52 @@ func TestClient(t *testing.T) { } } -func Test_WithPropagatedPrefix(t *testing.T) { - o := &options{ +func TestWithConstants(t *testing.T) { + md := metadata.Metadata{ + constKey: constValue, + } + options := &options{ + md: metadata.Metadata{ + "override": "override", + }, + } + + WithConstants(md)(options) + if !reflect.DeepEqual(md, options.md) { + t.Errorf("want: %v, got: %v", md, options.md) + } +} + +func TestOptions_WithPropagatedPrefix(t *testing.T) { + options := &options{ prefix: []string{"override"}, } - WithPropagatedPrefix("something", "another")(o) + prefixes := []string{"something", "another"} - if len(o.prefix) != 2 { + WithPropagatedPrefix(prefixes...)(options) + if !reflect.DeepEqual(prefixes, options.prefix) { t.Error("The prefix must be overridden.") } } + +func TestOptions_hasPrefix(t *testing.T) { + tests := []struct { + name string + options *options + key string + exists bool + }{ + {"exists key upper", &options{prefix: []string{"prefix"}}, "PREFIX_true", true}, + {"exists key lower", &options{prefix: []string{"prefix"}}, "prefix_true", true}, + {"not exists key upper", &options{prefix: []string{"prefix"}}, "false_PREFIX", false}, + {"not exists key lower", &options{prefix: []string{"prefix"}}, "false_prefix", false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + exists := test.options.hasPrefix(test.key) + if test.exists != exists { + t.Errorf("key: '%sr', not exists prefixs: %v", test.key, test.options.prefix) + } + }) + } +}