From e7ddc1ba1e9ba23d0880b76f4908300e885362cf Mon Sep 17 00:00:00 2001 From: Tony Chen Date: Fri, 28 May 2021 19:47:24 +0800 Subject: [PATCH] Change the default func to public (#966) --- transport/http/client.go | 43 ++++++++++++++++-------------------- transport/http/handle.go | 47 +++++++++++++++++++--------------------- transport/http/server.go | 6 ++--- 3 files changed, 44 insertions(+), 52 deletions(-) diff --git a/transport/http/client.go b/transport/http/client.go index 21bf5c852..8198fd3e9 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -20,7 +20,7 @@ import ( "github.com/go-kratos/kratos/v2/transport/http/balancer/random" ) -// Client is http client +// Client is an HTTP client. type Client struct { cc *http.Client r *resolver @@ -36,11 +36,6 @@ type Client struct { discovery registry.Discovery } -const ( - // errNodeNotFound represents service node not found. - errNodeNotFound = "NODE_NOT_FOUND" -) - // DecodeErrorFunc is decode error func. type DecodeErrorFunc func(ctx context.Context, res *http.Response) error @@ -132,7 +127,7 @@ func WithBalancer(b balancer.Balancer) ClientOption { } } -// Client is a HTTP transport client. +// Client is an HTTP transport client. type clientOptions struct { ctx context.Context transport http.RoundTripper @@ -154,11 +149,10 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { ctx: ctx, scheme: "http", timeout: 1 * time.Second, - encoder: defaultRequestEncoder, - decoder: defaultResponseDecoder, - errorDecoder: defaultErrorDecoder, + encoder: DefaultRequestEncoder, + decoder: DefaultResponseDecoder, + errorDecoder: DefaultErrorDecoder, transport: http.DefaultTransport, - discovery: nil, balancer: random.New(), } for _, o := range opts { @@ -259,18 +253,18 @@ func (client *Client) invoke(ctx context.Context, req *http.Request, args interf if client.r != nil { nodes := client.r.fetch(ctx) if len(nodes) == 0 { - return nil, errors.ServiceUnavailable(errNodeNotFound, "fetch error") + return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", "fetch error") } var node *registry.ServiceInstance var err error node, done, err = client.b.Pick(ctx, c.pathPattern, nodes) if err != nil { - return nil, errors.ServiceUnavailable(errNodeNotFound, err.Error()) + return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error()) } req = req.Clone(ctx) addr, err := parseEndpoint(client.scheme, node.Endpoints) if err != nil { - return nil, errors.ServiceUnavailable(errNodeNotFound, err.Error()) + return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error()) } req.URL.Host = addr } @@ -317,7 +311,8 @@ func (client *Client) do(ctx context.Context, req *http.Request, c callInfo) (*h return resp, nil } -func defaultRequestEncoder(ctx context.Context, in interface{}) (string, []byte, error) { +// DefaultRequestEncoder is an HTTP request encoder. +func DefaultRequestEncoder(ctx context.Context, in interface{}) (string, []byte, error) { body, err := encoding.GetCodec("json").Marshal(in) if err != nil { return "", nil, err @@ -325,16 +320,18 @@ func defaultRequestEncoder(ctx context.Context, in interface{}) (string, []byte, return "application/json", body, err } -func defaultResponseDecoder(ctx context.Context, res *http.Response, v interface{}) error { +// DefaultResponseDecoder is an HTTP response decoder. +func DefaultResponseDecoder(ctx context.Context, res *http.Response, v interface{}) error { defer res.Body.Close() data, err := ioutil.ReadAll(res.Body) if err != nil { return err } - return codecForResponse(res).Unmarshal(data, v) + return CodecForResponse(res).Unmarshal(data, v) } -func defaultErrorDecoder(ctx context.Context, res *http.Response) error { +// DefaultErrorDecoder is an HTTP error decoder. +func DefaultErrorDecoder(ctx context.Context, res *http.Response) error { if res.StatusCode >= 200 && res.StatusCode <= 299 { return nil } @@ -342,20 +339,18 @@ func defaultErrorDecoder(ctx context.Context, res *http.Response) error { data, err := ioutil.ReadAll(res.Body) if err == nil { e := new(errors.Error) - if err = codecForResponse(res).Unmarshal(data, e); err == nil { + if err = CodecForResponse(res).Unmarshal(data, e); err == nil { return e } } return errors.Errorf(res.StatusCode, errors.UnknownReason, err.Error()) } -func codecForResponse(r *http.Response) encoding.Codec { +// CodecForResponse get encoding.Codec via http.Response +func CodecForResponse(r *http.Response) encoding.Codec { codec := encoding.GetCodec(httputil.ContentSubtype("Content-Type")) if codec != nil { return codec } - if codec == nil { - codec = encoding.GetCodec("json") - } - return codec + return encoding.GetCodec("json") } diff --git a/transport/http/handle.go b/transport/http/handle.go index e5938b0e5..6296131ed 100644 --- a/transport/http/handle.go +++ b/transport/http/handle.go @@ -43,9 +43,9 @@ type HandleOptions struct { // Deprecated: use NewHandler instead. func DefaultHandleOptions() HandleOptions { return HandleOptions{ - Decode: defaultRequestDecoder, - Encode: defaultResponseEncoder, - Error: defaultErrorEncoder, + Decode: DefaultRequestDecoder, + Encode: DefaultResponseEncoder, + Error: DefaultErrorEncoder, Middleware: recovery.Recovery(), } } @@ -86,7 +86,7 @@ type Handler struct { opts HandleOptions } -// NewHandler new a HTTP handler. +// NewHandler new an HTTP handler. func NewHandler(handler interface{}, opts ...HandleOption) http.Handler { if err := validateHandler(handler); err != nil { panic(err) @@ -150,11 +150,11 @@ func validateHandler(handler interface{}) error { return nil } -// defaultRequestDecoder decodes the request body to object. -func defaultRequestDecoder(req *http.Request, v interface{}) error { - subtype := httputil.ContentSubtype(req.Header.Get("Content-Type")) - if codec := encoding.GetCodec(subtype); codec != nil { - data, err := ioutil.ReadAll(req.Body) +// DefaultRequestDecoder decodes the request body to object. +func DefaultRequestDecoder(r *http.Request, v interface{}) error { + codec, ok := CodecForRequest(r, "Content-Type") + if ok { + data, err := ioutil.ReadAll(r.Body) if err != nil { return errors.BadRequest("CODEC", err.Error()) } @@ -162,16 +162,16 @@ func defaultRequestDecoder(req *http.Request, v interface{}) error { return errors.BadRequest("CODEC", err.Error()) } } else { - if err := binding.BindForm(req, v); err != nil { + if err := binding.BindForm(r, v); err != nil { return errors.BadRequest("CODEC", err.Error()) } } return nil } -// defaultResponseEncoder encodes the object to the HTTP response. -func defaultResponseEncoder(w http.ResponseWriter, r *http.Request, v interface{}) error { - codec := CodecForRequest(r) +// DefaultResponseEncoder encodes the object to the HTTP response. +func DefaultResponseEncoder(w http.ResponseWriter, r *http.Request, v interface{}) error { + codec, _ := CodecForRequest(r, "Accept") data, err := codec.Marshal(v) if err != nil { return err @@ -186,9 +186,9 @@ func defaultResponseEncoder(w http.ResponseWriter, r *http.Request, v interface{ return nil } -// defaultErrorEncoder encodes the error to the HTTP response. -func defaultErrorEncoder(w http.ResponseWriter, r *http.Request, se error) { - codec := CodecForRequest(r) +// DefaultErrorEncoder encodes the error to the HTTP response. +func DefaultErrorEncoder(w http.ResponseWriter, r *http.Request, se error) { + codec, _ := CodecForRequest(r, "Accept") body, err := codec.Marshal(se) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -206,15 +206,12 @@ func defaultErrorEncoder(w http.ResponseWriter, r *http.Request, se error) { } // CodecForRequest get encoding.Codec via http.Request -func CodecForRequest(r *http.Request) encoding.Codec { - var codec encoding.Codec - for _, accept := range r.Header["Accept"] { - if codec = encoding.GetCodec(httputil.ContentSubtype(accept)); codec != nil { - break +func CodecForRequest(r *http.Request, name string) (encoding.Codec, bool) { + for _, accept := range r.Header[name] { + codec := encoding.GetCodec(httputil.ContentSubtype(accept)) + if codec != nil { + return codec, true } } - if codec == nil { - codec = encoding.GetCodec("json") - } - return codec + return encoding.GetCodec("json"), false } diff --git a/transport/http/server.go b/transport/http/server.go index 4c97a8033..ae6af783c 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -19,7 +19,7 @@ import ( var _ transport.Server = (*Server)(nil) var _ transport.Endpointer = (*Server)(nil) -// ServerOption is HTTP server option. +// ServerOption is an HTTP server option. type ServerOption func(*Server) // Network with server network. @@ -50,7 +50,7 @@ func Logger(logger log.Logger) ServerOption { } } -// Server is a HTTP server wrapper. +// Server is an HTTP server wrapper. type Server struct { *http.Server lis net.Listener @@ -61,7 +61,7 @@ type Server struct { log *log.Helper } -// NewServer creates a HTTP server by options. +// NewServer creates an HTTP server by options. func NewServer(opts ...ServerOption) *Server { srv := &Server{ network: "tcp",