diff --git a/internal/matcher/middleware.go b/internal/matcher/middleware.go new file mode 100644 index 000000000..8d5681820 --- /dev/null +++ b/internal/matcher/middleware.go @@ -0,0 +1,62 @@ +package matcher + +import ( + "sort" + "strings" + + "github.com/go-kratos/kratos/v2/middleware" +) + +// Matcher is a middleware matcher. +type Matcher interface { + Use(ms ...middleware.Middleware) + Add(selector string, ms ...middleware.Middleware) + Match(operation string) []middleware.Middleware +} + +// New new a middleware matcher. +func New() Matcher { + return &matcher{ + matchs: make(map[string][]middleware.Middleware), + } +} + +type matcher struct { + prefix []string + defaults []middleware.Middleware + matchs map[string][]middleware.Middleware +} + +func (m *matcher) Use(ms ...middleware.Middleware) { + m.defaults = ms +} + +func (m *matcher) Add(selector string, ms ...middleware.Middleware) { + if strings.HasSuffix(selector, "*") { + selector = strings.TrimSuffix(selector, "*") + m.prefix = append(m.prefix, selector) + // sort the prefix: + // - /foo/bar + // - /foo + sort.Slice(m.prefix, func(i, j int) bool { + return m.prefix[i] > m.prefix[j] + }) + } + m.matchs[selector] = ms +} + +func (m *matcher) Match(operation string) []middleware.Middleware { + ms := make([]middleware.Middleware, 0, len(m.defaults)) + if len(m.defaults) > 0 { + ms = append(ms, m.defaults...) + } + if next, ok := m.matchs[operation]; ok { + return append(ms, next...) + } + for _, prefix := range m.prefix { + if strings.HasPrefix(operation, prefix) { + return append(ms, m.matchs[prefix]...) + } + } + return ms +} diff --git a/internal/matcher/middleware_test.go b/internal/matcher/middleware_test.go new file mode 100644 index 000000000..9c948aebc --- /dev/null +++ b/internal/matcher/middleware_test.go @@ -0,0 +1,62 @@ +package matcher + +import ( + "context" + "testing" + + "github.com/go-kratos/kratos/v2/middleware" +) + +func logging(module string) middleware.Middleware { + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req interface{}) (reply interface{}, err error) { + return module, nil + } + } +} + +func equal(ms []middleware.Middleware, modules ...string) bool { + if len(ms) == 0 { + return false + } + for i, m := range ms { + x, _ := m(nil)(nil, nil) + if x != modules[i] { + return false + } + } + return true +} + +func TestMatcher(t *testing.T) { + m := New() + m.Use(logging("logging")) + m.Add("*", logging("*")) + m.Add("/foo/*", logging("foo/*")) + m.Add("/foo/bar/*", logging("foo/bar/*")) + m.Add("/foo/bar", logging("foo/bar")) + + if ms := m.Match("/"); len(ms) != 2 { + t.Fatal("not equal") + } else if !equal(ms, "logging", "*") { + t.Fatal("not equal") + } + + if ms := m.Match("/foo/xxx"); len(ms) != 2 { + t.Fatal("not equal") + } else if !equal(ms, "logging", "foo/*") { + t.Fatal("not equal") + } + + if ms := m.Match("/foo/bar"); len(ms) != 2 { + t.Fatal("not equal") + } else if !equal(ms, "logging", "foo/bar") { + t.Fatal("not equal") + } + + if ms := m.Match("/foo/bar/x"); len(ms) != 2 { + t.Fatal("not equal") + } else if !equal(ms, "logging", "foo/bar/*") { + t.Fatal("not equal") + } +} diff --git a/transport/grpc/interceptor.go b/transport/grpc/interceptor.go index 26daedfb8..ccd83ea69 100644 --- a/transport/grpc/interceptor.go +++ b/transport/grpc/interceptor.go @@ -33,8 +33,8 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor { h := func(ctx context.Context, req interface{}) (interface{}, error) { return handler(ctx, req) } - if len(s.middleware) > 0 { - h = middleware.Chain(s.middleware...)(h) + if next := s.middleware.Match(tr.Operation()); len(next) > 0 { + h = middleware.Chain(next...)(h) } reply, err := h(ctx, req) if len(replyHeader) > 0 { diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 2b5bde745..b66ba932d 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -8,6 +8,7 @@ import ( "time" "github.com/go-kratos/kratos/v2/internal/endpoint" + "github.com/go-kratos/kratos/v2/internal/matcher" apimd "github.com/go-kratos/kratos/v2/api/metadata" @@ -62,7 +63,7 @@ func Logger(logger log.Logger) ServerOption { // Middleware with server middleware. func Middleware(m ...middleware.Middleware) ServerOption { return func(s *Server) { - s.middleware = m + s.middleware.Use(m...) } } @@ -112,7 +113,7 @@ type Server struct { address string endpoint *url.URL timeout time.Duration - middleware []middleware.Middleware + middleware matcher.Matcher unaryInts []grpc.UnaryServerInterceptor streamInts []grpc.StreamServerInterceptor grpcOpts []grpc.ServerOption @@ -123,11 +124,12 @@ type Server struct { // NewServer creates a gRPC server by options. func NewServer(opts ...ServerOption) *Server { srv := &Server{ - baseCtx: context.Background(), - network: "tcp", - address: ":0", - timeout: 1 * time.Second, - health: health.NewServer(), + baseCtx: context.Background(), + network: "tcp", + address: ":0", + timeout: 1 * time.Second, + health: health.NewServer(), + middleware: matcher.New(), } for _, o := range opts { o(srv) @@ -163,6 +165,15 @@ func NewServer(opts ...ServerOption) *Server { return srv } +// Use uses a service middleware with selector. +// selector: +// - '/*' +// - '/helloworld.v1.Greeter/*' +// - '/helloworld.v1.Greeter/SayHello' +func (s *Server) Use(selector string, m ...middleware.Middleware) { + s.middleware.Add(selector, m...) +} + // Endpoint return a real address to registry endpoint. // examples: // grpc://127.0.0.1:9000?isSecure=false diff --git a/transport/grpc/server_test.go b/transport/grpc/server_test.go index 368c32953..fd1ae6757 100644 --- a/transport/grpc/server_test.go +++ b/transport/grpc/server_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/internal/matcher" pb "github.com/go-kratos/kratos/v2/internal/testdata/helloworld" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" @@ -198,17 +199,6 @@ func TestTimeout(t *testing.T) { } } -func TestMiddleware(t *testing.T) { - o := &Server{} - v := []middleware.Middleware{ - func(middleware.Handler) middleware.Handler { return nil }, - } - Middleware(v...)(o) - if !reflect.DeepEqual(v, o.middleware) { - t.Errorf("expect %v, got %v", v, o.middleware) - } -} - func TestTLSConfig(t *testing.T) { o := &Server{} v := &tls.Config{} @@ -273,9 +263,10 @@ func TestServer_unaryServerInterceptor(t *testing.T) { srv := &Server{ baseCtx: context.Background(), endpoint: u, - middleware: []middleware.Middleware{EmptyMiddleware()}, timeout: time.Duration(10), + middleware: matcher.New(), } + srv.middleware.Use(EmptyMiddleware()) req := &struct{}{} rv, err := srv.unaryServerInterceptor()(context.TODO(), req, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (i interface{}, e error) { return &testResp{Data: "hi"}, nil diff --git a/transport/http/context.go b/transport/http/context.go index bc54134f5..347d93ffe 100644 --- a/transport/http/context.go +++ b/transport/http/context.go @@ -10,6 +10,7 @@ import ( "time" "github.com/go-kratos/kratos/v2/middleware" + "github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport/http/binding" "github.com/gorilla/mux" ) @@ -89,7 +90,10 @@ func (c *wrapper) Query() url.Values { func (c *wrapper) Request() *http.Request { return c.req } func (c *wrapper) Response() http.ResponseWriter { return c.res } func (c *wrapper) Middleware(h middleware.Handler) middleware.Handler { - return middleware.Chain(c.router.srv.ms...)(h) + if tr, ok := transport.FromServerContext(c.req.Context()); ok { + return middleware.Chain(c.router.srv.middleware.Match(tr.Operation())...)(h) + } + return middleware.Chain(c.router.srv.middleware.Match(c.req.URL.Path)...)(h) } func (c *wrapper) Bind(v interface{}) error { return c.router.srv.dec(c.req, v) } func (c *wrapper) BindVars(v interface{}) error { return binding.BindQuery(c.Vars(), v) } diff --git a/transport/http/server.go b/transport/http/server.go index 9bb9ffea5..4767e79b4 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -10,6 +10,7 @@ import ( "time" "github.com/go-kratos/kratos/v2/internal/endpoint" + "github.com/go-kratos/kratos/v2/internal/matcher" "github.com/go-kratos/kratos/v2/internal/host" "github.com/go-kratos/kratos/v2/log" @@ -58,7 +59,7 @@ func Logger(logger log.Logger) ServerOption { // Middleware with service middleware option. func Middleware(m ...middleware.Middleware) ServerOption { return func(o *Server) { - o.ms = m + o.middleware.Use(m...) } } @@ -124,7 +125,7 @@ type Server struct { address string timeout time.Duration filters []FilterFunc - ms []middleware.Middleware + middleware matcher.Matcher dec DecodeRequestFunc enc EncodeResponseFunc ene EncodeErrorFunc @@ -138,6 +139,7 @@ func NewServer(opts ...ServerOption) *Server { network: "tcp", address: ":0", timeout: 1 * time.Second, + middleware: matcher.New(), dec: DefaultRequestDecoder, enc: DefaultResponseEncoder, ene: DefaultErrorEncoder, @@ -157,6 +159,15 @@ func NewServer(opts ...ServerOption) *Server { return srv } +// Use uses a service middleware with selector. +// selector: +// - '/*' +// - '/helloworld.v1.Greeter/*' +// - '/helloworld.v1.Greeter/SayHello' +func (s *Server) Use(selector string, m ...middleware.Middleware) { + s.middleware.Add(selector, m...) +} + // WalkRoute walks the router and all its sub-routers, calling walkFn for each route in the tree. func (s *Server) WalkRoute(fn WalkRouteFunc) error { return s.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { @@ -229,10 +240,10 @@ func (s *Server) filter() mux.MiddlewareFunc { tr := &Transport{ operation: pathTemplate, + pathTemplate: pathTemplate, reqHeader: headerCarrier(req.Header), replyHeader: headerCarrier(w.Header()), request: req, - pathTemplate: pathTemplate, } if s.endpoint != nil { tr.endpoint = s.endpoint.String() diff --git a/transport/http/server_test.go b/transport/http/server_test.go index a6f05080a..05bc8399e 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -14,7 +14,6 @@ import ( "time" "github.com/go-kratos/kratos/v2/errors" - "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/internal/host" ) @@ -313,17 +312,6 @@ func TestLogger(t *testing.T) { // todo } -func TestMiddleware(t *testing.T) { - o := &Server{} - v := []middleware.Middleware{ - func(middleware.Handler) middleware.Handler { return nil }, - } - Middleware(v...)(o) - if !reflect.DeepEqual(v, o.ms) { - t.Errorf("expected %v got %v", v, o.ms) - } -} - func TestRequestDecoder(t *testing.T) { o := &Server{} v := func(*http.Request, interface{}) error { return nil }