diff --git a/middleware/selector/selector.go b/middleware/selector/selector.go index 34aa54f89..53eaf4b60 100644 --- a/middleware/selector/selector.go +++ b/middleware/selector/selector.go @@ -11,7 +11,7 @@ import ( type ( transporter func(ctx context.Context) (transport.Transporter, bool) - match func(operation string) bool + MatchFunc func(operation string) bool ) var ( @@ -32,6 +32,7 @@ type Builder struct { prefix []string regex []string path []string + match MatchFunc ms []middleware.Middleware } @@ -64,6 +65,12 @@ func (b *Builder) Path(path ...string) *Builder { return b } +// Match is with Builder's match +func (b *Builder) Match(fn MatchFunc) *Builder { + b.match = fn + return b +} + // Build is Builder's Build, for example: Server().Path(m1,m2).Build() func (b *Builder) Build() middleware.Middleware { var transporter func(ctx context.Context) (transport.Transporter, bool) @@ -72,11 +79,11 @@ func (b *Builder) Build() middleware.Middleware { } else { transporter = serverTransporter } - return selector(transporter, b.match, b.ms...) + return selector(transporter, b.matchs, b.ms...) } -// match is match operation compliance Builder -func (b *Builder) match(operation string) bool { +// matchs is match operation compliance Builder +func (b *Builder) matchs(operation string) bool { for _, prefix := range b.prefix { if prefixMatch(prefix, operation) { return true @@ -92,11 +99,15 @@ func (b *Builder) match(operation string) bool { return true } } + + if b.match != nil { + return b.match(operation) + } return false } // selector middleware -func selector(transporter transporter, match match, ms ...middleware.Middleware) middleware.Middleware { +func selector(transporter transporter, match MatchFunc, ms ...middleware.Middleware) middleware.Middleware { return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (reply interface{}, err error) { info, ok := transporter(ctx) diff --git a/middleware/selector/selector_test.go b/middleware/selector/selector_test.go index 4554e295a..aec1a45aa 100644 --- a/middleware/selector/selector_test.go +++ b/middleware/selector/selector_test.go @@ -3,10 +3,12 @@ package selector import ( "context" "fmt" + "strings" "testing" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" + "github.com/stretchr/testify/assert" ) var _ transport.Transporter = &Transport{} @@ -109,6 +111,47 @@ func TestMatchClient(t *testing.T) { } } +func TestFunc(t *testing.T) { + tests := []struct { + name string + ctx context.Context + }{ + { + name: "/hello.Update/world", + ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/hello.Update/world"}), + }, + { + name: "/hi.Create/world", + ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/hi.Create/world"}), + }, + { + name: "/test.Name/1234", + ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/test.Name/1234"}), + }, + { + name: "/go-kratos.dev/kratos", + ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/go-kratos.dev/kratos"}), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + next := func(ctx context.Context, req interface{}) (interface{}, error) { + t.Log(req) + return "reply", nil + } + next = Server(testMiddleware).Match(func(operation string) bool { + if strings.HasPrefix(operation, "/go-kratos.dev") || strings.HasSuffix(operation, "world") { + return true + } + return false + }).Build()(next) + reply, err := next(test.ctx, test.name) + assert.Equal(t, reply, "reply") + assert.Nil(t, err) + }) + } +} + func testMiddleware(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (reply interface{}, err error) { fmt.Println("before")