From cc0221b5cefbf7d36d36b2e63c03a250e9ecacd6 Mon Sep 17 00:00:00 2001 From: Tony Chen Date: Tue, 25 May 2021 00:56:31 +0800 Subject: [PATCH] errors: add errors coder (#946) * add errors coder * rename internal http to httputil * add errors proto --- errors/errors.go | 63 +++----- errors/errors.pb.go | 194 +++++++++++++++++++++++ errors/errors.proto | 17 ++ examples/blog/internal/server/grpc.go | 4 +- examples/blog/internal/server/http.go | 4 +- examples/helloworld/client/main.go | 1 + internal/{http => httputil}/http.go | 11 +- internal/{http => httputil}/http_test.go | 2 +- middleware/recovery/recovery.go | 2 +- middleware/validate/validate.go | 4 +- transport/http/client.go | 81 +++++----- transport/http/handle.go | 62 ++++---- 12 files changed, 326 insertions(+), 119 deletions(-) create mode 100644 errors/errors.pb.go create mode 100644 errors/errors.proto rename internal/{http => httputil}/http.go (90%) rename internal/{http => httputil}/http_test.go (97%) diff --git a/errors/errors.go b/errors/errors.go index a9b3cb0e7..a6e40486c 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -4,64 +4,57 @@ import ( "errors" "fmt" + "github.com/go-kratos/kratos/v2/internal/httputil" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" ) -const ( - // SupportPackageIsVersion1 this constant should not be referenced by any other code. - SupportPackageIsVersion1 = true -) - -// Error is describes the cause of the error with structured details. -// For more details see https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto. -type Error struct { - s *status.Status - - Domain string `json:"domain"` - Reason string `json:"reason"` - Metadata map[string]string `json:"metadata"` -} +//go:generate protoc -I. --go_out=paths=source_relative:. errors.proto func (e *Error) Error() string { return fmt.Sprintf("error: domain = %s reason = %s metadata = %v", e.Domain, e.Reason, e.Metadata) } +// HTTPStatus return an HTTP error code. +func (e *Error) HTTPStatus() int { + return httputil.StatusFromGRPCCode(codes.Code(e.Code)) +} + // GRPCStatus returns the Status represented by se. func (e *Error) GRPCStatus() *status.Status { - s, err := e.s.WithDetails(&errdetails.ErrorInfo{ - Domain: e.Domain, - Reason: e.Reason, - Metadata: e.Metadata, - }) - if err != nil { - return e.s - } + s, _ := status.New(codes.Code(e.Code), e.Message). + WithDetails(&errdetails.ErrorInfo{ + Domain: e.Domain, + Reason: e.Reason, + Metadata: e.Metadata, + }) return s } // Is matches each error in the chain with the target value. func (e *Error) Is(err error) bool { - if target := new(Error); errors.As(err, &target) { - return target.Domain == e.Domain && target.Reason == e.Reason + if se := new(Error); errors.As(err, &se) { + return se.Domain == e.Domain && se.Reason == e.Reason } return false } // WithMetadata with an MD formed by the mapping of key, value. func (e *Error) WithMetadata(md map[string]string) *Error { - err := *e + err := proto.Clone(e).(*Error) err.Metadata = md - return &err + return err } // New returns an error object for the code, message. func New(code codes.Code, domain, reason, message string) *Error { return &Error{ - s: status.New(code, message), - Domain: domain, - Reason: reason, + Code: int32(code), + Message: message, + Domain: domain, + Reason: reason, } } @@ -72,11 +65,7 @@ func Newf(code codes.Code, domain, reason, format string, a ...interface{}) *Err // Errorf returns an error object for the code, message and error info. func Errorf(code codes.Code, domain, reason, format string, a ...interface{}) error { - return &Error{ - s: status.New(code, fmt.Sprintf(format, a...)), - Domain: domain, - Reason: reason, - } + return New(code, domain, reason, fmt.Sprintf(format, a...)) } // Code returns the code for a particular error. @@ -86,7 +75,7 @@ func Code(err error) codes.Code { return codes.OK } if se := FromError(err); err != nil { - return se.s.Code() + return codes.Code(se.Code) } return codes.Unknown } @@ -115,8 +104,8 @@ func FromError(err error) *Error { if err == nil { return nil } - if target := new(Error); errors.As(err, &target) { - return target + if se := new(Error); errors.As(err, &se) { + return se } gs, ok := status.FromError(err) if ok { diff --git a/errors/errors.pb.go b/errors/errors.pb.go new file mode 100644 index 000000000..17dc0c195 --- /dev/null +++ b/errors/errors.pb.go @@ -0,0 +1,194 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.26.0 +// protoc v3.14.0 +// source: errors.proto + +package errors + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Error struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Code int32 `protobuf:"varint,1,opt,name=code,proto3" json:"code,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + // error details + Reason string `protobuf:"bytes,3,opt,name=reason,proto3" json:"reason,omitempty"` + Domain string `protobuf:"bytes,4,opt,name=domain,proto3" json:"domain,omitempty"` + Metadata map[string]string `protobuf:"bytes,5,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` +} + +func (x *Error) Reset() { + *x = Error{} + if protoimpl.UnsafeEnabled { + mi := &file_errors_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Error) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Error) ProtoMessage() {} + +func (x *Error) ProtoReflect() protoreflect.Message { + mi := &file_errors_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Error.ProtoReflect.Descriptor instead. +func (*Error) Descriptor() ([]byte, []int) { + return file_errors_proto_rawDescGZIP(), []int{0} +} + +func (x *Error) GetCode() int32 { + if x != nil { + return x.Code + } + return 0 +} + +func (x *Error) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *Error) GetReason() string { + if x != nil { + return x.Reason + } + return "" +} + +func (x *Error) GetDomain() string { + if x != nil { + return x.Domain + } + return "" +} + +func (x *Error) GetMetadata() map[string]string { + if x != nil { + return x.Metadata + } + return nil +} + +var File_errors_proto protoreflect.FileDescriptor + +var file_errors_proto_rawDesc = []byte{ + 0x0a, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0d, + 0x6b, 0x72, 0x61, 0x74, 0x6f, 0x73, 0x2e, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x22, 0xe2, 0x01, + 0x0a, 0x05, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x12, 0x16, 0x0a, + 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x3e, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, + 0x61, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x22, 0x2e, 0x6b, 0x72, 0x61, 0x74, 0x6f, 0x73, + 0x2e, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x2e, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x2e, 0x4d, 0x65, + 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x3b, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, + 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, + 0x38, 0x01, 0x42, 0x59, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x6b, 0x72, 0x61, 0x74, 0x6f, 0x73, 0x2e, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x50, 0x01, + 0x5a, 0x2c, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, + 0x6b, 0x72, 0x61, 0x74, 0x6f, 0x73, 0x2f, 0x6b, 0x72, 0x61, 0x74, 0x6f, 0x73, 0x2f, 0x76, 0x32, + 0x2f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x3b, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x73, 0xa2, 0x02, + 0x0c, 0x4b, 0x72, 0x61, 0x74, 0x6f, 0x73, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x62, 0x06, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_errors_proto_rawDescOnce sync.Once + file_errors_proto_rawDescData = file_errors_proto_rawDesc +) + +func file_errors_proto_rawDescGZIP() []byte { + file_errors_proto_rawDescOnce.Do(func() { + file_errors_proto_rawDescData = protoimpl.X.CompressGZIP(file_errors_proto_rawDescData) + }) + return file_errors_proto_rawDescData +} + +var file_errors_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_errors_proto_goTypes = []interface{}{ + (*Error)(nil), // 0: kratos.errors.Error + nil, // 1: kratos.errors.Error.MetadataEntry +} +var file_errors_proto_depIdxs = []int32{ + 1, // 0: kratos.errors.Error.metadata:type_name -> kratos.errors.Error.MetadataEntry + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_errors_proto_init() } +func file_errors_proto_init() { + if File_errors_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_errors_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Error); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_errors_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_errors_proto_goTypes, + DependencyIndexes: file_errors_proto_depIdxs, + MessageInfos: file_errors_proto_msgTypes, + }.Build() + File_errors_proto = out.File + file_errors_proto_rawDesc = nil + file_errors_proto_goTypes = nil + file_errors_proto_depIdxs = nil +} diff --git a/errors/errors.proto b/errors/errors.proto new file mode 100644 index 000000000..2553932bb --- /dev/null +++ b/errors/errors.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package kratos.errors; + +option go_package = "github.com/go-kratos/kratos/v2/errors;errors"; +option java_multiple_files = true; +option java_package = "com.github.kratos.errors"; +option objc_class_prefix = "KratosErrors"; + +message Error { + int32 code = 1; + string message = 2; + // error details + string reason = 3; + string domain = 4; + map metadata = 5; +}; diff --git a/examples/blog/internal/server/grpc.go b/examples/blog/internal/server/grpc.go index f26c97059..fbee36592 100644 --- a/examples/blog/internal/server/grpc.go +++ b/examples/blog/internal/server/grpc.go @@ -17,10 +17,10 @@ import ( func NewGRPCServer(c *conf.Server, tracer trace.TracerProvider, blog *service.BlogService) *grpc.Server { var opts = []grpc.ServerOption{ grpc.Middleware( + recovery.Recovery(), tracing.Server(tracing.WithTracerProvider(tracer)), logging.Server(log.DefaultLogger), - recovery.Recovery(), - validate.Validator(v1.BlogService_ServiceDesc.ServiceName+".grpc"), + validate.Validator(), ), } if c.Grpc.Network != "" { diff --git a/examples/blog/internal/server/http.go b/examples/blog/internal/server/http.go index 0893d8bb2..59cb4a6d5 100644 --- a/examples/blog/internal/server/http.go +++ b/examples/blog/internal/server/http.go @@ -26,10 +26,10 @@ func NewHTTPServer(c *conf.Server, tracer trace.TracerProvider, blog *service.Bl opts = append(opts, http.Timeout(c.Http.Timeout.AsDuration())) } m := http.Middleware( + recovery.Recovery(), tracing.Server(tracing.WithTracerProvider(tracer)), logging.Server(log.DefaultLogger), - recovery.Recovery(), - validate.Validator(v1.BlogService_ServiceDesc.ServiceName+".http"), + validate.Validator(), ) srv := http.NewServer(opts...) srv.HandlePrefix("/", v1.NewBlogServiceHandler(blog, m)) diff --git a/examples/helloworld/client/main.go b/examples/helloworld/client/main.go index 77186ef5e..77c1950eb 100644 --- a/examples/helloworld/client/main.go +++ b/examples/helloworld/client/main.go @@ -36,6 +36,7 @@ func callHTTP() { } log.Printf("[http] SayHello %s\n", reply.Message) + // returns error reply, err = client.SayHello(context.Background(), &pb.HelloRequest{Name: "error"}) if err != nil { log.Printf("[http] SayHello error: %v\n", err) diff --git a/internal/http/http.go b/internal/httputil/http.go similarity index 90% rename from internal/http/http.go rename to internal/httputil/http.go index 893b11f7a..068941f96 100644 --- a/internal/http/http.go +++ b/internal/httputil/http.go @@ -1,4 +1,4 @@ -package http +package httputil import ( "net/http" @@ -11,15 +11,6 @@ const ( baseContentType = "application" ) -var ( - // HeaderAccept is accept header. - HeaderAccept = http.CanonicalHeaderKey("Accept") - // HeaderContentType is content-type header. - HeaderContentType = http.CanonicalHeaderKey("Content-Type") - // HeaderAcceptLanguage is accept-language header. - HeaderAcceptLanguage = http.CanonicalHeaderKey("Accept-Language") -) - // ContentType returns the content-type with base prefix. func ContentType(subtype string) string { return strings.Join([]string{baseContentType, subtype}, "/") diff --git a/internal/http/http_test.go b/internal/httputil/http_test.go similarity index 97% rename from internal/http/http_test.go rename to internal/httputil/http_test.go index 82d85104e..e8058b2c9 100644 --- a/internal/http/http_test.go +++ b/internal/httputil/http_test.go @@ -1,4 +1,4 @@ -package http +package httputil import "testing" diff --git a/middleware/recovery/recovery.go b/middleware/recovery/recovery.go index 0e8622ee3..f8a771297 100644 --- a/middleware/recovery/recovery.go +++ b/middleware/recovery/recovery.go @@ -40,7 +40,7 @@ func Recovery(opts ...Option) middleware.Middleware { options := options{ logger: log.DefaultLogger, handler: func(ctx context.Context, req, err interface{}) error { - return errors.InternalServer("", "recovery", fmt.Sprintf("panic triggered: %v", err)) + return errors.InternalServer("global", "recovery", fmt.Sprintf("panic triggered: %v", err)) }, } for _, o := range opts { diff --git a/middleware/validate/validate.go b/middleware/validate/validate.go index bc823bfee..a872b486f 100644 --- a/middleware/validate/validate.go +++ b/middleware/validate/validate.go @@ -12,12 +12,12 @@ type validator interface { } // Validator is a validator middleware. -func Validator(domain string) middleware.Middleware { +func Validator() middleware.Middleware { return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (reply interface{}, err error) { if v, ok := req.(validator); ok { if err := v.Validate(); err != nil { - return nil, errors.BadRequest(domain, "validator", err.Error()) + return nil, errors.BadRequest("global", "validator", err.Error()) } } return handler(ctx, req) diff --git a/transport/http/client.go b/transport/http/client.go index 48ddaae73..dc348cdb5 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -10,13 +10,11 @@ import ( "time" "github.com/go-kratos/kratos/v2/encoding" - xhttp "github.com/go-kratos/kratos/v2/internal/http" + "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/internal/httputil" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport/http/binding" - spb "google.golang.org/genproto/googleapis/rpc/status" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" ) @@ -28,19 +26,19 @@ type Client struct { endpoint string userAgent string middleware middleware.Middleware - encoder RequestEncodeFunc - decoder ResponseDecodeFunc + encoder EncodeRequestFunc + decoder DecodeResponseFunc errorDecoder DecodeErrorFunc } // DecodeErrorFunc is decode error func. type DecodeErrorFunc func(ctx context.Context, res *http.Response) error -// RequestEncodeFunc is request encode func. -type RequestEncodeFunc func(ctx context.Context, in interface{}) (contentType string, body []byte, err error) +// EncodeRequestFunc is request encode func. +type EncodeRequestFunc func(ctx context.Context, in interface{}) (contentType string, body []byte, err error) -// ResponseDecodeFunc is response decode func. -type ResponseDecodeFunc func(ctx context.Context, res *http.Response, out interface{}) error +// DecodeResponseFunc is response decode func. +type DecodeResponseFunc func(ctx context.Context, res *http.Response, out interface{}) error // ClientOption is HTTP client option. type ClientOption func(*clientOptions) @@ -87,20 +85,27 @@ func WithEndpoint(endpoint string) ClientOption { } } -// WithEncoder with client request encoder. -func WithEncoder(encoder RequestEncodeFunc) ClientOption { +// WithRequestEncoder with client request encoder. +func WithRequestEncoder(encoder EncodeRequestFunc) ClientOption { return func(o *clientOptions) { o.encoder = encoder } } -// WithDecoder with client response decoder. -func WithDecoder(decoder ResponseDecodeFunc) ClientOption { +// WithResponseDecoder with client response decoder. +func WithResponseDecoder(decoder DecodeResponseFunc) ClientOption { return func(o *clientOptions) { o.decoder = decoder } } +// WithErrorDecoder with client error decoder. +func WithErrorDecoder(errorDecoder DecodeErrorFunc) ClientOption { + return func(o *clientOptions) { + o.errorDecoder = errorDecoder + } +} + // Client is a HTTP transport client. type clientOptions struct { ctx context.Context @@ -110,8 +115,8 @@ type clientOptions struct { schema string endpoint string userAgent string - encoder RequestEncodeFunc - decoder ResponseDecodeFunc + encoder EncodeRequestFunc + decoder DecodeResponseFunc errorDecoder DecodeErrorFunc } @@ -164,16 +169,14 @@ func (client *Client) Invoke(ctx context.Context, pathPattern string, args inter if args != nil && c.bodyPattern != "" { // TODO: only encode the target field of args var ( - content []byte - err error + body []byte + err error ) - switch c.bodyPattern { - } - contentType, content, err = client.encoder(ctx, args) + contentType, body, err = client.encoder(ctx, args) if err != nil { return err } - reqBody = bytes.NewReader(content) + reqBody = bytes.NewReader(body) } req, err := http.NewRequest(c.method, url, reqBody) if err != nil { @@ -232,31 +235,26 @@ func (client *Client) do(ctx context.Context, req *http.Request, c callInfo) (*h return nil, err } if err := client.errorDecoder(ctx, resp); err != nil { - resp.Body.Close() return nil, err } return resp, nil } -func defaultRequestEncoder(ctx context.Context, in interface{}) (contentType string, body []byte, err error) { - content, err := encoding.GetCodec("json").Marshal(in) +func defaultRequestEncoder(ctx context.Context, in interface{}) (string, []byte, error) { + body, err := encoding.GetCodec("json").Marshal(in) if err != nil { return "", nil, err } - return "application/json", content, err + return "application/json", body, err } func defaultResponseDecoder(ctx context.Context, res *http.Response, v interface{}) error { - subtype := xhttp.ContentSubtype(res.Header.Get(xhttp.HeaderContentType)) - codec := encoding.GetCodec(subtype) - if codec == nil { - codec = encoding.GetCodec("json") - } + defer res.Body.Close() data, err := ioutil.ReadAll(res.Body) if err != nil { return err } - return codec.Unmarshal(data, v) + return codecForResponse(res).Unmarshal(data, v) } func defaultErrorDecoder(ctx context.Context, res *http.Response) error { @@ -265,10 +263,21 @@ func defaultErrorDecoder(ctx context.Context, res *http.Response) error { } defer res.Body.Close() if data, err := ioutil.ReadAll(res.Body); err == nil { - st := new(spb.Status) - if err = protojson.Unmarshal(data, st); err == nil { - return status.ErrorProto(st) + e := new(errors.Error) + if err := codecForResponse(res).Unmarshal(data, e); err == nil { + return e } } - return status.Error(xhttp.GRPCCodeFromStatus(res.StatusCode), res.Status) + return errors.Errorf(httputil.GRPCCodeFromStatus(res.StatusCode), "", "", "") +} + +func codecForResponse(r *http.Response) encoding.Codec { + codec := encoding.GetCodec(httputil.ContentSubtype("Content-Type")) + if codec != nil { + return codec + } + if codec == nil { + codec = encoding.GetCodec("json") + } + return codec } diff --git a/transport/http/handle.go b/transport/http/handle.go index c70f35f19..5f64f7030 100644 --- a/transport/http/handle.go +++ b/transport/http/handle.go @@ -8,15 +8,12 @@ import ( "reflect" "github.com/go-kratos/kratos/v2/encoding" - "github.com/go-kratos/kratos/v2/encoding/json" - xhttp "github.com/go-kratos/kratos/v2/internal/http" + "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/internal/httputil" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware/recovery" "github.com/go-kratos/kratos/v2/transport/http/binding" "github.com/gorilla/mux" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/encoding/protojson" ) // SupportPackageIsVersion1 These constants should not be referenced from any other code. @@ -47,9 +44,9 @@ type HandleOptions struct { // Deprecated: use NewHandler instead. func DefaultHandleOptions() HandleOptions { return HandleOptions{ - Decode: decodeRequest, - Encode: encodeResponse, - Error: encodeError, + Decode: defaultRequestDecoder, + Encode: defaultResponseEncoder, + Error: defaultErrorEncoder, Middleware: recovery.Recovery(), } } @@ -154,63 +151,72 @@ func validateHandler(handler interface{}) error { return nil } -// decodeRequest decodes the request body to object. -func decodeRequest(req *http.Request, v interface{}) error { - subtype := xhttp.ContentSubtype(req.Header.Get(xhttp.HeaderContentType)) +// defaultRequestDecoder decodes the request body to object. +func defaultRequestDecoder(req *http.Request, v interface{}) error { + subtype := httputil.ContentSubtype(req.Header.Get("Content-Type")) if codec := encoding.GetCodec(subtype); codec != nil { data, err := ioutil.ReadAll(req.Body) if err != nil { - return status.Error(codes.InvalidArgument, err.Error()) + return errors.BadRequest("global", "codec", err.Error()) } if err := codec.Unmarshal(data, v); err != nil { - return status.Error(codes.InvalidArgument, err.Error()) + return errors.BadRequest("global", "codec", err.Error()) } } else { if err := binding.BindForm(req, v); err != nil { - return status.Error(codes.InvalidArgument, err.Error()) + return errors.BadRequest("global", "codec", err.Error()) } } if err := binding.BindVars(mux.Vars(req), v); err != nil { - return status.Error(codes.InvalidArgument, err.Error()) + return errors.BadRequest("global", "codec", err.Error()) } return nil } -// encodeResponse encodes the object to the HTTP response. -func encodeResponse(w http.ResponseWriter, r *http.Request, v interface{}) error { +// defaultResponseEncoder encodes the object to the HTTP response. +func defaultResponseEncoder(w http.ResponseWriter, r *http.Request, v interface{}) error { codec := codecForRequest(r) data, err := codec.Marshal(v) if err != nil { return err } - w.Header().Set(xhttp.HeaderContentType, xhttp.ContentType(codec.Name())) + w.Header().Set("Content-Type", httputil.ContentType(codec.Name())) + if sc, ok := v.(interface { + HTTPStatus() int + }); ok { + w.WriteHeader(sc.HTTPStatus()) + } _, _ = w.Write(data) return nil } -// encodeError encodes the error to the HTTP response. -func encodeError(w http.ResponseWriter, r *http.Request, err error) { - st, _ := status.FromError(err) - data, err := protojson.Marshal(st.Proto()) +// defaultErrorEncoder encodes the error to the HTTP response. +func defaultErrorEncoder(w http.ResponseWriter, r *http.Request, se error) { + codec := codecForRequest(r) + body, err := codec.Marshal(se) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } - w.Header().Set(xhttp.HeaderContentType, "application/json; charset=utf-8") - w.WriteHeader(xhttp.StatusFromGRPCCode(st.Code())) - w.Write(data) + w.Header().Set("Content-Type", httputil.ContentType(codec.Name())) + if sc, ok := se.(interface { + HTTPStatus() int + }); ok { + w.WriteHeader(sc.HTTPStatus()) + } + w.Write(body) } // codecForRequest get encoding.Codec via http.Request func codecForRequest(r *http.Request) encoding.Codec { var codec encoding.Codec - for _, accept := range r.Header[xhttp.HeaderAccept] { - if codec = encoding.GetCodec(xhttp.ContentSubtype(accept)); codec != nil { + for _, accept := range r.Header["Accept"] { + if codec = encoding.GetCodec(httputil.ContentSubtype(accept)); codec != nil { break } } if codec == nil { - codec = encoding.GetCodec(json.Name) + codec = encoding.GetCodec("json") } return codec }