feat(middleware): add selector matcher (#2239)

* feat(middleware): add selector matcher


Co-authored-by: chenzhihui <chenzhihui@bilibili.com>
pull/2240/head
Tony Chen 2 years ago committed by GitHub
parent 377356d04d
commit f3b0da3f04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 62
      internal/matcher/middleware.go
  2. 62
      internal/matcher/middleware_test.go
  3. 4
      transport/grpc/interceptor.go
  4. 15
      transport/grpc/server.go
  5. 15
      transport/grpc/server_test.go
  6. 6
      transport/http/context.go
  7. 17
      transport/http/server.go
  8. 12
      transport/http/server_test.go

@ -0,0 +1,62 @@
package matcher
import (
"sort"
"strings"
"github.com/go-kratos/kratos/v2/middleware"
)
// Matcher is a middleware matcher.
type Matcher interface {
Use(ms ...middleware.Middleware)
Add(selector string, ms ...middleware.Middleware)
Match(operation string) []middleware.Middleware
}
// New new a middleware matcher.
func New() Matcher {
return &matcher{
matchs: make(map[string][]middleware.Middleware),
}
}
type matcher struct {
prefix []string
defaults []middleware.Middleware
matchs map[string][]middleware.Middleware
}
func (m *matcher) Use(ms ...middleware.Middleware) {
m.defaults = ms
}
func (m *matcher) Add(selector string, ms ...middleware.Middleware) {
if strings.HasSuffix(selector, "*") {
selector = strings.TrimSuffix(selector, "*")
m.prefix = append(m.prefix, selector)
// sort the prefix:
// - /foo/bar
// - /foo
sort.Slice(m.prefix, func(i, j int) bool {
return m.prefix[i] > m.prefix[j]
})
}
m.matchs[selector] = ms
}
func (m *matcher) Match(operation string) []middleware.Middleware {
ms := make([]middleware.Middleware, 0, len(m.defaults))
if len(m.defaults) > 0 {
ms = append(ms, m.defaults...)
}
if next, ok := m.matchs[operation]; ok {
return append(ms, next...)
}
for _, prefix := range m.prefix {
if strings.HasPrefix(operation, prefix) {
return append(ms, m.matchs[prefix]...)
}
}
return ms
}

@ -0,0 +1,62 @@
package matcher
import (
"context"
"testing"
"github.com/go-kratos/kratos/v2/middleware"
)
func logging(module string) middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
return module, nil
}
}
}
func equal(ms []middleware.Middleware, modules ...string) bool {
if len(ms) == 0 {
return false
}
for i, m := range ms {
x, _ := m(nil)(nil, nil)
if x != modules[i] {
return false
}
}
return true
}
func TestMatcher(t *testing.T) {
m := New()
m.Use(logging("logging"))
m.Add("*", logging("*"))
m.Add("/foo/*", logging("foo/*"))
m.Add("/foo/bar/*", logging("foo/bar/*"))
m.Add("/foo/bar", logging("foo/bar"))
if ms := m.Match("/"); len(ms) != 2 {
t.Fatal("not equal")
} else if !equal(ms, "logging", "*") {
t.Fatal("not equal")
}
if ms := m.Match("/foo/xxx"); len(ms) != 2 {
t.Fatal("not equal")
} else if !equal(ms, "logging", "foo/*") {
t.Fatal("not equal")
}
if ms := m.Match("/foo/bar"); len(ms) != 2 {
t.Fatal("not equal")
} else if !equal(ms, "logging", "foo/bar") {
t.Fatal("not equal")
}
if ms := m.Match("/foo/bar/x"); len(ms) != 2 {
t.Fatal("not equal")
} else if !equal(ms, "logging", "foo/bar/*") {
t.Fatal("not equal")
}
}

