From 81f96ee74d4628d4ced2bbd58b48b3775ebb2402 Mon Sep 17 00:00:00 2001 From: Tony Chen Date: Sat, 3 Jul 2021 15:22:13 +0800 Subject: [PATCH] fix(http): fix error encoder (#1141) * fix error encoder --- errors/errors.go | 7 +-- examples/errors/api/error_reason.proto | 3 +- examples/errors/api/errors.proto | 18 ------- examples/errors/client/main.go | 16 +++---- examples/errors/server/main.go | 2 +- examples/http/errors/main.go | 65 ++++++++++++++++++++++++++ transport/http/codec.go | 37 +++++---------- transport/http/codec_test.go | 29 +++++------- 8 files changed, 101 insertions(+), 76 deletions(-) delete mode 100644 examples/errors/api/errors.proto create mode 100644 examples/http/errors/main.go diff --git a/errors/errors.go b/errors/errors.go index d83b130df..01a492d7f 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -25,14 +25,9 @@ func (e *Error) Error() string { return fmt.Sprintf("error: code = %d reason = %s message = %s metadata = %v", e.Code, e.Reason, e.Message, e.Metadata) } -// StatusCode return an HTTP error code. -func (e *Error) StatusCode() int { - return int(e.Code) -} - // GRPCStatus returns the Status represented by se. func (e *Error) GRPCStatus() *status.Status { - s, _ := status.New(httputil.GRPCCodeFromStatus(e.StatusCode()), e.Message). + s, _ := status.New(httputil.GRPCCodeFromStatus(int(e.Code)), e.Message). WithDetails(&errdetails.ErrorInfo{ Reason: e.Reason, Metadata: e.Metadata, diff --git a/examples/errors/api/error_reason.proto b/examples/errors/api/error_reason.proto index 2b1a2c166..37f322e19 100644 --- a/examples/errors/api/error_reason.proto +++ b/examples/errors/api/error_reason.proto @@ -1,7 +1,8 @@ syntax = "proto3"; package errors; -import "errors.proto"; + +import "errors/errors.proto"; // 多语言特定包名,用于源代码引用 option go_package = "github.com/go-kratos/kratos/examples/blog/api/v1;v1"; diff --git a/examples/errors/api/errors.proto b/examples/errors/api/errors.proto deleted file mode 100644 index 3603729ac..000000000 --- a/examples/errors/api/errors.proto +++ /dev/null @@ -1,18 +0,0 @@ -syntax = "proto3"; - -package errors; - -option go_package = "github.com/go-kratos/kratos/v2/errors;v1"; -option java_multiple_files = true; -option java_package = "com.github.kratos.errors"; -option objc_class_prefix = "KratosErrors"; - -import "google/protobuf/descriptor.proto"; - -extend google.protobuf.EnumOptions { - int32 default_code = 1108; -} - -extend google.protobuf.EnumValueOptions { - int32 code = 1109; -} \ No newline at end of file diff --git a/examples/errors/client/main.go b/examples/errors/client/main.go index 7a5d9cdc7..b9b9807ec 100644 --- a/examples/errors/client/main.go +++ b/examples/errors/client/main.go @@ -2,13 +2,13 @@ package main import ( "context" - "github.com/go-kratos/kratos/examples/errors/api" - "github.com/go-kratos/kratos/v2/errors" "log" + "github.com/go-kratos/kratos/examples/errors/api" pb "github.com/go-kratos/kratos/examples/helloworld/helloworld" - transgrpc "github.com/go-kratos/kratos/v2/transport/grpc" - transhttp "github.com/go-kratos/kratos/v2/transport/http" + "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/transport/grpc" + "github.com/go-kratos/kratos/v2/transport/http" ) func main() { @@ -17,9 +17,9 @@ func main() { } func callHTTP() { - conn, err := transhttp.NewClient( + conn, err := http.NewClient( context.Background(), - transhttp.WithEndpoint("127.0.0.1:8000"), + http.WithEndpoint("127.0.0.1:8000"), ) if err != nil { panic(err) @@ -39,9 +39,9 @@ func callHTTP() { } func callGRPC() { - conn, err := transgrpc.DialInsecure( + conn, err := grpc.DialInsecure( context.Background(), - transgrpc.WithEndpoint("127.0.0.1:9000"), + grpc.WithEndpoint("127.0.0.1:9000"), ) if err != nil { panic(err) diff --git a/examples/errors/server/main.go b/examples/errors/server/main.go index 7319554e4..45f7ea6a3 100644 --- a/examples/errors/server/main.go +++ b/examples/errors/server/main.go @@ -3,12 +3,12 @@ package main import ( "context" "fmt" - "github.com/go-kratos/kratos/v2/errors" "log" "github.com/go-kratos/kratos/examples/errors/api" "github.com/go-kratos/kratos/examples/helloworld/helloworld" "github.com/go-kratos/kratos/v2" + "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/transport/grpc" "github.com/go-kratos/kratos/v2/transport/http" ) diff --git a/examples/http/errors/main.go b/examples/http/errors/main.go new file mode 100644 index 000000000..335bc7d91 --- /dev/null +++ b/examples/http/errors/main.go @@ -0,0 +1,65 @@ +package main + +import ( + "errors" + "fmt" + "log" + stdhttp "net/http" + + "github.com/go-kratos/kratos/v2" + "github.com/go-kratos/kratos/v2/transport/http" +) + +// HTTPError is an HTTP error. +type HTTPError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +func (e *HTTPError) Error() string { + return fmt.Sprintf("HTTPError code: %d message: %s", e.Code, e.Message) +} + +// FromError try to convert an error to *HTTPError. +func FromError(err error) *HTTPError { + if err == nil { + return nil + } + if se := new(HTTPError); errors.As(err, &se) { + return se + } + return &HTTPError{Code: 500} +} + +func errorEncoder(w stdhttp.ResponseWriter, r *stdhttp.Request, err error) { + se := FromError(err) + codec, _ := http.CodecForRequest(r, "Accept") + body, err := codec.Marshal(se) + if err != nil { + w.WriteHeader(500) + return + } + w.Header().Set("Content-Type", "application/"+codec.Name()) + w.WriteHeader(se.Code) + w.Write(body) +} + +func main() { + httpSrv := http.NewServer( + http.Address(":8000"), + http.ErrorEncoder(errorEncoder), + ) + router := httpSrv.Route("/") + router.GET("home", func(ctx http.Context) error { + return &HTTPError{Code: 400, Message: "request error"} + }) + app := kratos.New( + kratos.Name("mux"), + kratos.Server( + httpSrv, + ), + ) + if err := app.Run(); err != nil { + log.Fatal(err) + } +} diff --git a/transport/http/codec.go b/transport/http/codec.go index 55563c0b9..58db62e17 100644 --- a/transport/http/codec.go +++ b/transport/http/codec.go @@ -7,7 +7,6 @@ import ( "github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/internal/httputil" - "github.com/go-kratos/kratos/v2/transport/http/binding" ) // SupportPackageIsVersion1 These constants should not be referenced from any other code. @@ -24,17 +23,15 @@ type EncodeErrorFunc func(http.ResponseWriter, *http.Request, error) // DefaultRequestDecoder decodes the request body to object. func DefaultRequestDecoder(r *http.Request, v interface{}) error { - if codec, ok := CodecForRequest(r, "Content-Type"); ok { - data, err := ioutil.ReadAll(r.Body) - if err != nil { - return errors.BadRequest("CODEC", err.Error()) - } - if err := codec.Unmarshal(data, v); err != nil { - return errors.BadRequest("CODEC", err.Error()) - } - return nil + codec, ok := CodecForRequest(r, "Content-Type") + if !ok { + return errors.BadRequest("CODEC", r.Header.Get("Content-Type")) + } + data, err := ioutil.ReadAll(r.Body) + if err != nil { + return errors.BadRequest("CODEC", err.Error()) } - if err := binding.BindForm(r, v); err != nil { + if err = codec.Unmarshal(data, v); err != nil { return errors.BadRequest("CODEC", err.Error()) } return nil @@ -48,17 +45,13 @@ func DefaultResponseEncoder(w http.ResponseWriter, r *http.Request, v interface{ return err } w.Header().Set("Content-Type", httputil.ContentType(codec.Name())) - if sc, ok := v.(interface { - StatusCode() int - }); ok { - w.WriteHeader(sc.StatusCode()) - } - _, _ = w.Write(data) + w.Write(data) return nil } // DefaultErrorEncoder encodes the error to the HTTP response. -func DefaultErrorEncoder(w http.ResponseWriter, r *http.Request, se error) { +func DefaultErrorEncoder(w http.ResponseWriter, r *http.Request, err error) { + se := errors.FromError(err) codec, _ := CodecForRequest(r, "Accept") body, err := codec.Marshal(se) if err != nil { @@ -66,13 +59,7 @@ func DefaultErrorEncoder(w http.ResponseWriter, r *http.Request, se error) { return } w.Header().Set("Content-Type", httputil.ContentType(codec.Name())) - if sc, ok := se.(interface { - StatusCode() int - }); ok { - w.WriteHeader(sc.StatusCode()) - } else { - w.WriteHeader(http.StatusInternalServerError) - } + w.WriteHeader(int(se.Code)) w.Write(body) } diff --git a/transport/http/codec_test.go b/transport/http/codec_test.go index ca9a1ab5f..961e1269b 100644 --- a/transport/http/codec_test.go +++ b/transport/http/codec_test.go @@ -2,11 +2,12 @@ package http import ( "bytes" - "github.com/go-kratos/kratos/v2/errors" - "github.com/stretchr/testify/assert" "io/ioutil" nethttp "net/http" "testing" + + "github.com/go-kratos/kratos/v2/errors" + "github.com/stretchr/testify/assert" ) func TestDefaultRequestDecoder(t *testing.T) { @@ -46,40 +47,34 @@ func (w *mockResponseWriter) WriteHeader(statusCode int) { } type dataWithStatusCode struct { - statusCode int - A string `json:"a"` - B int64 `json:"b"` -} - -func (d *dataWithStatusCode) StatusCode() int { - return d.statusCode + A string `json:"a"` + B int64 `json:"b"` } func TestDefaultResponseEncoder(t *testing.T) { - w := &mockResponseWriter{header: make(nethttp.Header)} + w := &mockResponseWriter{StatusCode: 200, header: make(nethttp.Header)} req1 := &nethttp.Request{ Header: make(nethttp.Header), } req1.Header.Set("Content-Type", "application/json") - v1 := &dataWithStatusCode{statusCode: 201, A: "1", B: 2} + v1 := &dataWithStatusCode{A: "1", B: 2} err := DefaultResponseEncoder(w, req1, v1) assert.Nil(t, err) assert.Equal(t, "application/json", w.Header().Get("Content-Type")) - assert.Equal(t, 201, w.StatusCode) + assert.Equal(t, 200, w.StatusCode) assert.NotNil(t, w.Data) } func TestDefaultResponseEncoderWithError(t *testing.T) { w := &mockResponseWriter{header: make(nethttp.Header)} - req1 := &nethttp.Request{ + req := &nethttp.Request{ Header: make(nethttp.Header), } - req1.Header.Set("Content-Type", "application/json") + req.Header.Set("Content-Type", "application/json") - v1 := &errors.Error{Code: 511} - err := DefaultResponseEncoder(w, req1, v1) - assert.Nil(t, err) + se := &errors.Error{Code: 511} + DefaultErrorEncoder(w, req, se) assert.Equal(t, "application/json", w.Header().Get("Content-Type")) assert.Equal(t, 511, w.StatusCode) assert.NotNil(t, w.Data)