package selector import ( "context" "fmt" "strings" "testing" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" "github.com/stretchr/testify/assert" ) var _ transport.Transporter = &Transport{} type Transport struct { kind transport.Kind endpoint string operation string headers *mockHeader } func (tr *Transport) Kind() transport.Kind { return tr.kind } func (tr *Transport) Endpoint() string { return tr.endpoint } func (tr *Transport) Operation() string { return tr.operation } func (tr *Transport) RequestHeader() transport.Header { return tr.headers } func (tr *Transport) ReplyHeader() transport.Header { return nil } type mockHeader struct { m map[string]string } func (m *mockHeader) Get(key string) string { return m.m[key] } func (m *mockHeader) Set(key, value string) { m.m[key] = value } func (m *mockHeader) Keys() []string { keys := make([]string, 0, len(m.m)) for k := range m.m { keys = append(keys, k) } return keys } func TestMatch(t *testing.T) { tests := []struct { name string ctx context.Context }{ // TODO: Add test cases. { name: "/hello/world", ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/hello/world"}), }, { name: "/hi/world", ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/hi/world"}), }, { name: "/test/1234", ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/test/1234"}), }, { name: "/example/kratos", ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/example/kratos"}), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { next := func(ctx context.Context, req interface{}) (interface{}, error) { t.Log(req) return "reply", nil } next = Server(testMiddleware).Prefix("/hello/").Regex(`/test/[0-9]+`). Path("/example/kratos").Build()(next) _, _ = next(test.ctx, test.name) }) } } func TestMatchClient(t *testing.T) { tests := []struct { name string ctx context.Context }{ // TODO: Add test cases. { name: "/hello/world", ctx: transport.NewClientContext(context.Background(), &Transport{operation: "/hello/world"}), }, { name: "/hi/world", ctx: transport.NewClientContext(context.Background(), &Transport{operation: "/hi/world"}), }, { name: "/test/1234", ctx: transport.NewClientContext(context.Background(), &Transport{operation: "/test/1234"}), }, { name: "/example/kratos", ctx: transport.NewClientContext(context.Background(), &Transport{operation: "/example/kratos"}), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { next := func(ctx context.Context, req interface{}) (interface{}, error) { t.Log(req) return "reply", nil } next = Client(testMiddleware).Prefix("/hello/").Regex(`/test/[0-9]+`). Path("/example/kratos").Build()(next) _, _ = next(test.ctx, test.name) }) } } func TestFunc(t *testing.T) { tests := []struct { name string ctx context.Context }{ { name: "/hello.Update/world", ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/hello.Update/world"}), }, { name: "/hi.Create/world", ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/hi.Create/world"}), }, { name: "/test.Name/1234", ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/test.Name/1234"}), }, { name: "/go-kratos.dev/kratos", ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/go-kratos.dev/kratos"}), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { next := func(ctx context.Context, req interface{}) (interface{}, error) { t.Log(req) return "reply", nil } next = Server(testMiddleware).Match(func(ctx context.Context, operation string) bool { if strings.HasPrefix(operation, "/go-kratos.dev") || strings.HasSuffix(operation, "world") { return true } return false }).Build()(next) reply, err := next(test.ctx, test.name) assert.Equal(t, reply, "reply") assert.Nil(t, err) }) } } func TestHeaderFunc(t *testing.T) { tests := []struct { name string ctx context.Context }{ { name: "/hello.Update/world", ctx: transport.NewServerContext(context.Background(), &Transport{ operation: "/hello.Update/world", headers: &mockHeader{map[string]string{"X-Test": "test"}}, }), }, { name: "/hi.Create/world", ctx: transport.NewServerContext(context.Background(), &Transport{ operation: "/hi.Create/world", headers: &mockHeader{map[string]string{"X-Test": "test2", "go-kratos": "kratos"}}, }), }, { name: "/test.Name/1234", ctx: transport.NewServerContext(context.Background(), &Transport{ operation: "/test.Name/1234", headers: &mockHeader{map[string]string{"X-Test": "test3"}}, }), }, { name: "/go-kratos.dev/kratos", ctx: transport.NewServerContext(context.Background(), &Transport{ operation: "/go-kratos.dev/kratos", headers: &mockHeader{map[string]string{"X-Test": "test"}}, }), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { next := func(ctx context.Context, req interface{}) (interface{}, error) { t.Log(req) return "reply", nil } next = Server(testMiddleware).Match(func(ctx context.Context, operation string) bool { tr, ok := transport.FromServerContext(ctx) if !ok { return false } if tr.RequestHeader().Get("X-Test") == "test" { return true } if tr.RequestHeader().Get("go-kratos") == "kratos" { return true } return false }).Build()(next) reply, err := next(test.ctx, test.name) assert.Equal(t, reply, "reply") assert.Nil(t, err) }) } } func testMiddleware(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (reply interface{}, err error) { fmt.Println("before") reply, err = handler(ctx, req) fmt.Println("after") return } }