diff --git a/cmd/protoc-gen-go-http/go.mod b/cmd/protoc-gen-go-http/go.mod index bf656c694..a7d44ac4c 100644 --- a/cmd/protoc-gen-go-http/go.mod +++ b/cmd/protoc-gen-go-http/go.mod @@ -3,10 +3,9 @@ module github.com/go-kratos/kratos/cmd/protoc-gen-go-http/v2 go 1.15 require ( - github.com/go-kratos/kratos/v2 v2.0.0-alpha4 + github.com/go-kratos/kratos/v2 v2.0.0-20210303135906-ec6ddeeacb79 github.com/golang/protobuf v1.4.3 - github.com/google/uuid v1.2.0 // indirect - golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect + github.com/gorilla/mux v1.8.0 google.golang.org/genproto v0.0.0-20210202153253-cf70463f6119 google.golang.org/grpc v1.36.0 google.golang.org/protobuf v1.25.0 diff --git a/cmd/protoc-gen-go-http/go.sum b/cmd/protoc-gen-go-http/go.sum index f125f33b4..1c5cf17e2 100644 --- a/cmd/protoc-gen-go-http/go.sum +++ b/cmd/protoc-gen-go-http/go.sum @@ -9,12 +9,8 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/go-kratos/kratos/v2 v2.0.0-20210217083752-d86d233d93ce h1:LfOsLN9s8tAxR8xIZGWQvEVWxHfipTnBSE0dvG4h3k8= -github.com/go-kratos/kratos/v2 v2.0.0-20210217083752-d86d233d93ce/go.mod h1:oLvFyDBJkkWN8TPqb+NmpvRrSy9uM/K+XQubVRc11a8= -github.com/go-kratos/kratos/v2 v2.0.0-20210218084408-cf599c68a65f h1:ocHPvNS53zBT6NiGmgOWV2SuAmcNyHbJZk1t2cqEjIU= -github.com/go-kratos/kratos/v2 v2.0.0-20210218084408-cf599c68a65f/go.mod h1:oLvFyDBJkkWN8TPqb+NmpvRrSy9uM/K+XQubVRc11a8= -github.com/go-kratos/kratos/v2 v2.0.0-alpha4 h1:MKkkSZigSMg7Kx8HzrobZ93zlgmi0tAKWM9bMf6YTpU= -github.com/go-kratos/kratos/v2 v2.0.0-alpha4/go.mod h1:oLvFyDBJkkWN8TPqb+NmpvRrSy9uM/K+XQubVRc11a8= +github.com/go-kratos/kratos/v2 v2.0.0-20210303135906-ec6ddeeacb79 h1:30gIrPVUDzCZyHLh9gYbYyDykQAcVIoAWBty1UpwpzU= +github.com/go-kratos/kratos/v2 v2.0.0-20210303135906-ec6ddeeacb79/go.mod h1:oLvFyDBJkkWN8TPqb+NmpvRrSy9uM/K+XQubVRc11a8= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -35,8 +31,6 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs= -github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/imdario/mergo v0.3.6/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= @@ -60,8 +54,6 @@ golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAG golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9 h1:L2auWcuQIvxz9xSEqzESnV/QN/gNRXNApHi3fYwl2w0= @@ -86,7 +78,6 @@ google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZi google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.35.0 h1:TwIQcH3es+MojMVojxxfQ3l3OF2KzlRxML2xZq0kRo8= google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.36.0 h1:o1bcQ6imQMIOpdrO3SWf2z5RV72WbDwdXuK0MDlc8As= google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= diff --git a/cmd/protoc-gen-go-http/http.go b/cmd/protoc-gen-go-http/http.go index 22a4e69fd..97147f7d4 100644 --- a/cmd/protoc-gen-go-http/http.go +++ b/cmd/protoc-gen-go-http/http.go @@ -13,7 +13,9 @@ import ( const ( contextPackage = protogen.GoImportPath("context") httpPackage = protogen.GoImportPath("net/http") + muxPackage = protogen.GoImportPath("github.com/gorilla/mux") transportPackage = protogen.GoImportPath("github.com/go-kratos/kratos/v2/transport/http") + bindingPackage = protogen.GoImportPath("github.com/go-kratos/kratos/v2/transport/http/binding") ) var methodSets = make(map[string]int) @@ -40,7 +42,10 @@ func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen. } g.P("// This is a compile-time assertion to ensure that this generated file") g.P("// is compatible with the kratos package it is being compiled against.") - g.P("// ", contextPackage.Ident(""), "/", httpPackage.Ident("")) + g.P("var _ = new(", httpPackage.Ident("Request"), ")") + g.P("var _ = new(", contextPackage.Ident("Context"), ")") + g.P("var _ = ", bindingPackage.Ident("MapProto")) + g.P("var _ = ", muxPackage.Ident("NewRouter")) g.P("const _ = ", transportPackage.Ident("SupportPackageIsVersion1")) g.P() @@ -61,22 +66,24 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated Metadata: file.Desc.Path(), } for _, method := range service.Methods { + if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { + continue + } rule, ok := proto.GetExtension(method.Desc.Options(), annotations.E_Http).(*annotations.HttpRule) if rule != nil && ok { for _, bind := range rule.AdditionalBindings { - sd.Methods = append(sd.Methods, buildHTTPRule(method, bind)) + sd.Methods = append(sd.Methods, buildHTTPRule(g, method, bind)) } - sd.Methods = append(sd.Methods, buildHTTPRule(method, rule)) + sd.Methods = append(sd.Methods, buildHTTPRule(g, method, rule)) } else { path := fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name()) - sd.Methods = append(sd.Methods, buildMethodDesc(method, "POST", path)) - + sd.Methods = append(sd.Methods, buildMethodDesc(g, method, "POST", path)) } } g.P(sd.execute()) } -func buildHTTPRule(m *protogen.Method, rule *annotations.HttpRule) *methodDesc { +func buildHTTPRule(g *protogen.GeneratedFile, m *protogen.Method, rule *annotations.HttpRule) *methodDesc { var ( path string method string @@ -105,23 +112,29 @@ func buildHTTPRule(m *protogen.Method, rule *annotations.HttpRule) *methodDesc { } body = rule.Body responseBody = rule.ResponseBody - md := buildMethodDesc(m, method, path) + md := buildMethodDesc(g, m, method, path) + if body == "*" { + body = "" + } if body != "" { md.Body = "." + camelCaseVars(body) } + if responseBody == "*" { + responseBody = "" + } if responseBody != "" { md.ResponseBody = "." + camelCaseVars(responseBody) } return md } -func buildMethodDesc(m *protogen.Method, method, path string) *methodDesc { +func buildMethodDesc(g *protogen.GeneratedFile, m *protogen.Method, method, path string) *methodDesc { defer func() { methodSets[m.GoName]++ }() return &methodDesc{ Name: m.GoName, Num: methodSets[m.GoName], - Request: m.Input.GoIdent.GoName, - Reply: m.Output.GoIdent.GoName, + Request: g.QualifiedGoIdent(m.Input.GoIdent), + Reply: g.QualifiedGoIdent(m.Output.GoIdent), Path: path, Method: method, Vars: buildPathVars(m, path), diff --git a/cmd/protoc-gen-go-http/internal/testproto/echo_service_errors.pb.go b/cmd/protoc-gen-go-http/internal/testproto/echo_service_errors.pb.go new file mode 100644 index 000000000..715d8d0de --- /dev/null +++ b/cmd/protoc-gen-go-http/internal/testproto/echo_service_errors.pb.go @@ -0,0 +1,11 @@ +// Code generated by protoc-gen-go-errors. DO NOT EDIT. + +package testproto + +import ( + errors "github.com/go-kratos/kratos/v2/errors" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the kratos package it is being compiled against. +const _ = errors.SupportPackageIsVersion1 diff --git a/cmd/protoc-gen-go-http/internal/testproto/echo_service_http.pb.go b/cmd/protoc-gen-go-http/internal/testproto/echo_service_http.pb.go index ccdae48b2..27e3c610e 100644 --- a/cmd/protoc-gen-go-http/internal/testproto/echo_service_http.pb.go +++ b/cmd/protoc-gen-go-http/internal/testproto/echo_service_http.pb.go @@ -5,15 +5,21 @@ package testproto import ( context "context" http1 "github.com/go-kratos/kratos/v2/transport/http" + binding "github.com/go-kratos/kratos/v2/transport/http/binding" + mux "github.com/gorilla/mux" http "net/http" ) // This is a compile-time assertion to ensure that this generated file // is compatible with the kratos package it is being compiled against. -// context./http. +var _ = new(http.Request) +var _ = new(context.Context) +var _ = binding.MapProto +var _ = mux.NewRouter + const _ = http1.SupportPackageIsVersion1 -type EchoServiceHTTPServer interface { +type EchoServiceHandler interface { Echo(context.Context, *SimpleMessage) (*SimpleMessage, error) EchoBody(context.Context, *SimpleMessage) (*SimpleMessage, error) @@ -23,193 +29,221 @@ type EchoServiceHTTPServer interface { EchoPatch(context.Context, *DynamicMessageUpdate) (*DynamicMessageUpdate, error) } -func RegisterEchoServiceHTTPServer(s http1.ServiceRegistrar, srv EchoServiceHTTPServer) { - s.RegisterService(&_HTTP_EchoService_serviceDesc, srv) -} - -func _HTTP_EchoService_Echo_0(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { - var in SimpleMessage - - if err := http1.BindVars(req, &in); err != nil { - return nil, err - } - - if err := http1.BindForm(req, &in); err != nil { - return nil, err - } - - out, err := srv.(EchoServiceServer).Echo(ctx, &in) - if err != nil { - return nil, err - } - return out, nil -} - -func _HTTP_EchoService_Echo_1(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { - var in SimpleMessage - - if err := http1.BindVars(req, &in); err != nil { - return nil, err - } - - if err := http1.BindForm(req, &in); err != nil { - return nil, err - } - - out, err := srv.(EchoServiceServer).Echo(ctx, &in) - if err != nil { - return nil, err - } - return out, nil -} - -func _HTTP_EchoService_Echo_2(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { - var in SimpleMessage - - if err := http1.BindVars(req, &in); err != nil { - return nil, err - } - - if err := http1.BindForm(req, &in); err != nil { - return nil, err - } - - out, err := srv.(EchoServiceServer).Echo(ctx, &in) - if err != nil { - return nil, err - } - return out, nil -} - -func _HTTP_EchoService_Echo_3(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { - var in SimpleMessage - - if err := http1.BindVars(req, &in); err != nil { - return nil, err - } - - if err := http1.BindForm(req, &in); err != nil { - return nil, err - } - - out, err := srv.(EchoServiceServer).Echo(ctx, &in) - if err != nil { - return nil, err - } - return out, nil -} - -func _HTTP_EchoService_Echo_4(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { - var in SimpleMessage - - if err := http1.BindVars(req, &in); err != nil { - return nil, err - } - - if err := http1.BindForm(req, &in); err != nil { - return nil, err - } - - out, err := srv.(EchoServiceServer).Echo(ctx, &in) - if err != nil { - return nil, err - } - return out, nil -} - -func _HTTP_EchoService_EchoBody_0(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { - var in SimpleMessage - - if err := dec(&in); err != nil { - return nil, err - } - - out, err := srv.(EchoServiceServer).EchoBody(ctx, &in) - if err != nil { - return nil, err - } - return out, nil -} - -func _HTTP_EchoService_EchoDelete_0(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { - var in SimpleMessage - - if err := http1.BindForm(req, &in); err != nil { - return nil, err - } - - out, err := srv.(EchoServiceServer).EchoDelete(ctx, &in) - if err != nil { - return nil, err - } - return out, nil -} - -func _HTTP_EchoService_EchoPatch_0(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { - var in DynamicMessageUpdate - - if err := dec(in.Body); err != nil { - return nil, err - } - - out, err := srv.(EchoServiceServer).EchoPatch(ctx, &in) - if err != nil { - return nil, err - } - return out, nil -} - -var _HTTP_EchoService_serviceDesc = http1.ServiceDesc{ - ServiceName: "testproto.EchoService", - Methods: []http1.MethodDesc{ - - { - Path: "/v1/example/echo/{id}/{num}", - Method: "GET", - Handler: _HTTP_EchoService_Echo_0, - }, - - { - Path: "/v1/example/echo/{id}/{num}/{lang}", - Method: "GET", - Handler: _HTTP_EchoService_Echo_1, - }, - - { - Path: "/v1/example/echo1/{id}/{line_num}/{status.note}", - Method: "GET", - Handler: _HTTP_EchoService_Echo_2, - }, - - { - Path: "/v1/example/echo2/{no.note}", - Method: "GET", - Handler: _HTTP_EchoService_Echo_3, - }, - - { - Path: "/v1/example/echo/{id}", - Method: "POST", - Handler: _HTTP_EchoService_Echo_4, - }, - - { - Path: "/v1/example/echo_body", - Method: "POST", - Handler: _HTTP_EchoService_EchoBody_0, - }, - - { - Path: "/v1/example/echo_delete", - Method: "DELETE", - Handler: _HTTP_EchoService_EchoDelete_0, - }, - - { - Path: "/v1/example/echo_patch", - Method: "PATCH", - Handler: _HTTP_EchoService_EchoPatch_0, - }, - }, - Metadata: "echo_service.proto", +func NewEchoServiceHandler(srv EchoServiceHandler, opts ...http1.HandleOption) http.Handler { + h := http1.DefaultHandleOptions() + for _, o := range opts { + o(&h) + } + r := mux.NewRouter() + + r.HandleFunc("/v1/example/echo/{id}/{num}", func(w http.ResponseWriter, r *http.Request) { + var in SimpleMessage + + if err := binding.MapProto(&in, mux.Vars(r)); err != nil { + h.Error(w, r, err) + return + } + + if err := h.Decode(r, &in); err != nil { + h.Error(w, r, err) + return + } + next := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.Echo(ctx, req.(*SimpleMessage)) + } + if h.Middleware != nil { + next = h.Middleware(next) + } + out, err := next(r.Context(), &in) + if err != nil { + h.Error(w, r, err) + return + } + if err := h.Encode(w, r, out); err != nil { + h.Error(w, r, err) + } + }).Methods("GET") + + r.HandleFunc("/v1/example/echo/{id}/{num}/{lang}", func(w http.ResponseWriter, r *http.Request) { + var in SimpleMessage + + if err := binding.MapProto(&in, mux.Vars(r)); err != nil { + h.Error(w, r, err) + return + } + + if err := h.Decode(r, &in); err != nil { + h.Error(w, r, err) + return + } + next := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.Echo(ctx, req.(*SimpleMessage)) + } + if h.Middleware != nil { + next = h.Middleware(next) + } + out, err := next(r.Context(), &in) + if err != nil { + h.Error(w, r, err) + return + } + if err := h.Encode(w, r, out); err != nil { + h.Error(w, r, err) + } + }).Methods("GET") + + r.HandleFunc("/v1/example/echo1/{id}/{line_num}/{status.note}", func(w http.ResponseWriter, r *http.Request) { + var in SimpleMessage + + if err := binding.MapProto(&in, mux.Vars(r)); err != nil { + h.Error(w, r, err) + return + } + + if err := h.Decode(r, &in); err != nil { + h.Error(w, r, err) + return + } + next := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.Echo(ctx, req.(*SimpleMessage)) + } + if h.Middleware != nil { + next = h.Middleware(next) + } + out, err := next(r.Context(), &in) + if err != nil { + h.Error(w, r, err) + return + } + if err := h.Encode(w, r, out); err != nil { + h.Error(w, r, err) + } + }).Methods("GET") + + r.HandleFunc("/v1/example/echo2/{no.note}", func(w http.ResponseWriter, r *http.Request) { + var in SimpleMessage + + if err := binding.MapProto(&in, mux.Vars(r)); err != nil { + h.Error(w, r, err) + return + } + + if err := h.Decode(r, &in); err != nil { + h.Error(w, r, err) + return + } + next := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.Echo(ctx, req.(*SimpleMessage)) + } + if h.Middleware != nil { + next = h.Middleware(next) + } + out, err := next(r.Context(), &in) + if err != nil { + h.Error(w, r, err) + return + } + if err := h.Encode(w, r, out); err != nil { + h.Error(w, r, err) + } + }).Methods("GET") + + r.HandleFunc("/v1/example/echo/{id}", func(w http.ResponseWriter, r *http.Request) { + var in SimpleMessage + + if err := binding.MapProto(&in, mux.Vars(r)); err != nil { + h.Error(w, r, err) + return + } + + if err := h.Decode(r, &in); err != nil { + h.Error(w, r, err) + return + } + next := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.Echo(ctx, req.(*SimpleMessage)) + } + if h.Middleware != nil { + next = h.Middleware(next) + } + out, err := next(r.Context(), &in) + if err != nil { + h.Error(w, r, err) + return + } + if err := h.Encode(w, r, out); err != nil { + h.Error(w, r, err) + } + }).Methods("POST") + + r.HandleFunc("/v1/example/echo_body", func(w http.ResponseWriter, r *http.Request) { + var in SimpleMessage + + if err := h.Decode(r, &in); err != nil { + h.Error(w, r, err) + return + } + next := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.EchoBody(ctx, req.(*SimpleMessage)) + } + if h.Middleware != nil { + next = h.Middleware(next) + } + out, err := next(r.Context(), &in) + if err != nil { + h.Error(w, r, err) + return + } + if err := h.Encode(w, r, out); err != nil { + h.Error(w, r, err) + } + }).Methods("POST") + + r.HandleFunc("/v1/example/echo_delete", func(w http.ResponseWriter, r *http.Request) { + var in SimpleMessage + + if err := h.Decode(r, &in); err != nil { + h.Error(w, r, err) + return + } + next := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.EchoDelete(ctx, req.(*SimpleMessage)) + } + if h.Middleware != nil { + next = h.Middleware(next) + } + out, err := next(r.Context(), &in) + if err != nil { + h.Error(w, r, err) + return + } + if err := h.Encode(w, r, out); err != nil { + h.Error(w, r, err) + } + }).Methods("DELETE") + + r.HandleFunc("/v1/example/echo_patch", func(w http.ResponseWriter, r *http.Request) { + var in DynamicMessageUpdate + + if err := h.Decode(r, &in.Body); err != nil { + h.Error(w, r, err) + return + } + next := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.EchoPatch(ctx, req.(*DynamicMessageUpdate)) + } + if h.Middleware != nil { + next = h.Middleware(next) + } + out, err := next(r.Context(), &in) + if err != nil { + h.Error(w, r, err) + return + } + if err := h.Encode(w, r, out); err != nil { + h.Error(w, r, err) + } + }).Methods("PATCH") + + return r } diff --git a/cmd/protoc-gen-go-http/internal/testproto/stream.pb.go b/cmd/protoc-gen-go-http/internal/testproto/stream.pb.go new file mode 100644 index 000000000..a2fed91d8 --- /dev/null +++ b/cmd/protoc-gen-go-http/internal/testproto/stream.pb.go @@ -0,0 +1,91 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.25.0 +// protoc v3.13.0 +// source: stream.proto + +package testproto + +import ( + proto "github.com/golang/protobuf/proto" + empty "github.com/golang/protobuf/ptypes/empty" + _ "google.golang.org/genproto/googleapis/api/annotations" + httpbody "google.golang.org/genproto/googleapis/api/httpbody" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +var File_stream_proto protoreflect.FileDescriptor + +var file_stream_proto_rawDesc = []byte{ + 0x0a, 0x0c, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x09, + 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x61, 0x6e, 0x6e, 0x6f, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x19, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, + 0x61, 0x70, 0x69, 0x2f, 0x68, 0x74, 0x74, 0x70, 0x62, 0x6f, 0x64, 0x79, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x1a, 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x32, + 0x69, 0x0a, 0x0d, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x12, 0x58, 0x0a, 0x08, 0x44, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x16, 0x2e, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x61, 0x70, + 0x69, 0x2e, 0x48, 0x74, 0x74, 0x70, 0x42, 0x6f, 0x64, 0x79, 0x22, 0x1c, 0x82, 0xd3, 0xe4, 0x93, + 0x02, 0x16, 0x12, 0x14, 0x2f, 0x76, 0x31, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, + 0x64, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x30, 0x01, 0x42, 0x51, 0x5a, 0x4f, 0x67, 0x69, + 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x6b, 0x72, 0x61, 0x74, + 0x6f, 0x73, 0x2f, 0x6b, 0x72, 0x61, 0x74, 0x6f, 0x73, 0x2f, 0x63, 0x6d, 0x64, 0x2f, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x2d, 0x67, 0x65, 0x6e, 0x2d, 0x67, 0x6f, 0x2d, 0x68, 0x74, 0x74, 0x70, + 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x3b, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var file_stream_proto_goTypes = []interface{}{ + (*empty.Empty)(nil), // 0: google.protobuf.Empty + (*httpbody.HttpBody)(nil), // 1: google.api.HttpBody +} +var file_stream_proto_depIdxs = []int32{ + 0, // 0: testproto.StreamService.Download:input_type -> google.protobuf.Empty + 1, // 1: testproto.StreamService.Download:output_type -> google.api.HttpBody + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_stream_proto_init() } +func file_stream_proto_init() { + if File_stream_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_stream_proto_rawDesc, + NumEnums: 0, + NumMessages: 0, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_stream_proto_goTypes, + DependencyIndexes: file_stream_proto_depIdxs, + }.Build() + File_stream_proto = out.File + file_stream_proto_rawDesc = nil + file_stream_proto_goTypes = nil + file_stream_proto_depIdxs = nil +} diff --git a/cmd/protoc-gen-go-http/internal/testproto/stream.proto b/cmd/protoc-gen-go-http/internal/testproto/stream.proto new file mode 100644 index 000000000..f343c8f40 --- /dev/null +++ b/cmd/protoc-gen-go-http/internal/testproto/stream.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; +option go_package = "github.com/go-kratos/kratos/cmd/protoc-gen-go-http/internal/testproto;testproto"; +package testproto; + +import "google/api/annotations.proto"; +import "google/api/httpbody.proto"; +import "google/protobuf/empty.proto"; + +service StreamService { + rpc Download(google.protobuf.Empty) returns (stream google.api.HttpBody) { + option (google.api.http) = { + get : "/v1/example/download" + }; + } +} diff --git a/cmd/protoc-gen-go-http/internal/testproto/stream_grpc.pb.go b/cmd/protoc-gen-go-http/internal/testproto/stream_grpc.pb.go new file mode 100644 index 000000000..ed744d6bd --- /dev/null +++ b/cmd/protoc-gen-go-http/internal/testproto/stream_grpc.pb.go @@ -0,0 +1,130 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. + +package testproto + +import ( + context "context" + empty "github.com/golang/protobuf/ptypes/empty" + httpbody "google.golang.org/genproto/googleapis/api/httpbody" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// StreamServiceClient is the client API for StreamService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type StreamServiceClient interface { + Download(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (StreamService_DownloadClient, error) +} + +type streamServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewStreamServiceClient(cc grpc.ClientConnInterface) StreamServiceClient { + return &streamServiceClient{cc} +} + +func (c *streamServiceClient) Download(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (StreamService_DownloadClient, error) { + stream, err := c.cc.NewStream(ctx, &StreamService_ServiceDesc.Streams[0], "/testproto.StreamService/Download", opts...) + if err != nil { + return nil, err + } + x := &streamServiceDownloadClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type StreamService_DownloadClient interface { + Recv() (*httpbody.HttpBody, error) + grpc.ClientStream +} + +type streamServiceDownloadClient struct { + grpc.ClientStream +} + +func (x *streamServiceDownloadClient) Recv() (*httpbody.HttpBody, error) { + m := new(httpbody.HttpBody) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// StreamServiceServer is the server API for StreamService service. +// All implementations must embed UnimplementedStreamServiceServer +// for forward compatibility +type StreamServiceServer interface { + Download(*empty.Empty, StreamService_DownloadServer) error + mustEmbedUnimplementedStreamServiceServer() +} + +// UnimplementedStreamServiceServer must be embedded to have forward compatible implementations. +type UnimplementedStreamServiceServer struct { +} + +func (UnimplementedStreamServiceServer) Download(*empty.Empty, StreamService_DownloadServer) error { + return status.Errorf(codes.Unimplemented, "method Download not implemented") +} +func (UnimplementedStreamServiceServer) mustEmbedUnimplementedStreamServiceServer() {} + +// UnsafeStreamServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to StreamServiceServer will +// result in compilation errors. +type UnsafeStreamServiceServer interface { + mustEmbedUnimplementedStreamServiceServer() +} + +func RegisterStreamServiceServer(s grpc.ServiceRegistrar, srv StreamServiceServer) { + s.RegisterService(&StreamService_ServiceDesc, srv) +} + +func _StreamService_Download_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(empty.Empty) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(StreamServiceServer).Download(m, &streamServiceDownloadServer{stream}) +} + +type StreamService_DownloadServer interface { + Send(*httpbody.HttpBody) error + grpc.ServerStream +} + +type streamServiceDownloadServer struct { + grpc.ServerStream +} + +func (x *streamServiceDownloadServer) Send(m *httpbody.HttpBody) error { + return x.ServerStream.SendMsg(m) +} + +// StreamService_ServiceDesc is the grpc.ServiceDesc for StreamService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var StreamService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "testproto.StreamService", + HandlerType: (*StreamServiceServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "Download", + Handler: _StreamService_Download_Handler, + ServerStreams: true, + }, + }, + Metadata: "stream.proto", +} diff --git a/cmd/protoc-gen-go-http/internal/testproto/stream_http.pb.go b/cmd/protoc-gen-go-http/internal/testproto/stream_http.pb.go new file mode 100644 index 000000000..4112a4588 --- /dev/null +++ b/cmd/protoc-gen-go-http/internal/testproto/stream_http.pb.go @@ -0,0 +1,33 @@ +// Code generated by protoc-gen-go-http. DO NOT EDIT. + +package testproto + +import ( + context "context" + http1 "github.com/go-kratos/kratos/v2/transport/http" + binding "github.com/go-kratos/kratos/v2/transport/http/binding" + mux "github.com/gorilla/mux" + http "net/http" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the kratos package it is being compiled against. +var _ = new(http.Request) +var _ = new(context.Context) +var _ = binding.MapProto +var _ = mux.NewRouter + +const _ = http1.SupportPackageIsVersion1 + +type StreamServiceHandler interface { +} + +func NewStreamServiceHandler(srv StreamServiceHandler, opts ...http1.HandleOption) http.Handler { + h := http1.DefaultHandleOptions() + for _, o := range opts { + o(&h) + } + r := mux.NewRouter() + + return r +} diff --git a/cmd/protoc-gen-go-http/template.go b/cmd/protoc-gen-go-http/template.go index c4f5cc44b..6d21337bd 100644 --- a/cmd/protoc-gen-go-http/template.go +++ b/cmd/protoc-gen-go-http/template.go @@ -7,54 +7,48 @@ import ( ) var httpTemplate = ` -type {{.ServiceType}}HTTPServer interface { +type {{.ServiceType}}Handler interface { {{range .MethodSets}} {{.Name}}(context.Context, *{{.Request}}) (*{{.Reply}}, error) {{end}} } -func Register{{.ServiceType}}HTTPServer(s http1.ServiceRegistrar, srv {{.ServiceType}}HTTPServer) { - s.RegisterService(&_HTTP_{{.ServiceType}}_serviceDesc, srv) -} -{{range .Methods}} -func _HTTP_{{$.ServiceType}}_{{.Name}}_{{.Num}}(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { - var in {{.Request}} -{{if eq .Body ""}} - if err := http1.BindForm(req, &in); err != nil { - return nil, err - } -{{else if eq .Body ".*"}} - if err := dec(&in); err != nil { - return nil, err - } -{{else}} - if err := dec(in{{.Body}}); err != nil { - return nil, err - } -{{end}} -{{if ne (len .Vars) 0}} - if err := http1.BindVars(req, &in); err != nil { - return nil, err - } -{{end}} - out, err := srv.({{$.ServiceType}}Server).{{.Name}}(ctx, &in) - if err != nil { - return nil, err + +func New{{.ServiceType}}Handler(srv {{.ServiceType}}Handler, opts ...http1.HandleOption) http.Handler { + h := http1.DefaultHandleOptions() + for _, o := range opts { + o(&h) } - return out{{.ResponseBody}}, nil -} -{{end}} -var _HTTP_{{.ServiceType}}_serviceDesc = http1.ServiceDesc{ - ServiceName: "{{.ServiceName}}", - Methods: []http1.MethodDesc{ -{{range .Methods}} - { - Path: "{{.Path}}", - Method: "{{.Method}}", - Handler: _HTTP_{{$.ServiceType}}_{{.Name}}_{{.Num}}, - }, -{{end}} - }, - Metadata: "{{.Metadata}}", + r := mux.NewRouter() + {{range .Methods}} + r.HandleFunc("{{.Path}}", func(w http.ResponseWriter, r *http.Request) { + var in {{.Request}} + {{if ne (len .Vars) 0}} + if err := binding.MapProto(&in, mux.Vars(r)); err != nil { + h.Error(w, r, err) + return + } + {{end}} + if err := h.Decode(r, &in{{.Body}}); err != nil { + h.Error(w, r, err) + return + } + next := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.{{.Name}}(ctx, req.(*{{.Request}})) + } + if h.Middleware != nil { + next = h.Middleware(next) + } + out, err := next(r.Context(), &in) + if err != nil { + h.Error(w, r, err) + return + } + if err := h.Encode(w, r, out{{.ResponseBody}}); err != nil { + h.Error(w, r, err) + } + }).Methods("{{.Method}}") + {{end}} + return r } ` diff --git a/encoding/encoding.go b/encoding/encoding.go index 0e9141a7a..f9defe602 100644 --- a/encoding/encoding.go +++ b/encoding/encoding.go @@ -1,6 +1,8 @@ package encoding -import "strings" +import ( + "strings" +) // Codec defines the interface Transport uses to encode and decode messages. Note // that implementations of this interface must be thread safe; a Codec's diff --git a/encoding/proto/proto.go b/encoding/proto/proto.go index 26a291441..96051a969 100644 --- a/encoding/proto/proto.go +++ b/encoding/proto/proto.go @@ -3,7 +3,8 @@ package proto import ( - "google.golang.org/grpc/encoding" + "github.com/go-kratos/kratos/v2/encoding" + "google.golang.org/protobuf/proto" ) diff --git a/errors/errors.go b/errors/errors.go index 69f7b8a64..ad991020c 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -3,6 +3,7 @@ package errors import ( "errors" "fmt" + "net/http" ) const ( @@ -26,6 +27,48 @@ func (e *StatusError) Is(target error) bool { return false } +// HTTPStatus returns the Status represented by se. +func (e *StatusError) HTTPStatus() int { + switch e.Code { + case 0: + return http.StatusOK + case 1: + return http.StatusInternalServerError + case 2: + return http.StatusInternalServerError + case 3: + return http.StatusBadRequest + case 4: + return http.StatusRequestTimeout + case 5: + return http.StatusNotFound + case 6: + return http.StatusConflict + case 7: + return http.StatusForbidden + case 8: + return http.StatusTooManyRequests + case 9: + return http.StatusPreconditionFailed + case 10: + return http.StatusConflict + case 11: + return http.StatusBadRequest + case 12: + return http.StatusNotImplemented + case 13: + return http.StatusInternalServerError + case 14: + return http.StatusServiceUnavailable + case 15: + return http.StatusInternalServerError + case 16: + return http.StatusUnauthorized + default: + return http.StatusInternalServerError + } +} + func (e *StatusError) Error() string { return fmt.Sprintf("error: code = %d reason = %s message = %s details = %+v", e.Code, e.Reason, e.Message, e.Details) } diff --git a/go.sum b/go.sum index f4c905468..067f82f0f 100644 --- a/go.sum +++ b/go.sum @@ -20,7 +20,6 @@ github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:x github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1 h1:ZFgWrT+bLgsYPirOnRfKLYJLvssAegOj/hgyMFdJZe0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= @@ -41,7 +40,6 @@ github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+ github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -63,7 +61,6 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9 h1:L2auWcuQIvxz9xSEqzESnV/QN/gNRXNApHi3fYwl2w0= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -79,7 +76,6 @@ google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9Ywl google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20210114201628-6edceaf6022f h1:izedQ6yVIc5mZsRuXzmSreCOlzI0lCU1HpG8yEdMiKw= google.golang.org/genproto v0.0.0-20210114201628-6edceaf6022f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= diff --git a/transport/http/binding/bind.go b/transport/http/binding/bind.go new file mode 100644 index 000000000..7549c3637 --- /dev/null +++ b/transport/http/binding/bind.go @@ -0,0 +1,18 @@ +package binding + +import ( + "net/http" + + "google.golang.org/protobuf/proto" +) + +// BindForm bind form parameters to target. +func BindForm(req *http.Request, target interface{}) error { + if err := req.ParseForm(); err != nil { + return err + } + if msg, ok := target.(proto.Message); ok { + return mapProto(msg, req.Form) + } + return mapForm(target, req.Form) +} diff --git a/transport/http/binding/form.go b/transport/http/binding/form.go new file mode 100644 index 000000000..5c83bc124 --- /dev/null +++ b/transport/http/binding/form.go @@ -0,0 +1,385 @@ +package binding + +// Copyright 2014 Manu Martinez-Almeida. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +var ( + errUnknownType = errors.New("unknown type") + emptyField = reflect.StructField{} +) + +func mapForm(ptr interface{}, form map[string][]string) error { + return mapFormByTag(ptr, form, "json") +} + +func mapFormByTag(ptr interface{}, form map[string][]string, tag string) error { + ptrVal := reflect.ValueOf(ptr) + var pointed interface{} + if ptrVal.Kind() == reflect.Ptr { + ptrVal = ptrVal.Elem() + pointed = ptrVal.Interface() + } + if ptrVal.Kind() == reflect.Map && + ptrVal.Type().Key().Kind() == reflect.String { + if pointed != nil { + ptr = pointed + } + return setFormMap(ptr, form) + } + return mappingByPtr(ptr, formSource(form), tag) +} + +// setter tries to set value on a walking by fields of a struct +type setter interface { + TrySet(value reflect.Value, field reflect.StructField, key string, opt setOptions) (isSetted bool, err error) +} + +type formSource map[string][]string + +var _ setter = formSource(nil) + +// TrySet tries to set a value by request's form source (like map[string][]string) +func (form formSource) TrySet(value reflect.Value, field reflect.StructField, tagValue string, opt setOptions) (isSetted bool, err error) { + return setByForm(value, field, form, tagValue, opt) +} + +func mappingByPtr(ptr interface{}, setter setter, tag string) error { + _, err := mapping(reflect.ValueOf(ptr), emptyField, setter, tag) + return err +} + +func mapping(value reflect.Value, field reflect.StructField, setter setter, tag string) (bool, error) { + if field.Tag.Get(tag) == "-" { // just ignoring this field + return false, nil + } + + var vKind = value.Kind() + + if vKind == reflect.Ptr { + var isNew bool + vPtr := value + if value.IsNil() { + isNew = true + vPtr = reflect.New(value.Type().Elem()) + } + isSetted, err := mapping(vPtr.Elem(), field, setter, tag) + if err != nil { + return false, err + } + if isNew && isSetted { + value.Set(vPtr) + } + return isSetted, nil + } + + if vKind != reflect.Struct || !field.Anonymous { + ok, err := tryToSetValue(value, field, setter, tag) + if err != nil { + return false, err + } + if ok { + return true, nil + } + } + + if vKind == reflect.Struct { + tValue := value.Type() + + var isSetted bool + for i := 0; i < value.NumField(); i++ { + sf := tValue.Field(i) + if sf.PkgPath != "" && !sf.Anonymous { // unexported + continue + } + ok, err := mapping(value.Field(i), tValue.Field(i), setter, tag) + if err != nil { + return false, err + } + isSetted = isSetted || ok + } + return isSetted, nil + } + return false, nil +} + +type setOptions struct { + isDefaultExists bool + defaultValue string +} + +func tryToSetValue(value reflect.Value, field reflect.StructField, setter setter, tag string) (bool, error) { + var tagValue string + var setOpt setOptions + + tagValue = field.Tag.Get(tag) + tagValue, opts := head(tagValue, ",") + + if tagValue == "" { // default value is FieldName + tagValue = field.Name + } + if tagValue == "" { // when field is "emptyField" variable + return false, nil + } + + var opt string + for len(opts) > 0 { + opt, opts = head(opts, ",") + + if k, v := head(opt, "="); k == "default" { + setOpt.isDefaultExists = true + setOpt.defaultValue = v + } + } + + return setter.TrySet(value, field, tagValue, setOpt) +} + +func setByForm(value reflect.Value, field reflect.StructField, form map[string][]string, tagValue string, opt setOptions) (isSetted bool, err error) { + vs, ok := form[tagValue] + if !ok && !opt.isDefaultExists { + return false, nil + } + + switch value.Kind() { + case reflect.Slice: + if !ok { + vs = []string{opt.defaultValue} + } + return true, setSlice(vs, value, field) + case reflect.Array: + if !ok { + vs = []string{opt.defaultValue} + } + if len(vs) != value.Len() { + return false, fmt.Errorf("%q is not valid value for %s", vs, value.Type().String()) + } + return true, setArray(vs, value, field) + default: + var val string + if !ok { + val = opt.defaultValue + } + + if len(vs) > 0 { + val = vs[0] + } + return true, setWithProperType(val, value, field) + } +} + +func setWithProperType(val string, value reflect.Value, field reflect.StructField) error { + switch value.Kind() { + case reflect.Int: + return setIntField(val, 0, value) + case reflect.Int8: + return setIntField(val, 8, value) + case reflect.Int16: + return setIntField(val, 16, value) + case reflect.Int32: + return setIntField(val, 32, value) + case reflect.Int64: + switch value.Interface().(type) { + case time.Duration: + return setTimeDuration(val, value, field) + } + return setIntField(val, 64, value) + case reflect.Uint: + return setUintField(val, 0, value) + case reflect.Uint8: + return setUintField(val, 8, value) + case reflect.Uint16: + return setUintField(val, 16, value) + case reflect.Uint32: + return setUintField(val, 32, value) + case reflect.Uint64: + return setUintField(val, 64, value) + case reflect.Bool: + return setBoolField(val, value) + case reflect.Float32: + return setFloatField(val, 32, value) + case reflect.Float64: + return setFloatField(val, 64, value) + case reflect.String: + value.SetString(val) + case reflect.Struct: + switch value.Interface().(type) { + case time.Time: + return setTimeField(val, field, value) + } + return json.Unmarshal([]byte(val), value.Addr().Interface()) + case reflect.Map: + return json.Unmarshal([]byte(val), value.Addr().Interface()) + default: + return errUnknownType + } + return nil +} + +func setIntField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0" + } + intVal, err := strconv.ParseInt(val, 10, bitSize) + if err == nil { + field.SetInt(intVal) + } + return err +} + +func setUintField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0" + } + uintVal, err := strconv.ParseUint(val, 10, bitSize) + if err == nil { + field.SetUint(uintVal) + } + return err +} + +func setBoolField(val string, field reflect.Value) error { + if val == "" { + val = "false" + } + boolVal, err := strconv.ParseBool(val) + if err == nil { + field.SetBool(boolVal) + } + return err +} + +func setFloatField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0.0" + } + floatVal, err := strconv.ParseFloat(val, bitSize) + if err == nil { + field.SetFloat(floatVal) + } + return err +} + +func setTimeField(val string, structField reflect.StructField, value reflect.Value) error { + timeFormat := structField.Tag.Get("time_format") + if timeFormat == "" { + timeFormat = time.RFC3339 + } + + switch tf := strings.ToLower(timeFormat); tf { + case "unix", "unixnano": + tv, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return err + } + + d := time.Duration(1) + if tf == "unixnano" { + d = time.Second + } + + t := time.Unix(tv/int64(d), tv%int64(d)) + value.Set(reflect.ValueOf(t)) + return nil + + } + + if val == "" { + value.Set(reflect.ValueOf(time.Time{})) + return nil + } + + l := time.Local + if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC { + l = time.UTC + } + + if locTag := structField.Tag.Get("time_location"); locTag != "" { + loc, err := time.LoadLocation(locTag) + if err != nil { + return err + } + l = loc + } + + t, err := time.ParseInLocation(timeFormat, val, l) + if err != nil { + return err + } + + value.Set(reflect.ValueOf(t)) + return nil +} + +func setArray(vals []string, value reflect.Value, field reflect.StructField) error { + for i, s := range vals { + err := setWithProperType(s, value.Index(i), field) + if err != nil { + return err + } + } + return nil +} + +func setSlice(vals []string, value reflect.Value, field reflect.StructField) error { + slice := reflect.MakeSlice(value.Type(), len(vals), len(vals)) + err := setArray(vals, slice, field) + if err != nil { + return err + } + value.Set(slice) + return nil +} + +func setTimeDuration(val string, value reflect.Value, field reflect.StructField) error { + d, err := time.ParseDuration(val) + if err != nil { + return err + } + value.Set(reflect.ValueOf(d)) + return nil +} + +func head(str, sep string) (head string, tail string) { + idx := strings.Index(str, sep) + if idx < 0 { + return str, "" + } + return str[:idx], str[idx+len(sep):] +} + +func setFormMap(ptr interface{}, form map[string][]string) error { + el := reflect.TypeOf(ptr).Elem() + + if el.Kind() == reflect.Slice { + ptrMap, ok := ptr.(map[string][]string) + if !ok { + return errors.New("cannot convert to map slices of strings") + } + for k, v := range form { + ptrMap[k] = v + } + + return nil + } + + ptrMap, ok := ptr.(map[string]string) + if !ok { + return errors.New("cannot convert to map of strings") + } + for k, v := range form { + ptrMap[k] = v[len(v)-1] // pick last + } + + return nil +} diff --git a/transport/http/bind.go b/transport/http/binding/proto.go similarity index 89% rename from transport/http/bind.go rename to transport/http/binding/proto.go index d0a1d7fed..4a9dd9947 100644 --- a/transport/http/bind.go +++ b/transport/http/binding/proto.go @@ -1,26 +1,26 @@ -package http +package binding import ( "encoding/base64" "errors" "fmt" "log" - "net/http" "strconv" "strings" "time" - "github.com/golang/protobuf/ptypes" - "github.com/golang/protobuf/ptypes/wrappers" "google.golang.org/genproto/protobuf/field_mask" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + "google.golang.org/protobuf/types/known/wrapperspb" ) -// BindVars parses url parameters. -func BindVars(req *http.Request, msg proto.Message) error { - for key, value := range Vars(req) { +// MapProto sets a value in a nested Protobuf structure. +func MapProto(msg proto.Message, values map[string]string) error { + for key, value := range values { if err := populateFieldValues(msg.ProtoReflect(), strings.Split(key, "."), []string{value}); err != nil { return err } @@ -28,12 +28,8 @@ func BindVars(req *http.Request, msg proto.Message) error { return nil } -// BindForm parses form parameters. -func BindForm(req *http.Request, msg proto.Message) error { - if err := req.ParseForm(); err != nil { - return err - } - for key, values := range req.Form { +func mapProto(msg proto.Message, values map[string][]string) error { + for key, values := range values { if err := populateFieldValues(msg.ProtoReflect(), strings.Split(key, "."), values); err != nil { return err } @@ -213,10 +209,7 @@ func parseMessage(md protoreflect.MessageDescriptor, value string) (protoreflect if err != nil { return protoreflect.Value{}, err } - msg, err = ptypes.TimestampProto(t) - if err != nil { - return protoreflect.Value{}, err - } + msg = timestamppb.New(t) case "google.protobuf.Duration": if value == "null" { break @@ -225,57 +218,57 @@ func parseMessage(md protoreflect.MessageDescriptor, value string) (protoreflect if err != nil { return protoreflect.Value{}, err } - msg = ptypes.DurationProto(d) + msg = durationpb.New(d) case "google.protobuf.DoubleValue": v, err := strconv.ParseFloat(value, 64) if err != nil { return protoreflect.Value{}, err } - msg = &wrappers.DoubleValue{Value: v} + msg = wrapperspb.Double(v) case "google.protobuf.FloatValue": v, err := strconv.ParseFloat(value, 32) if err != nil { return protoreflect.Value{}, err } - msg = &wrappers.FloatValue{Value: float32(v)} + msg = wrapperspb.Float(float32(v)) case "google.protobuf.Int64Value": v, err := strconv.ParseInt(value, 10, 64) if err != nil { return protoreflect.Value{}, err } - msg = &wrappers.Int64Value{Value: v} + msg = wrapperspb.Int64(v) case "google.protobuf.Int32Value": v, err := strconv.ParseInt(value, 10, 32) if err != nil { return protoreflect.Value{}, err } - msg = &wrappers.Int32Value{Value: int32(v)} + msg = wrapperspb.Int32(int32(v)) case "google.protobuf.UInt64Value": v, err := strconv.ParseUint(value, 10, 64) if err != nil { return protoreflect.Value{}, err } - msg = &wrappers.UInt64Value{Value: v} + msg = wrapperspb.UInt64(v) case "google.protobuf.UInt32Value": v, err := strconv.ParseUint(value, 10, 32) if err != nil { return protoreflect.Value{}, err } - msg = &wrappers.UInt32Value{Value: uint32(v)} + msg = wrapperspb.UInt32(uint32(v)) case "google.protobuf.BoolValue": v, err := strconv.ParseBool(value) if err != nil { return protoreflect.Value{}, err } - msg = &wrappers.BoolValue{Value: v} + msg = wrapperspb.Bool(v) case "google.protobuf.StringValue": - msg = &wrappers.StringValue{Value: value} + msg = wrapperspb.String(value) case "google.protobuf.BytesValue": v, err := base64.StdEncoding.DecodeString(value) if err != nil { return protoreflect.Value{}, err } - msg = &wrappers.BytesValue{Value: v} + msg = wrapperspb.Bytes(v) case "google.protobuf.FieldMask": fm := &field_mask.FieldMask{} fm.Paths = append(fm.Paths, strings.Split(value, ",")...) diff --git a/transport/http/client.go b/transport/http/client.go index 35ba74d6d..35cf11f98 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -2,11 +2,10 @@ package http import ( "context" - "io/ioutil" + "encoding/json" "net/http" "time" - "github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" @@ -115,22 +114,13 @@ func Do(client *http.Client, req *http.Request, target interface{}) error { if err != nil { return err } - data, err := ioutil.ReadAll(res.Body) - if err != nil { - return err - } defer res.Body.Close() - subtype := contentSubtype(res.Header.Get("content-type")) - codec := encoding.GetCodec(subtype) - if codec == nil { - codec = encoding.GetCodec("json") - } if res.StatusCode < 200 || res.StatusCode > 299 { se := &errors.StatusError{} - if err := codec.Unmarshal(data, se); err != nil { + if err := json.NewDecoder(req.Body).Decode(se); err != nil { return err } return se } - return codec.Unmarshal(data, target) + return json.NewDecoder(req.Body).Decode(target) } diff --git a/transport/http/context.go b/transport/http/context.go index d12856e34..f7200e912 100644 --- a/transport/http/context.go +++ b/transport/http/context.go @@ -3,8 +3,6 @@ package http import ( "context" "net/http" - - "github.com/gorilla/mux" ) // ServerInfo is HTTP server infomation. @@ -43,8 +41,3 @@ func FromClientContext(ctx context.Context) (info ClientInfo, ok bool) { info, ok = ctx.Value(clientKey{}).(ClientInfo) return } - -// Vars returns the route variables for the current request, if any. -func Vars(req *http.Request) map[string]string { - return mux.Vars(req) -} diff --git a/transport/http/default.go b/transport/http/default.go deleted file mode 100644 index d1a59e95f..000000000 --- a/transport/http/default.go +++ /dev/null @@ -1,81 +0,0 @@ -package http - -import ( - "fmt" - "io/ioutil" - "net/http" - "strings" - - "github.com/go-kratos/kratos/v2/encoding" -) - -const baseContentType = "application" - -func contentType(subtype string) string { - return strings.Join([]string{baseContentType, subtype}, "/") -} - -func contentSubtype(contentType string) string { - if contentType == baseContentType { - return "" - } - if !strings.HasPrefix(contentType, baseContentType) { - return "" - } - // guaranteed since != baseContentType and has baseContentType prefix - switch contentType[len(baseContentType)] { - case '/', ';': - if i := strings.Index(contentType, ";"); i != -1 { - return contentType[len(baseContentType)+1 : i] - } - return contentType[len(baseContentType)+1:] - default: - return "" - } -} - -func defaultRequestDecoder(req *http.Request, v interface{}) error { - data, err := ioutil.ReadAll(req.Body) - if err != nil { - return err - } - defer req.Body.Close() - subtype := contentSubtype(req.Header.Get("content-type")) - codec := encoding.GetCodec(subtype) - if codec == nil { - return fmt.Errorf("decoding request failed unknown content-type: %s", subtype) - } - return codec.Unmarshal(data, v) -} - -func defaultResponseEncoder(res http.ResponseWriter, req *http.Request, v interface{}) error { - subtype := contentSubtype(req.Header.Get("accept")) - codec := encoding.GetCodec(subtype) - if codec == nil { - codec = encoding.GetCodec("json") - } - data, err := codec.Marshal(v) - if err != nil { - return err - } - res.Header().Set("content-type", contentType(codec.Name())) - res.Write(data) - return nil -} - -func defaultErrorEncoder(res http.ResponseWriter, req *http.Request, err error) { - se, code := StatusError(err) - subtype := contentSubtype(req.Header.Get("accept")) - codec := encoding.GetCodec(subtype) - if codec == nil { - codec = encoding.GetCodec("json") - } - data, err := codec.Marshal(se) - if err != nil { - res.WriteHeader(http.StatusInternalServerError) - return - } - res.Header().Set("content-type", contentType(codec.Name())) - res.WriteHeader(code) - res.Write(data) -} diff --git a/transport/http/default_test.go b/transport/http/default_test.go deleted file mode 100644 index 044f5be2e..000000000 --- a/transport/http/default_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package http - -import "testing" - -func TestSubtype(t *testing.T) { - tests := []struct { - input string - expected string - }{ - {"application/json", "json"}, - {"application/json;", "json"}, - {"application/json; charset=utf-8", "json"}, - {"application/", ""}, - {"application", ""}, - {"foo", ""}, - {"", ""}, - } - for _, test := range tests { - if contentSubtype(test.input) != test.expected { - t.Errorf("expected %s got %s", test.expected, test.input) - } - } -} diff --git a/transport/http/errors.go b/transport/http/errors.go deleted file mode 100644 index a0c016530..000000000 --- a/transport/http/errors.go +++ /dev/null @@ -1,59 +0,0 @@ -package http - -import ( - "net/http" - - "github.com/go-kratos/kratos/v2/errors" -) - -var ( - // References: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto - codesMapping = map[int32]int{ - 0: http.StatusOK, - 1: http.StatusInternalServerError, - 2: http.StatusInternalServerError, - 3: http.StatusBadRequest, - 4: http.StatusRequestTimeout, - 5: http.StatusNotFound, - 6: http.StatusConflict, - 7: http.StatusForbidden, - 8: http.StatusTooManyRequests, - 9: http.StatusPreconditionFailed, - 10: http.StatusConflict, - 11: http.StatusBadRequest, - 12: http.StatusNotImplemented, - 13: http.StatusInternalServerError, - 14: http.StatusServiceUnavailable, - 15: http.StatusInternalServerError, - 16: http.StatusUnauthorized, - } - statusMapping = map[int]int32{ - http.StatusOK: 0, - http.StatusBadRequest: 3, - http.StatusRequestTimeout: 4, - http.StatusNotFound: 5, - http.StatusConflict: 6, - http.StatusForbidden: 7, - http.StatusUnauthorized: 16, - http.StatusPreconditionFailed: 9, - http.StatusNotImplemented: 12, - http.StatusInternalServerError: 13, - http.StatusServiceUnavailable: 14, - } -) - -// StatusError converts error to status error. -func StatusError(err error) (*errors.StatusError, int) { - se, ok := errors.FromError(err) - if !ok { - se = &errors.StatusError{ - Code: 2, - Reason: "Unknown", - Message: "Unknown: " + err.Error(), - } - } - if status, ok := codesMapping[se.Code]; ok { - return se, status - } - return se, http.StatusInternalServerError -} diff --git a/transport/http/handle.go b/transport/http/handle.go new file mode 100644 index 000000000..a517acfb4 --- /dev/null +++ b/transport/http/handle.go @@ -0,0 +1,147 @@ +package http + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "strings" + + "github.com/go-kratos/kratos/v2/encoding" + "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/middleware" + "github.com/go-kratos/kratos/v2/transport/http/binding" +) + +const ( + // SupportPackageIsVersion1 These constants should not be referenced from any other code. + SupportPackageIsVersion1 = true + + baseContentType = "application" +) + +var ( + acceptHeader = http.CanonicalHeaderKey("Accept") + contentTypeHeader = http.CanonicalHeaderKey("Content-Type") +) + +// DecodeRequestFunc is decode request func. +type DecodeRequestFunc func(*http.Request, interface{}) error + +// EncodeResponseFunc is encode response func. +type EncodeResponseFunc func(http.ResponseWriter, *http.Request, interface{}) error + +// EncodeErrorFunc is encode error func. +type EncodeErrorFunc func(http.ResponseWriter, *http.Request, error) + +// HandleOption is handle option. +type HandleOption func(*HandleOptions) + +// HandleOptions is handle options. +type HandleOptions struct { + Decode DecodeRequestFunc + Encode EncodeResponseFunc + Error EncodeErrorFunc + Middleware middleware.Middleware +} + +// DefaultHandleOptions returns a default handle options. +func DefaultHandleOptions() HandleOptions { + return HandleOptions{ + Decode: decodeRequest, + Encode: encodeResponse, + Error: encodeError, + } +} + +// RequestDecoder with request decoder. +func RequestDecoder(dec DecodeRequestFunc) HandleOption { + return func(o *HandleOptions) { + o.Decode = dec + } +} + +// ResponseEncoder with response encoder. +func ResponseEncoder(en EncodeResponseFunc) HandleOption { + return func(o *HandleOptions) { + o.Encode = en + } +} + +// ErrorEncoder with error encoder. +func ErrorEncoder(en EncodeErrorFunc) HandleOption { + return func(o *HandleOptions) { + o.Error = en + } +} + +// Middleware with middleware option. +func Middleware(m middleware.Middleware) HandleOption { + return func(o *HandleOptions) { + o.Middleware = m + } +} + +// decodeRequest decodes the request body to object. +func decodeRequest(req *http.Request, v interface{}) error { + subtype := contentSubtype(req.Header.Get(contentTypeHeader)) + if codec := encoding.GetCodec(subtype); codec != nil { + data, err := ioutil.ReadAll(req.Body) + if err != nil { + return err + } + return codec.Unmarshal(data, v) + } + return binding.BindForm(req, v) +} + +// encodeResponse encodes the object to the HTTP response. +func encodeResponse(w http.ResponseWriter, r *http.Request, v interface{}) error { + for _, accept := range r.Header[acceptHeader] { + if codec := encoding.GetCodec(contentSubtype(accept)); codec != nil { + data, err := codec.Marshal(v) + if err != nil { + return err + } + w.Header().Set(contentTypeHeader, contentType(codec.Name())) + w.Write(data) + return nil + } + } + return json.NewEncoder(w).Encode(v) +} + +// encodeError encodes the erorr to the HTTP response. +func encodeError(w http.ResponseWriter, r *http.Request, err error) { + se, ok := errors.FromError(err) + if !ok { + se = &errors.StatusError{ + Code: 2, + Reason: "", + Message: err.Error(), + } + } + w.WriteHeader(se.HTTPStatus()) + encodeResponse(w, r, se) +} + +func contentType(subtype string) string { + return strings.Join([]string{baseContentType, subtype}, "/") +} + +func contentSubtype(contentType string) string { + if contentType == baseContentType { + return "" + } + if !strings.HasPrefix(contentType, baseContentType) { + return "" + } + switch contentType[len(baseContentType)] { + case '/', ';': + if i := strings.Index(contentType, ";"); i != -1 { + return contentType[len(baseContentType)+1 : i] + } + return contentType[len(baseContentType)+1:] + default: + return "" + } +} diff --git a/transport/http/handle_test.go b/transport/http/handle_test.go new file mode 100644 index 000000000..e5049ffee --- /dev/null +++ b/transport/http/handle_test.go @@ -0,0 +1,57 @@ +package http + +import ( + "context" + "net/http" + "testing" + + "github.com/gorilla/mux" +) + +type HelloRequest struct { + Name string `json:"name"` +} +type HelloReply struct { + Message string `json:"message"` +} +type GreeterService struct { +} + +func (s *GreeterService) SayHello(ctx context.Context, req *HelloRequest) (*HelloReply, error) { + return &HelloReply{Message: "hello " + req.Name}, nil +} + +func newGreeterHandler(srv *GreeterService, opts ...HandleOption) http.Handler { + h := DefaultHandleOptions() + for _, o := range opts { + o(&h) + } + r := mux.NewRouter() + r.HandleFunc("/helloworld", func(w http.ResponseWriter, r *http.Request) { + var in HelloRequest + if err := h.Decode(r, &in); err != nil { + h.Error(w, r, err) + return + } + next := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.SayHello(ctx, &in) + } + if h.Middleware != nil { + next = h.Middleware(next) + } + out, err := next(r.Context(), &in) + if err != nil { + h.Error(w, r, err) + return + } + if err := h.Encode(w, r, out); err != nil { + h.Error(w, r, err) + } + }).Methods("POST") + return r +} + +func TestHandler(t *testing.T) { + s := &GreeterService{} + _ = newGreeterHandler(s) +} diff --git a/transport/http/server.go b/transport/http/server.go index 4cfd470da..5d829a0bb 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -10,8 +10,6 @@ import ( "github.com/go-kratos/kratos/v2/internal/host" "github.com/go-kratos/kratos/v2/log" - "github.com/go-kratos/kratos/v2/middleware" - "github.com/go-kratos/kratos/v2/middleware/recovery" "github.com/go-kratos/kratos/v2/transport" "github.com/gorilla/mux" @@ -21,15 +19,6 @@ const loggerName = "transport/http" var _ transport.Server = (*Server)(nil) -// DecodeRequestFunc deocder request func. -type DecodeRequestFunc func(req *http.Request, v interface{}) error - -// EncodeResponseFunc is encode response func. -type EncodeResponseFunc func(res http.ResponseWriter, req *http.Request, v interface{}) error - -// EncodeErrorFunc is encode error func. -type EncodeErrorFunc func(res http.ResponseWriter, req *http.Request, err error) - // ServerOption is HTTP server option. type ServerOption func(*Server) @@ -61,46 +50,24 @@ func Logger(logger log.Logger) ServerOption { } } -// Middleware with server middleware option. -func Middleware(m middleware.Middleware) ServerOption { - return func(s *Server) { - s.middleware = m - } -} - -// ErrorEncoder with error handler option. -func ErrorEncoder(fn EncodeErrorFunc) ServerOption { - return func(s *Server) { - s.errorEncoder = fn - } -} - // Server is a HTTP server wrapper. type Server struct { *http.Server - lis net.Listener - network string - address string - timeout time.Duration - middleware middleware.Middleware - requestDecoder DecodeRequestFunc - responseEncoder EncodeResponseFunc - errorEncoder EncodeErrorFunc - router *mux.Router - log *log.Helper + lis net.Listener + network string + address string + timeout time.Duration + router *mux.Router + log *log.Helper } // NewServer creates a HTTP server by options. func NewServer(opts ...ServerOption) *Server { srv := &Server{ - network: "tcp", - address: ":0", - timeout: time.Second, - requestDecoder: defaultRequestDecoder, - responseEncoder: defaultResponseEncoder, - errorEncoder: defaultErrorEncoder, - middleware: recovery.Recovery(), - log: log.NewHelper(loggerName, log.DefaultLogger), + network: "tcp", + address: ":0", + timeout: time.Second, + log: log.NewHelper(loggerName, log.DefaultLogger), } for _, o := range opts { o(srv) @@ -120,33 +87,23 @@ func (s *Server) Handle(path string, h http.Handler) { s.router.Handle(path, h) } +// HanldePrefix registers a new route with a matcher for the URL path prefix. +func (s *Server) HanldePrefix(prefix string, h http.Handler) { + s.router.PathPrefix(prefix).Handler(h) +} + // HandleFunc registers a new route with a matcher for the URL path. func (s *Server) HandleFunc(path string, h http.HandlerFunc) { s.router.HandleFunc(path, h) } -// PrefixHanlde registers a new route with a matcher for the URL path prefix. -func (s *Server) PrefixHanlde(prefix string, h http.Handler) { - s.router.PathPrefix(prefix).Handler(h) -} - // ServeHTTP should write reply headers and data to the ResponseWriter and then return. func (s *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) { ctx, cancel := context.WithTimeout(req.Context(), s.timeout) defer cancel() ctx = transport.NewContext(ctx, transport.Transport{Kind: "HTTP"}) ctx = NewServerContext(ctx, ServerInfo{Request: req, Response: res}) - - h := func(ctx context.Context, req interface{}) (interface{}, error) { - s.router.ServeHTTP(res, req.(*http.Request)) - return res, nil - } - if s.middleware != nil { - h = s.middleware(h) - } - if _, err := h(ctx, req.WithContext(ctx)); err != nil { - s.errorEncoder(res, req, err) - } + s.router.ServeHTTP(res, req) } // Endpoint return a real address to registry endpoint. diff --git a/transport/http/service.go b/transport/http/service.go deleted file mode 100644 index 3fa3595f6..000000000 --- a/transport/http/service.go +++ /dev/null @@ -1,49 +0,0 @@ -package http - -import ( - "context" - "net/http" -) - -// SupportPackageIsVersion1 These constants should not be referenced from any other code. -const SupportPackageIsVersion1 = true - -type methodHandler func(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (out interface{}, err error) - -// MethodDesc represents a Proto service's method specification. -type MethodDesc struct { - Path string - Method string - Handler methodHandler -} - -// ServiceDesc represents a Proto service's specification. -type ServiceDesc struct { - ServiceName string - Methods []MethodDesc - Metadata interface{} -} - -// ServiceRegistrar wraps a single method that supports service registration. -type ServiceRegistrar interface { - RegisterService(desc *ServiceDesc, impl interface{}) -} - -// RegisterService . -func (s *Server) RegisterService(desc *ServiceDesc, impl interface{}) { - for _, m := range desc.Methods { - h := m.Handler - s.router.HandleFunc(m.Path, func(res http.ResponseWriter, req *http.Request) { - out, err := h(impl, req.Context(), req, func(v interface{}) error { - return s.requestDecoder(req, v) - }) - if err != nil { - s.errorEncoder(res, req, err) - return - } - if err := s.responseEncoder(res, req, out); err != nil { - s.errorEncoder(res, req, err) - } - }).Methods(m.Method) - } -} diff --git a/transport/http/service_test.go b/transport/http/service_test.go deleted file mode 100644 index 39e66d19c..000000000 --- a/transport/http/service_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package http - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "testing" - "time" - - "github.com/go-kratos/kratos/v2/internal/host" -) - -type testRequest struct { - Name string `json:"name"` -} -type testReply struct { - Result string `json:"result"` -} -type testService struct{} - -func (s *testService) SayHello(ctx context.Context, req *testRequest) (*testReply, error) { - return &testReply{Result: req.Name}, nil -} - -func TestService(t *testing.T) { - h := func(srv interface{}, ctx context.Context, req *http.Request, dec func(interface{}) error) (interface{}, error) { - var in testRequest - if err := dec(&in); err != nil { - return nil, err - } - out, err := srv.(*testService).SayHello(ctx, &in) - if err != nil { - return nil, err - } - return out, nil - } - sd := &ServiceDesc{ - ServiceName: "helloworld.Greeter", - Methods: []MethodDesc{ - { - Path: "/helloworld", - Method: "POST", - Handler: h, - }, - }, - } - - svc := &testService{} - srv := NewServer() - srv.RegisterService(sd, svc) - - time.AfterFunc(time.Second, func() { - defer srv.Stop() - testServiceClient(t, srv) - }) - - if err := srv.Start(); err != nil { - t.Fatal(err) - } -} - -func testServiceClient(t *testing.T, srv *Server) { - client, err := NewClient(context.Background()) - if err != nil { - t.Fatal(err) - } - port, ok := host.Port(srv.lis) - if !ok { - t.Fatalf("extract port error: %v", srv.lis) - } - var ( - in = testRequest{Name: "hello"} - out = testReply{} - url = fmt.Sprintf("http://127.0.0.1:%d/helloworld", port) - ) - data, err := json.Marshal(in) - if err != nil { - t.Fatal(err) - } - req, err := http.NewRequest("POST", url, bytes.NewReader(data)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("content-type", "application/json") - if err := Do(client, req, &out); err != nil { - t.Fatal(err) - } - if out.Result != in.Name { - t.Fatalf("expected %s got %s", in.Name, out.Result) - } -}