From ef6e52d1bab82fb43dc8edfe8dc05936faa6eafc Mon Sep 17 00:00:00 2001 From: Tony Chen Date: Thu, 20 May 2021 23:30:50 +0800 Subject: [PATCH] add multiple middlewares (#936) --- examples/blog/internal/server/grpc.go | 9 +++------ examples/blog/internal/server/http.go | 9 +++------ examples/helloworld/client/main.go | 4 +--- examples/helloworld/server/main.go | 13 ++++--------- middleware/middleware.go | 8 ++++---- transport/grpc/client.go | 6 ++++-- transport/grpc/server.go | 4 ++-- transport/http/client.go | 4 ++-- transport/http/handle.go | 4 ++-- 9 files changed, 25 insertions(+), 36 deletions(-) diff --git a/examples/blog/internal/server/grpc.go b/examples/blog/internal/server/grpc.go index 17a1ba6fb..9d69e0fb3 100644 --- a/examples/blog/internal/server/grpc.go +++ b/examples/blog/internal/server/grpc.go @@ -5,7 +5,6 @@ import ( "github.com/go-kratos/kratos/examples/blog/internal/conf" "github.com/go-kratos/kratos/examples/blog/internal/service" "github.com/go-kratos/kratos/v2/log" - "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware/logging" "github.com/go-kratos/kratos/v2/middleware/recovery" "github.com/go-kratos/kratos/v2/middleware/tracing" @@ -17,11 +16,9 @@ import ( func NewGRPCServer(c *conf.Server, tracer trace.TracerProvider, blog *service.BlogService) *grpc.Server { var opts = []grpc.ServerOption{ grpc.Middleware( - middleware.Chain( - tracing.Server(tracing.WithTracerProvider(tracer)), - logging.Server(log.DefaultLogger), - recovery.Recovery(), - ), + tracing.Server(tracing.WithTracerProvider(tracer)), + logging.Server(log.DefaultLogger), + recovery.Recovery(), ), } if c.Grpc.Network != "" { diff --git a/examples/blog/internal/server/http.go b/examples/blog/internal/server/http.go index c991bab1c..c0f44acec 100644 --- a/examples/blog/internal/server/http.go +++ b/examples/blog/internal/server/http.go @@ -5,7 +5,6 @@ import ( "github.com/go-kratos/kratos/examples/blog/internal/conf" "github.com/go-kratos/kratos/examples/blog/internal/service" "github.com/go-kratos/kratos/v2/log" - "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware/logging" "github.com/go-kratos/kratos/v2/middleware/recovery" "github.com/go-kratos/kratos/v2/middleware/tracing" @@ -26,11 +25,9 @@ func NewHTTPServer(c *conf.Server, tracer trace.TracerProvider, blog *service.Bl opts = append(opts, http.Timeout(c.Http.Timeout.AsDuration())) } m := http.Middleware( - middleware.Chain( - tracing.Server(tracing.WithTracerProvider(tracer)), - logging.Server(log.DefaultLogger), - recovery.Recovery(), - ), + tracing.Server(tracing.WithTracerProvider(tracer)), + logging.Server(log.DefaultLogger), + recovery.Recovery(), ) srv := http.NewServer(opts...) srv.HandlePrefix("/", v1.NewBlogServiceHandler(blog, m)) diff --git a/examples/helloworld/client/main.go b/examples/helloworld/client/main.go index 6b23eec44..a86219550 100644 --- a/examples/helloworld/client/main.go +++ b/examples/helloworld/client/main.go @@ -22,9 +22,7 @@ func callHTTP() { client, err := transhttp.NewClient( context.Background(), transhttp.WithMiddleware( - middleware.Chain( - recovery.Recovery(), - ), + recovery.Recovery(), ), ) if err != nil { diff --git a/examples/helloworld/server/main.go b/examples/helloworld/server/main.go index e957251e2..bd5fdca18 100644 --- a/examples/helloworld/server/main.go +++ b/examples/helloworld/server/main.go @@ -9,7 +9,6 @@ import ( "github.com/go-kratos/kratos/v2" "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/log" - "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware/logging" "github.com/go-kratos/kratos/v2/middleware/recovery" "github.com/go-kratos/kratos/v2/transport/grpc" @@ -48,10 +47,8 @@ func main() { grpcSrv := grpc.NewServer( grpc.Address(":9000"), grpc.Middleware( - middleware.Chain( - logging.Server(logger), - recovery.Recovery(), - ), + recovery.Recovery(), + logging.Server(logger), )) s := &server{} @@ -60,10 +57,8 @@ func main() { httpSrv := http.NewServer(http.Address(":8000")) httpSrv.HandlePrefix("/", pb.NewGreeterHandler(s, http.Middleware( - middleware.Chain( - logging.Server(logger), - recovery.Recovery(), - ), + recovery.Recovery(), + logging.Server(logger), )), ) diff --git a/middleware/middleware.go b/middleware/middleware.go index 8d687c0f7..8a514ad97 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -11,11 +11,11 @@ type Handler func(ctx context.Context, req interface{}) (interface{}, error) type Middleware func(Handler) Handler // Chain returns a Middleware that specifies the chained handler for endpoint. -func Chain(outer Middleware, others ...Middleware) Middleware { +func Chain(m ...Middleware) Middleware { return func(next Handler) Handler { - for i := len(others) - 1; i >= 0; i-- { - next = others[i](next) + for i := len(m) - 1; i >= 0; i-- { + next = m[i](next) } - return outer(next) + return next } } diff --git a/transport/grpc/client.go b/transport/grpc/client.go index 3767a19fd..6b062cd38 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -9,6 +9,8 @@ import ( "github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport/grpc/resolver/discovery" + + // init resolver _ "github.com/go-kratos/kratos/v2/transport/grpc/resolver/direct" "google.golang.org/grpc" @@ -33,9 +35,9 @@ func WithTimeout(timeout time.Duration) ClientOption { } // WithMiddleware with client middleware. -func WithMiddleware(m middleware.Middleware) ClientOption { +func WithMiddleware(m ...middleware.Middleware) ClientOption { return func(o *clientOptions) { - o.middleware = m + o.middleware = middleware.Chain(m...) } } diff --git a/transport/grpc/server.go b/transport/grpc/server.go index a7ac5eedc..6f0b95631 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -52,9 +52,9 @@ func Logger(logger log.Logger) ServerOption { } // Middleware with server middleware. -func Middleware(m middleware.Middleware) ServerOption { +func Middleware(m ...middleware.Middleware) ServerOption { return func(s *Server) { - s.middleware = m + s.middleware = middleware.Chain(m...) } } diff --git a/transport/http/client.go b/transport/http/client.go index 519008402..e60c8374b 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -43,9 +43,9 @@ func WithTransport(trans http.RoundTripper) ClientOption { } // WithMiddleware with client middleware. -func WithMiddleware(m middleware.Middleware) ClientOption { +func WithMiddleware(m ...middleware.Middleware) ClientOption { return func(o *clientOptions) { - o.middleware = m + o.middleware = middleware.Chain(m...) } } diff --git a/transport/http/handle.go b/transport/http/handle.go index b02782c87..16c69b89e 100644 --- a/transport/http/handle.go +++ b/transport/http/handle.go @@ -67,9 +67,9 @@ func ErrorEncoder(en EncodeErrorFunc) HandleOption { } // Middleware with middleware option. -func Middleware(m middleware.Middleware) HandleOption { +func Middleware(m ...middleware.Middleware) HandleOption { return func(o *HandleOptions) { - o.Middleware = m + o.Middleware = middleware.Chain(m...) } }