From c7827ff701538e00cade553d1351d714e9a2491f Mon Sep 17 00:00:00 2001 From: Tony Chen Date: Thu, 18 Feb 2021 16:53:59 +0800 Subject: [PATCH] fix http middleware (#710) * fix http middleware --- cmd/protoc-gen-go-http/go.mod | 2 +- cmd/protoc-gen-go-http/go.sum | 2 + cmd/protoc-gen-go-http/http.go | 9 ++- .../testproto/echo_service_grpc.pb.go | 8 ++- .../testproto/echo_service_http.pb.go | 59 ++++++------------- cmd/protoc-gen-go-http/template.go | 7 +-- transport/http/client.go | 2 +- transport/http/server.go | 10 +++- transport/http/service.go | 6 +- transport/http/service_test.go | 8 +-- 10 files changed, 46 insertions(+), 67 deletions(-) diff --git a/cmd/protoc-gen-go-http/go.mod b/cmd/protoc-gen-go-http/go.mod index 7e51ad5b4..3f2d1a909 100644 --- a/cmd/protoc-gen-go-http/go.mod +++ b/cmd/protoc-gen-go-http/go.mod @@ -3,7 +3,7 @@ module github.com/go-kratos/kratos/cmd/protoc-gen-go-http/v2 go 1.15 require ( - github.com/go-kratos/kratos/v2 v2.0.0-20210217083752-d86d233d93ce + github.com/go-kratos/kratos/v2 v2.0.0-20210218084408-cf599c68a65f github.com/golang/protobuf v1.4.3 google.golang.org/genproto v0.0.0-20210202153253-cf70463f6119 google.golang.org/grpc v1.35.0 diff --git a/cmd/protoc-gen-go-http/go.sum b/cmd/protoc-gen-go-http/go.sum index 9c8241d2b..9eecccc2e 100644 --- a/cmd/protoc-gen-go-http/go.sum +++ b/cmd/protoc-gen-go-http/go.sum @@ -11,6 +11,8 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7 github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/go-kratos/kratos/v2 v2.0.0-20210217083752-d86d233d93ce h1:LfOsLN9s8tAxR8xIZGWQvEVWxHfipTnBSE0dvG4h3k8= github.com/go-kratos/kratos/v2 v2.0.0-20210217083752-d86d233d93ce/go.mod h1:oLvFyDBJkkWN8TPqb+NmpvRrSy9uM/K+XQubVRc11a8= +github.com/go-kratos/kratos/v2 v2.0.0-20210218084408-cf599c68a65f h1:ocHPvNS53zBT6NiGmgOWV2SuAmcNyHbJZk1t2cqEjIU= +github.com/go-kratos/kratos/v2 v2.0.0-20210218084408-cf599c68a65f/go.mod h1:oLvFyDBJkkWN8TPqb+NmpvRrSy9uM/K+XQubVRc11a8= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/cmd/protoc-gen-go-http/http.go b/cmd/protoc-gen-go-http/http.go index d03f75306..22a4e69fd 100644 --- a/cmd/protoc-gen-go-http/http.go +++ b/cmd/protoc-gen-go-http/http.go @@ -11,10 +11,9 @@ import ( ) const ( - contextPackage = protogen.GoImportPath("context") - httpPackage = protogen.GoImportPath("net/http") - transportPackage = protogen.GoImportPath("github.com/go-kratos/kratos/v2/transport/http") - middlewarePackage = protogen.GoImportPath("github.com/go-kratos/kratos/v2/middleware") + contextPackage = protogen.GoImportPath("context") + httpPackage = protogen.GoImportPath("net/http") + transportPackage = protogen.GoImportPath("github.com/go-kratos/kratos/v2/transport/http") ) var methodSets = make(map[string]int) @@ -41,7 +40,7 @@ func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen. } g.P("// This is a compile-time assertion to ensure that this generated file") g.P("// is compatible with the kratos package it is being compiled against.") - g.P("// ", contextPackage.Ident(""), "/", httpPackage.Ident(""), "/", middlewarePackage.Ident("")) + g.P("// ", contextPackage.Ident(""), "/", httpPackage.Ident("")) g.P("const _ = ", transportPackage.Ident("SupportPackageIsVersion1")) g.P() diff --git a/cmd/protoc-gen-go-http/internal/testproto/echo_service_grpc.pb.go b/cmd/protoc-gen-go-http/internal/testproto/echo_service_grpc.pb.go index 51c231a60..ae03788d4 100644 --- a/cmd/protoc-gen-go-http/internal/testproto/echo_service_grpc.pb.go +++ b/cmd/protoc-gen-go-http/internal/testproto/echo_service_grpc.pb.go @@ -11,6 +11,7 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. const _ = grpc.SupportPackageIsVersion7 // EchoServiceClient is the client API for EchoService service. @@ -118,7 +119,7 @@ type UnsafeEchoServiceServer interface { } func RegisterEchoServiceServer(s grpc.ServiceRegistrar, srv EchoServiceServer) { - s.RegisterService(&_EchoService_serviceDesc, srv) + s.RegisterService(&EchoService_ServiceDesc, srv) } func _EchoService_Echo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { @@ -193,7 +194,10 @@ func _EchoService_EchoPatch_Handler(srv interface{}, ctx context.Context, dec fu return interceptor(ctx, in, info, handler) } -var _EchoService_serviceDesc = grpc.ServiceDesc{ +// EchoService_ServiceDesc is the grpc.ServiceDesc for EchoService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var EchoService_ServiceDesc = grpc.ServiceDesc{ ServiceName: "testproto.EchoService", HandlerType: (*EchoServiceServer)(nil), Methods: []grpc.MethodDesc{ diff --git a/cmd/protoc-gen-go-http/internal/testproto/echo_service_http.pb.go b/cmd/protoc-gen-go-http/internal/testproto/echo_service_http.pb.go index dfb5932f1..ccdae48b2 100644 --- a/cmd/protoc-gen-go-http/internal/testproto/echo_service_http.pb.go +++ b/cmd/protoc-gen-go-http/internal/testproto/echo_service_http.pb.go @@ -4,14 +4,13 @@ package testproto import ( context "context" - middleware "github.com/go-kratos/kratos/v2/middleware" http1 "github.com/go-kratos/kratos/v2/transport/http" http "net/http" ) // This is a compile-time assertion to ensure that this generated file // is compatible with the kratos package it is being compiled against. -// context./http./middleware. +// context./http. const _ = http1.SupportPackageIsVersion1 type EchoServiceHTTPServer interface { @@ -28,7 +27,7 @@ func RegisterEchoServiceHTTPServer(s http1.ServiceRegistrar, srv EchoServiceHTTP s.RegisterService(&_HTTP_EchoService_serviceDesc, srv) } -func _HTTP_EchoService_Echo_0(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error, m middleware.Middleware) (interface{}, error) { +func _HTTP_EchoService_Echo_0(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { var in SimpleMessage if err := http1.BindVars(req, &in); err != nil { @@ -39,17 +38,14 @@ func _HTTP_EchoService_Echo_0(srv interface{}, ctx context.Context, req *http.Re return nil, err } - h := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(EchoServiceServer).Echo(ctx, &in) - } - out, err := m(h)(ctx, &in) + out, err := srv.(EchoServiceServer).Echo(ctx, &in) if err != nil { return nil, err } return out, nil } -func _HTTP_EchoService_Echo_1(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error, m middleware.Middleware) (interface{}, error) { +func _HTTP_EchoService_Echo_1(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { var in SimpleMessage if err := http1.BindVars(req, &in); err != nil { @@ -60,17 +56,14 @@ func _HTTP_EchoService_Echo_1(srv interface{}, ctx context.Context, req *http.Re return nil, err } - h := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(EchoServiceServer).Echo(ctx, &in) - } - out, err := m(h)(ctx, &in) + out, err := srv.(EchoServiceServer).Echo(ctx, &in) if err != nil { return nil, err } return out, nil } -func _HTTP_EchoService_Echo_2(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error, m middleware.Middleware) (interface{}, error) { +func _HTTP_EchoService_Echo_2(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { var in SimpleMessage if err := http1.BindVars(req, &in); err != nil { @@ -81,17 +74,14 @@ func _HTTP_EchoService_Echo_2(srv interface{}, ctx context.Context, req *http.Re return nil, err } - h := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(EchoServiceServer).Echo(ctx, &in) - } - out, err := m(h)(ctx, &in) + out, err := srv.(EchoServiceServer).Echo(ctx, &in) if err != nil { return nil, err } return out, nil } -func _HTTP_EchoService_Echo_3(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error, m middleware.Middleware) (interface{}, error) { +func _HTTP_EchoService_Echo_3(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { var in SimpleMessage if err := http1.BindVars(req, &in); err != nil { @@ -102,17 +92,14 @@ func _HTTP_EchoService_Echo_3(srv interface{}, ctx context.Context, req *http.Re return nil, err } - h := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(EchoServiceServer).Echo(ctx, &in) - } - out, err := m(h)(ctx, &in) + out, err := srv.(EchoServiceServer).Echo(ctx, &in) if err != nil { return nil, err } return out, nil } -func _HTTP_EchoService_Echo_4(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error, m middleware.Middleware) (interface{}, error) { +func _HTTP_EchoService_Echo_4(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { var in SimpleMessage if err := http1.BindVars(req, &in); err != nil { @@ -123,61 +110,49 @@ func _HTTP_EchoService_Echo_4(srv interface{}, ctx context.Context, req *http.Re return nil, err } - h := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(EchoServiceServer).Echo(ctx, &in) - } - out, err := m(h)(ctx, &in) + out, err := srv.(EchoServiceServer).Echo(ctx, &in) if err != nil { return nil, err } return out, nil } -func _HTTP_EchoService_EchoBody_0(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error, m middleware.Middleware) (interface{}, error) { +func _HTTP_EchoService_EchoBody_0(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { var in SimpleMessage if err := dec(&in); err != nil { return nil, err } - h := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(EchoServiceServer).EchoBody(ctx, &in) - } - out, err := m(h)(ctx, &in) + out, err := srv.(EchoServiceServer).EchoBody(ctx, &in) if err != nil { return nil, err } return out, nil } -func _HTTP_EchoService_EchoDelete_0(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error, m middleware.Middleware) (interface{}, error) { +func _HTTP_EchoService_EchoDelete_0(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { var in SimpleMessage if err := http1.BindForm(req, &in); err != nil { return nil, err } - h := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(EchoServiceServer).EchoDelete(ctx, &in) - } - out, err := m(h)(ctx, &in) + out, err := srv.(EchoServiceServer).EchoDelete(ctx, &in) if err != nil { return nil, err } return out, nil } -func _HTTP_EchoService_EchoPatch_0(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error, m middleware.Middleware) (interface{}, error) { +func _HTTP_EchoService_EchoPatch_0(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { var in DynamicMessageUpdate if err := dec(in.Body); err != nil { return nil, err } - h := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(EchoServiceServer).EchoPatch(ctx, &in) - } - out, err := m(h)(ctx, &in) + out, err := srv.(EchoServiceServer).EchoPatch(ctx, &in) if err != nil { return nil, err } diff --git a/cmd/protoc-gen-go-http/template.go b/cmd/protoc-gen-go-http/template.go index 3537df4f9..618319563 100644 --- a/cmd/protoc-gen-go-http/template.go +++ b/cmd/protoc-gen-go-http/template.go @@ -16,7 +16,7 @@ func Register{{.ServiceType}}HTTPServer(s http1.ServiceRegistrar, srv {{.Service s.RegisterService(&_HTTP_{{.ServiceType}}_serviceDesc, srv) } {{range .Methods}} -func _HTTP_{{$.ServiceType}}_{{.Name}}_{{.Num}}(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error, m middleware.Middleware) (interface{}, error) { +func _HTTP_{{$.ServiceType}}_{{.Name}}_{{.Num}}(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { var in {{.Request}} {{if ne (len .Vars) 0}} if err := http1.BindVars(req, &in); err != nil { @@ -36,10 +36,7 @@ func _HTTP_{{$.ServiceType}}_{{.Name}}_{{.Num}}(srv interface{}, ctx context.Con return nil, err } {{end}} - h := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.({{$.ServiceType}}Server).{{.Name}}(ctx, &in) - } - out, err := m(h)(ctx, &in) + out, err := srv.({{$.ServiceType}}Server).{{.Name}}(ctx, &in) if err != nil { return nil, err } diff --git a/transport/http/client.go b/transport/http/client.go index 358bd3b0b..35ba74d6d 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -96,7 +96,7 @@ func (t *baseTransport) RoundTrip(req *http.Request) (*http.Response, error) { defer cancel() h := func(ctx context.Context, in interface{}) (interface{}, error) { - return t.base.RoundTrip(req) + return t.base.RoundTrip(in.(*http.Request)) } if t.middleware != nil { h = t.middleware(h) diff --git a/transport/http/server.go b/transport/http/server.go index 80f99935d..a4dd39ce7 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -131,7 +131,15 @@ func (s *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) { defer cancel() ctx = transport.NewContext(ctx, transport.Transport{Kind: "HTTP"}) ctx = NewServerContext(ctx, ServerInfo{Request: req, Response: res}) - s.router.ServeHTTP(res, req.WithContext(ctx)) + + h := func(ctx context.Context, req interface{}) (interface{}, error) { + s.router.ServeHTTP(res, req.(*http.Request)) + return res, nil + } + if s.middleware != nil { + h = s.middleware(h) + } + h(ctx, req.WithContext(ctx)) } // Endpoint return a real address to registry endpoint. diff --git a/transport/http/service.go b/transport/http/service.go index 9a09847ff..3fa3595f6 100644 --- a/transport/http/service.go +++ b/transport/http/service.go @@ -3,14 +3,12 @@ package http import ( "context" "net/http" - - "github.com/go-kratos/kratos/v2/middleware" ) // SupportPackageIsVersion1 These constants should not be referenced from any other code. const SupportPackageIsVersion1 = true -type methodHandler func(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error, m middleware.Middleware) (out interface{}, err error) +type methodHandler func(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (out interface{}, err error) // MethodDesc represents a Proto service's method specification. type MethodDesc struct { @@ -38,7 +36,7 @@ func (s *Server) RegisterService(desc *ServiceDesc, impl interface{}) { s.router.HandleFunc(m.Path, func(res http.ResponseWriter, req *http.Request) { out, err := h(impl, req.Context(), req, func(v interface{}) error { return s.requestDecoder(req, v) - }, s.middleware) + }) if err != nil { s.errorEncoder(res, req, err) return diff --git a/transport/http/service_test.go b/transport/http/service_test.go index 6d2440e20..39e66d19c 100644 --- a/transport/http/service_test.go +++ b/transport/http/service_test.go @@ -10,7 +10,6 @@ import ( "time" "github.com/go-kratos/kratos/v2/internal/host" - "github.com/go-kratos/kratos/v2/middleware" ) type testRequest struct { @@ -26,15 +25,12 @@ func (s *testService) SayHello(ctx context.Context, req *testRequest) (*testRepl } func TestService(t *testing.T) { - h := func(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error, m middleware.Middleware) (interface{}, error) { + h := func(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { var in testRequest if err := dec(&in); err != nil { return nil, err } - h := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(*testService).SayHello(ctx, &in) - } - out, err := m(h)(ctx, &in) + out, err := srv.(*testService).SayHello(ctx, &in) if err != nil { return nil, err }