test(middleware/metadata): supplement test and modify code style (#2448)

pull/2412/head
Jesse 2 years ago committed by GitHub
parent e3feea6eeb
commit b5482d1794
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 50
      middleware/metadata/metadata.go
  2. 77
      middleware/metadata/metadata_test.go

@ -51,16 +51,19 @@ func Server(opts ...Option) middleware.Middleware {
} }
return func(handler middleware.Handler) middleware.Handler { return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) { return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
if tr, ok := transport.FromServerContext(ctx); ok { tr, ok := transport.FromServerContext(ctx)
md := options.md.Clone() if !ok {
header := tr.RequestHeader() return handler(ctx, req)
for _, k := range header.Keys() { }
if options.hasPrefix(k) {
md.Set(k, header.Get(k)) md := options.md.Clone()
} header := tr.RequestHeader()
for _, k := range header.Keys() {
if options.hasPrefix(k) {
md.Set(k, header.Get(k))
} }
ctx = metadata.NewServerContext(ctx, md)
} }
ctx = metadata.NewServerContext(ctx, md)
return handler(ctx, req) return handler(ctx, req)
} }
} }
@ -76,25 +79,28 @@ func Client(opts ...Option) middleware.Middleware {
} }
return func(handler middleware.Handler) middleware.Handler { return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) { return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
if tr, ok := transport.FromClientContext(ctx); ok { tr, ok := transport.FromClientContext(ctx)
header := tr.RequestHeader() if !ok {
// x-md-local- return handler(ctx, req)
for k, v := range options.md { }
header := tr.RequestHeader()
// x-md-local-
for k, v := range options.md {
header.Set(k, v)
}
if md, ok := metadata.FromClientContext(ctx); ok {
for k, v := range md {
header.Set(k, v) header.Set(k, v)
} }
if md, ok := metadata.FromClientContext(ctx); ok { }
for k, v := range md { // x-md-global-
if md, ok := metadata.FromServerContext(ctx); ok {
for k, v := range md {
if options.hasPrefix(k) {
header.Set(k, v) header.Set(k, v)
} }
} }
// x-md-global-
if md, ok := metadata.FromServerContext(ctx); ok {
for k, v := range md {
if options.hasPrefix(k) {
header.Set(k, v)
}
}
}
} }
return handler(ctx, req) return handler(ctx, req)
} }

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"reflect"
"testing" "testing"
"github.com/go-kratos/kratos/v2/metadata" "github.com/go-kratos/kratos/v2/metadata"
@ -33,15 +34,18 @@ func (tr *testTransport) Operation() string { return "" }
func (tr *testTransport) RequestHeader() transport.Header { return tr.header } func (tr *testTransport) RequestHeader() transport.Header { return tr.header }
func (tr *testTransport) ReplyHeader() transport.Header { return tr.header } func (tr *testTransport) ReplyHeader() transport.Header { return tr.header }
var (
globalKey = "x-md-global-key"
globalValue = "global-value"
localKey = "x-md-local-key"
localValue = "local-value"
customKey = "x-md-local-custom"
customValue = "custom-value"
constKey = "x-md-local-const"
constValue = "x-md-local-const"
)
func TestSever(t *testing.T) { func TestSever(t *testing.T) {
var (
globalKey = "x-md-global-key"
globalValue = "global-value"
localKey = "x-md-local-key"
localValue = "local-value"
constKey = "x-md-local-const"
constValue = "x-md-local-const"
)
hs := func(ctx context.Context, in interface{}) (interface{}, error) { hs := func(ctx context.Context, in interface{}) (interface{}, error) {
md, ok := metadata.FromServerContext(ctx) md, ok := metadata.FromServerContext(ctx)
if !ok { if !ok {
@ -75,16 +79,6 @@ func TestSever(t *testing.T) {
} }
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
var (
globalKey = "x-md-global-key"
globalValue = "global-value"
localKey = "x-md-local-key"
localValue = "local-value"
customKey = "x-md-local-custom"
customValue = "custom-value"
constKey = "x-md-local-const"
constValue = "x-md-local-const"
)
hs := func(ctx context.Context, in interface{}) (interface{}, error) { hs := func(ctx context.Context, in interface{}) (interface{}, error) {
tr, ok := transport.FromClientContext(ctx) tr, ok := transport.FromClientContext(ctx)
if !ok { if !ok {
@ -127,13 +121,52 @@ func TestClient(t *testing.T) {
} }
} }
func Test_WithPropagatedPrefix(t *testing.T) { func TestWithConstants(t *testing.T) {
o := &options{ md := metadata.Metadata{
constKey: constValue,
}
options := &options{
md: metadata.Metadata{
"override": "override",
},
}
WithConstants(md)(options)
if !reflect.DeepEqual(md, options.md) {
t.Errorf("want: %v, got: %v", md, options.md)
}
}
func TestOptions_WithPropagatedPrefix(t *testing.T) {
options := &options{
prefix: []string{"override"}, prefix: []string{"override"},
} }
WithPropagatedPrefix("something", "another")(o) prefixes := []string{"something", "another"}
if len(o.prefix) != 2 { WithPropagatedPrefix(prefixes...)(options)
if !reflect.DeepEqual(prefixes, options.prefix) {
t.Error("The prefix must be overridden.") t.Error("The prefix must be overridden.")
} }
} }
func TestOptions_hasPrefix(t *testing.T) {
tests := []struct {
name string
options *options
key string
exists bool
}{
{"exists key upper", &options{prefix: []string{"prefix"}}, "PREFIX_true", true},
{"exists key lower", &options{prefix: []string{"prefix"}}, "prefix_true", true},
{"not exists key upper", &options{prefix: []string{"prefix"}}, "false_PREFIX", false},
{"not exists key lower", &options{prefix: []string{"prefix"}}, "false_prefix", false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
exists := test.options.hasPrefix(test.key)
if test.exists != exists {
t.Errorf("key: '%sr', not exists prefixs: %v", test.key, test.options.prefix)
}
})
}
}

Loading…
Cancel
Save