parent
60b1e593f1
commit
51a3a32502
@ -0,0 +1,127 @@ |
||||
package metadata |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"net/http" |
||||
"testing" |
||||
|
||||
"github.com/go-kratos/kratos/v2/metadata" |
||||
"github.com/go-kratos/kratos/v2/transport" |
||||
) |
||||
|
||||
type headerCarrier http.Header |
||||
|
||||
func (hc headerCarrier) Get(key string) string { return http.Header(hc).Get(key) } |
||||
|
||||
func (hc headerCarrier) Set(key string, value string) { http.Header(hc).Set(key, value) } |
||||
|
||||
// Keys lists the keys stored in this carrier.
|
||||
func (hc headerCarrier) Keys() []string { |
||||
keys := make([]string, 0, len(hc)) |
||||
for k := range http.Header(hc) { |
||||
keys = append(keys, k) |
||||
} |
||||
return keys |
||||
} |
||||
|
||||
type testTransport struct{ header headerCarrier } |
||||
|
||||
func (tr *testTransport) Kind() transport.Kind { return transport.KindHTTP } |
||||
func (tr *testTransport) Endpoint() string { return "" } |
||||
func (tr *testTransport) Operation() string { return "" } |
||||
func (tr *testTransport) Header() transport.Header { return tr.header } |
||||
|
||||
func TestSever(t *testing.T) { |
||||
var ( |
||||
globalKey = "x-md-global-key" |
||||
globalValue = "global-value" |
||||
localKey = "x-md-local-key" |
||||
localValue = "local-value" |
||||
constKey = "x-md-local-const" |
||||
constValue = "x-md-local-const" |
||||
) |
||||
hs := func(ctx context.Context, in interface{}) (interface{}, error) { |
||||
md, ok := metadata.FromServerContext(ctx) |
||||
if !ok { |
||||
return nil, errors.New("no md") |
||||
} |
||||
if md.Get(constKey) != constValue { |
||||
return nil, errors.New("const not equal") |
||||
} |
||||
if md.Get(globalKey) != globalValue { |
||||
return nil, errors.New("global not equal") |
||||
} |
||||
if md.Get(localKey) != localValue { |
||||
return nil, errors.New("local not equal") |
||||
} |
||||
return in, nil |
||||
} |
||||
hc := headerCarrier{} |
||||
hc.Set(globalKey, globalValue) |
||||
hc.Set(localKey, localValue) |
||||
ctx := transport.NewServerContext(context.Background(), &testTransport{hc}) |
||||
// const md
|
||||
constMD := metadata.New() |
||||
constMD.Set(constKey, constValue) |
||||
reply, err := Server(WithConstants(constMD))(hs)(ctx, "foo") |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if reply.(string) != "foo" { |
||||
t.Fatalf("want foo got %v", reply) |
||||
} |
||||
} |
||||
|
||||
func TestClient(t *testing.T) { |
||||
var ( |
||||
globalKey = "x-md-global-key" |
||||
globalValue = "global-value" |
||||
localKey = "x-md-local-key" |
||||
localValue = "local-value" |
||||
customKey = "x-md-local-custom" |
||||
customValue = "custom-value" |
||||
constKey = "x-md-local-const" |
||||
constValue = "x-md-local-const" |
||||
) |
||||
hs := func(ctx context.Context, in interface{}) (interface{}, error) { |
||||
tr, ok := transport.FromClientContext(ctx) |
||||
if !ok { |
||||
return nil, errors.New("no md") |
||||
} |
||||
if tr.Header().Get(constKey) != constValue { |
||||
return nil, errors.New("const not equal") |
||||
} |
||||
if tr.Header().Get(customKey) != customValue { |
||||
return nil, errors.New("custom not equal") |
||||
} |
||||
if tr.Header().Get(globalKey) != globalValue { |
||||
return nil, errors.New("global not equal") |
||||
} |
||||
if tr.Header().Get(localKey) != "" { |
||||
return nil, errors.New("local must empty") |
||||
} |
||||
return in, nil |
||||
} |
||||
// server md
|
||||
serverMD := metadata.New() |
||||
serverMD.Set(globalKey, globalValue) |
||||
serverMD.Set(localKey, localValue) |
||||
ctx := metadata.NewServerContext(context.Background(), serverMD) |
||||
// client md
|
||||
clientMD := metadata.New() |
||||
clientMD.Set(customKey, customValue) |
||||
ctx = metadata.NewClientContext(ctx, clientMD) |
||||
// transport carrier
|
||||
ctx = transport.NewClientContext(ctx, &testTransport{headerCarrier{}}) |
||||
// const md
|
||||
constMD := metadata.New() |
||||
constMD.Set(constKey, constValue) |
||||
reply, err := Client(WithConstants(constMD))(hs)(ctx, "bar") |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if reply.(string) != "bar" { |
||||
t.Fatalf("want foo got %v", reply) |
||||
} |
||||
} |
Loading…
Reference in new issue