@ -33,8 +33,8 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
h := func(ctx context.Context, req interface{}) (interface{}, error) { h := func(ctx context.Context, req interface{}) (interface{}, error) {
return handler(ctx, req) return handler(ctx, req)
} }
if len(s.middleware) > 0 { if next := s.middleware.Match(tr.Operation()); len(next) > 0 {
h = middleware.Chain(s.middleware...)(h) h = middleware.Chain(next...)(h)
} }
reply, err := h(ctx, req) reply, err := h(ctx, req)
if len(replyHeader) > 0 { if len(replyHeader) > 0 {

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/go-kratos/kratos/v2/internal/endpoint" "github.com/go-kratos/kratos/v2/internal/endpoint"
"github.com/go-kratos/kratos/v2/internal/matcher"
apimd "github.com/go-kratos/kratos/v2/api/metadata" apimd "github.com/go-kratos/kratos/v2/api/metadata"
@ -62,7 +63,7 @@ func Logger(logger log.Logger) ServerOption {
// Middleware with server middleware. // Middleware with server middleware.
func Middleware(m ...middleware.Middleware) ServerOption { func Middleware(m ...middleware.Middleware) ServerOption {
return func(s *Server) { return func(s *Server) {
s.middleware = m s.middleware.Use(m...)
} }
} }
@ -112,7 +113,7 @@ type Server struct {
address string address string
endpoint *url.URL endpoint *url.URL
timeout time.Duration timeout time.Duration
middleware []middleware.Middleware middleware matcher.Matcher
unaryInts []grpc.UnaryServerInterceptor unaryInts []grpc.UnaryServerInterceptor
streamInts []grpc.StreamServerInterceptor streamInts []grpc.StreamServerInterceptor
grpcOpts []grpc.ServerOption grpcOpts []grpc.ServerOption
@ -128,6 +129,7 @@ func NewServer(opts ...ServerOption) *Server {
address: ":0", address: ":0",
timeout: 1 * time.Second, timeout: 1 * time.Second,
health: health.NewServer(), health: health.NewServer(),
middleware: matcher.New(),
} }
for _, o := range opts { for _, o := range opts {
o(srv) o(srv)
@ -163,6 +165,15 @@ func NewServer(opts ...ServerOption) *Server {
return srv return srv
} }
// Use uses a service middleware with selector.
// selector:
// - '/*'
// - '/helloworld.v1.Greeter/*'
// - '/helloworld.v1.Greeter/SayHello'
func (s *Server) Use(selector string, m ...middleware.Middleware) {
s.middleware.Add(selector, m...)
}
// Endpoint return a real address to registry endpoint. // Endpoint return a real address to registry endpoint.
// examples: // examples:
// grpc://127.0.0.1:9000?isSecure=false // grpc://127.0.0.1:9000?isSecure=false

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/internal/matcher"
pb "github.com/go-kratos/kratos/v2/internal/testdata/helloworld" pb "github.com/go-kratos/kratos/v2/internal/testdata/helloworld"
"github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport"
@ -198,17 +199,6 @@ func TestTimeout(t *testing.T) {
} }
} }
func TestMiddleware(t *testing.T) {
o := &Server{}
v := []middleware.Middleware{
func(middleware.Handler) middleware.Handler { return nil },
}
Middleware(v...)(o)
if !reflect.DeepEqual(v, o.middleware) {
t.Errorf("expect %v, got %v", v, o.middleware)
}
}
func TestTLSConfig(t *testing.T) { func TestTLSConfig(t *testing.T) {
o := &Server{} o := &Server{}
v := &tls.Config{} v := &tls.Config{}
@ -273,9 +263,10 @@ func TestServer_unaryServerInterceptor(t *testing.T) {
srv := &Server{ srv := &Server{
baseCtx: context.Background(), baseCtx: context.Background(),
endpoint: u, endpoint: u,
middleware: []middleware.Middleware{EmptyMiddleware()},
timeout: time.Duration(10), timeout: time.Duration(10),
middleware: matcher.New(),
} }
srv.middleware.Use(EmptyMiddleware())
req := &struct{}{} req := &struct{}{}
rv, err := srv.unaryServerInterceptor()(context.TODO(), req, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (i interface{}, e error) { rv, err := srv.unaryServerInterceptor()(context.TODO(), req, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (i interface{}, e error) {
return &testResp{Data: "hi"}, nil return &testResp{Data: "hi"}, nil

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/transport/http/binding" "github.com/go-kratos/kratos/v2/transport/http/binding"
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
@ -89,7 +90,10 @@ func (c *wrapper) Query() url.Values {
func (c *wrapper) Request() *http.Request { return c.req } func (c *wrapper) Request() *http.Request { return c.req }
func (c *wrapper) Response() http.ResponseWriter { return c.res } func (c *wrapper) Response() http.ResponseWriter { return c.res }
func (c *wrapper) Middleware(h middleware.Handler) middleware.Handler { func (c *wrapper) Middleware(h middleware.Handler) middleware.Handler {
return middleware.Chain(c.router.srv.ms...)(h) if tr, ok := transport.FromServerContext(c.req.Context()); ok {
return middleware.Chain(c.router.srv.middleware.Match(tr.Operation())...)(h)
}
return middleware.Chain(c.router.srv.middleware.Match(c.req.URL.Path)...)(h)
} }
func (c *wrapper) Bind(v interface{}) error { return c.router.srv.dec(c.req, v) } func (c *wrapper) Bind(v interface{}) error { return c.router.srv.dec(c.req, v) }
func (c *wrapper) BindVars(v interface{}) error { return binding.BindQuery(c.Vars(), v) } func (c *wrapper) BindVars(v interface{}) error { return binding.BindQuery(c.Vars(), v) }

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/go-kratos/kratos/v2/internal/endpoint" "github.com/go-kratos/kratos/v2/internal/endpoint"
"github.com/go-kratos/kratos/v2/internal/matcher"
"github.com/go-kratos/kratos/v2/internal/host" "github.com/go-kratos/kratos/v2/internal/host"
"github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/log"
@ -58,7 +59,7 @@ func Logger(logger log.Logger) ServerOption {
// Middleware with service middleware option. // Middleware with service middleware option.
func Middleware(m ...middleware.Middleware) ServerOption { func Middleware(m ...middleware.Middleware) ServerOption {
return func(o *Server) { return func(o *Server) {
o.ms = m o.middleware.Use(m...)
} }
} }
@ -124,7 +125,7 @@ type Server struct {
address string address string
timeout time.Duration timeout time.Duration
filters []FilterFunc filters []FilterFunc
ms []middleware.Middleware middleware matcher.Matcher
dec DecodeRequestFunc dec DecodeRequestFunc
enc EncodeResponseFunc enc EncodeResponseFunc
ene EncodeErrorFunc ene EncodeErrorFunc
@ -138,6 +139,7 @@ func NewServer(opts ...ServerOption) *Server {
network: "tcp", network: "tcp",
address: ":0", address: ":0",
timeout: 1 * time.Second, timeout: 1 * time.Second,
middleware: matcher.New(),
dec: DefaultRequestDecoder, dec: DefaultRequestDecoder,
enc: DefaultResponseEncoder, enc: DefaultResponseEncoder,
ene: DefaultErrorEncoder, ene: DefaultErrorEncoder,
@ -157,6 +159,15 @@ func NewServer(opts ...ServerOption) *Server {
return srv return srv
} }
// Use uses a service middleware with selector.
// selector:
// - '/*'
// - '/helloworld.v1.Greeter/*'
// - '/helloworld.v1.Greeter/SayHello'
func (s *Server) Use(selector string, m ...middleware.Middleware) {
s.middleware.Add(selector, m...)
}
// WalkRoute walks the router and all its sub-routers, calling walkFn for each route in the tree. // WalkRoute walks the router and all its sub-routers, calling walkFn for each route in the tree.
func (s *Server) WalkRoute(fn WalkRouteFunc) error { func (s *Server) WalkRoute(fn WalkRouteFunc) error {
return s.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { return s.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
@ -229,10 +240,10 @@ func (s *Server) filter() mux.MiddlewareFunc {
tr := &Transport{ tr := &Transport{
operation: pathTemplate, operation: pathTemplate,
pathTemplate: pathTemplate,
reqHeader: headerCarrier(req.Header), reqHeader: headerCarrier(req.Header),
replyHeader: headerCarrier(w.Header()), replyHeader: headerCarrier(w.Header()),
request: req, request: req,
pathTemplate: pathTemplate,
} }
if s.endpoint != nil { if s.endpoint != nil {
tr.endpoint = s.endpoint.String() tr.endpoint = s.endpoint.String()

@ -14,7 +14,6 @@ import (
"time" "time"
"github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/internal/host" "github.com/go-kratos/kratos/v2/internal/host"
) )
@ -313,17 +312,6 @@ func TestLogger(t *testing.T) {
// todo // todo
} }
func TestMiddleware(t *testing.T) {
o := &Server{}
v := []middleware.Middleware{
func(middleware.Handler) middleware.Handler { return nil },
}
Middleware(v...)(o)
if !reflect.DeepEqual(v, o.ms) {
t.Errorf("expected %v got %v", v, o.ms)
}
}
func TestRequestDecoder(t *testing.T) { func TestRequestDecoder(t *testing.T) {
o := &Server{} o := &Server{}
v := func(*http.Request, interface{}) error { return nil } v := func(*http.Request, interface{}) error { return nil }

Loading…
Cancel
Save