transport/http: uses gRPC status to the HTTP error. (#870)

* uses gRPC status to the HTTP error.
pull/873/head
Tony Chen 4 years ago committed by GitHub
parent b03c810dce
commit 7c3212c306
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      examples/helloworld/client/main.go
  2. 1
      examples/helloworld/server/main.go
  3. 121
      internal/http/http.go
  4. 55
      middleware/status/status.go
  5. 4
      middleware/status/status_test.go
  6. 82
      transport/http/client.go
  7. 63
      transport/http/handle.go

@ -25,6 +25,7 @@ func callHTTP() {
transhttp.WithMiddleware(
middleware.Chain(
recovery.Recovery(),
status.Client(),
),
),
)

@ -63,6 +63,7 @@ func main() {
httpSrv.HandlePrefix("/", pb.NewGreeterHandler(s,
http.Middleware(
middleware.Chain(
status.Server(),
logging.Server(logger),
recovery.Recovery(),
),

@ -0,0 +1,121 @@
package http
import (
"net/http"
"strings"
"google.golang.org/grpc/codes"
)
const (
baseContentType = "application"
)
var (
// HeaderAccept is accept header.
HeaderAccept = http.CanonicalHeaderKey("Accept")
// HeaderContentType is content-type header.
HeaderContentType = http.CanonicalHeaderKey("Content-Type")
// HeaderAcceptLanguage is accept-language header.
HeaderAcceptLanguage = http.CanonicalHeaderKey("Accept-Language")
)
// ContentType returns the content-type with base prefix.
func ContentType(subtype string) string {
return strings.Join([]string{baseContentType, subtype}, "/")
}
// ContentSubtype returns the content-subtype for the given content-type. The
// given content-type must be a valid content-type that starts with
// but no content-subtype will be returned.
//
// contentType is assumed to be lowercase already.
func ContentSubtype(contentType string) string {
if contentType == baseContentType {
return ""
}
if !strings.HasPrefix(contentType, baseContentType) {
return ""
}
switch contentType[len(baseContentType)] {
case '/', ';':
if i := strings.Index(contentType, ";"); i != -1 {
return contentType[len(baseContentType)+1 : i]
}
return contentType[len(baseContentType)+1:]
default:
return ""
}
}
// GRPCCodeFromStatus converts a HTTP error code into the corresponding gRPC response status.
// See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto
func GRPCCodeFromStatus(code int) codes.Code {
switch code {
case http.StatusOK:
return codes.OK
case http.StatusBadRequest:
return codes.InvalidArgument
case http.StatusUnauthorized:
return codes.Unauthenticated
case http.StatusForbidden:
return codes.PermissionDenied
case http.StatusNotFound:
return codes.NotFound
case http.StatusConflict:
return codes.Aborted
case http.StatusTooManyRequests:
return codes.ResourceExhausted
case http.StatusInternalServerError:
return codes.Internal
case http.StatusNotImplemented:
return codes.Unimplemented
case http.StatusServiceUnavailable:
return codes.Unavailable
case http.StatusGatewayTimeout:
return codes.DeadlineExceeded
}
return codes.Unknown
}
// StatusFromGRPCCode converts a gRPC error code into the corresponding HTTP response status.
// See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto
func StatusFromGRPCCode(code codes.Code) int {
switch code {
case codes.OK:
return http.StatusOK
case codes.Canceled:
return http.StatusRequestTimeout
case codes.Unknown:
return http.StatusInternalServerError
case codes.InvalidArgument:
return http.StatusBadRequest
case codes.DeadlineExceeded:
return http.StatusGatewayTimeout
case codes.NotFound:
return http.StatusNotFound
case codes.AlreadyExists:
return http.StatusConflict
case codes.PermissionDenied:
return http.StatusForbidden
case codes.Unauthenticated:
return http.StatusUnauthorized
case codes.ResourceExhausted:
return http.StatusTooManyRequests
case codes.FailedPrecondition:
return http.StatusBadRequest
case codes.Aborted:
return http.StatusConflict
case codes.OutOfRange:
return http.StatusBadRequest
case codes.Unimplemented:
return http.StatusNotImplemented
case codes.Internal:
return http.StatusInternalServerError
case codes.Unavailable:
return http.StatusServiceUnavailable
case codes.DataLoss:
return http.StatusInternalServerError
}
return http.StatusInternalServerError
}

@ -2,15 +2,14 @@ package status
import (
"context"
"net/http"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/internal/http"
"github.com/go-kratos/kratos/v2/middleware"
//lint:ignore SA1019 grpc
"github.com/golang/protobuf/proto"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
@ -34,7 +33,7 @@ func WithHandler(h HandlerFunc) Option {
// Server is an error middleware.
func Server(opts ...Option) middleware.Middleware {
options := options{
handler: encodeErr,
handler: encodeError,
}
for _, o := range opts {
o(&options)
@ -53,7 +52,7 @@ func Server(opts ...Option) middleware.Middleware {
// Client is an error middleware.
func Client(opts ...Option) middleware.Middleware {
options := options{
handler: decodeErr,
handler: decodeError,
}
for _, o := range opts {
o(&options)
@ -69,7 +68,7 @@ func Client(opts ...Option) middleware.Middleware {
}
}
func encodeErr(ctx context.Context, err error) error {
func encodeError(ctx context.Context, err error) error {
var details []proto.Message
if target := new(errors.ErrorInfo); errors.As(err, &target) {
details = append(details, &errdetails.ErrorInfo{
@ -79,7 +78,7 @@ func encodeErr(ctx context.Context, err error) error {
})
}
es := errors.FromError(err)
gs := status.New(httpToGRPCCode(es.Code), es.Message)
gs := status.New(http.GRPCCodeFromStatus(es.Code), es.Message)
gs, err = gs.WithDetails(details...)
if err != nil {
return err
@ -87,9 +86,9 @@ func encodeErr(ctx context.Context, err error) error {
return gs.Err()
}
func decodeErr(ctx context.Context, err error) error {
func decodeError(ctx context.Context, err error) error {
gs := status.Convert(err)
code := grpcToHTTPCode(gs.Code())
code := http.StatusFromGRPCCode(gs.Code())
message := gs.Message()
for _, detail := range gs.Details() {
switch d := detail.(type) {
@ -104,43 +103,3 @@ func decodeErr(ctx context.Context, err error) error {
}
return errors.New(code, message)
}
func httpToGRPCCode(code int) codes.Code {
switch code {
case http.StatusBadRequest:
return codes.InvalidArgument
case http.StatusUnauthorized:
return codes.Unauthenticated
case http.StatusForbidden:
return codes.PermissionDenied
case http.StatusNotFound:
return codes.NotFound
case http.StatusConflict:
return codes.Aborted
case http.StatusInternalServerError:
return codes.Internal
case http.StatusServiceUnavailable:
return codes.Unavailable
}
return codes.Unknown
}
func grpcToHTTPCode(code codes.Code) int {
switch code {
case codes.InvalidArgument:
return http.StatusBadRequest
case codes.Unauthenticated:
return http.StatusUnauthorized
case codes.PermissionDenied:
return http.StatusForbidden
case codes.NotFound:
return http.StatusNotFound
case codes.Aborted:
return http.StatusConflict
case codes.Internal:
return http.StatusInternalServerError
case codes.Unavailable:
return http.StatusServiceUnavailable
}
return http.StatusInternalServerError
}

@ -9,8 +9,8 @@ import (
func TestErrEncoder(t *testing.T) {
err := errors.BadRequest("test", "invalid_argument", "format")
en := encodeErr(context.Background(), err)
de := decodeErr(context.Background(), en)
en := encodeError(context.Background(), err)
de := decodeError(context.Background(), en)
if !errors.IsBadRequest(de) {
t.Errorf("expected %v got %v", err, de)
}

@ -8,10 +8,17 @@ import (
"github.com/go-kratos/kratos/v2/encoding"
"github.com/go-kratos/kratos/v2/errors"
xhttp "github.com/go-kratos/kratos/v2/internal/http"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
spb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
)
// DecodeErrorFunc is decode error func.
type DecodeErrorFunc func(ctx context.Context, w *http.Response) error
// ClientOption is HTTP client option.
type ClientOption func(*clientOptions)
@ -45,11 +52,12 @@ func WithMiddleware(m middleware.Middleware) ClientOption {
// Client is a HTTP transport client.
type clientOptions struct {
ctx context.Context
timeout time.Duration
userAgent string
transport http.RoundTripper
middleware middleware.Middleware
ctx context.Context
timeout time.Duration
userAgent string
transport http.RoundTripper
errorDecoder DecodeErrorFunc
middleware middleware.Middleware
}
// NewClient returns an HTTP client.
@ -64,26 +72,29 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*http.Client, error)
// NewTransport creates an http.RoundTripper.
func NewTransport(ctx context.Context, opts ...ClientOption) (http.RoundTripper, error) {
options := &clientOptions{
ctx: ctx,
timeout: 500 * time.Millisecond,
transport: http.DefaultTransport,
ctx: ctx,
timeout: 500 * time.Millisecond,
transport: http.DefaultTransport,
errorDecoder: checkResponse,
}
for _, o := range opts {
o(options)
}
return &baseTransport{
middleware: options.middleware,
userAgent: options.userAgent,
timeout: options.timeout,
base: options.transport,
errorDecoder: options.errorDecoder,
middleware: options.middleware,
userAgent: options.userAgent,
timeout: options.timeout,
base: options.transport,
}, nil
}
type baseTransport struct {
userAgent string
timeout time.Duration
base http.RoundTripper
middleware middleware.Middleware
userAgent string
timeout time.Duration
base http.RoundTripper
errorDecoder DecodeErrorFunc
middleware middleware.Middleware
}
func (t *baseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
@ -96,7 +107,14 @@ func (t *baseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
defer cancel()
h := func(ctx context.Context, in interface{}) (interface{}, error) {
return t.base.RoundTrip(in.(*http.Request))
res, err := t.base.RoundTrip(in.(*http.Request))
if err != nil {
return nil, err
}
if err := t.errorDecoder(ctx, res); err != nil {
return nil, err
}
return res, nil
}
if t.middleware != nil {
h = t.middleware(h)
@ -115,19 +133,7 @@ func Do(client *http.Client, req *http.Request, target interface{}) error {
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode > 299 {
se := &errors.Error{Code: 500}
if err := decodeResponse(res, se); err != nil {
return err
}
return se
}
return decodeResponse(res, target)
}
func decodeResponse(res *http.Response, target interface{}) error {
subtype := contentSubtype(res.Header.Get(contentTypeHeader))
subtype := xhttp.ContentSubtype(res.Header.Get(xhttp.HeaderContentType))
codec := encoding.GetCodec(subtype)
if codec == nil {
codec = encoding.GetCodec("json")
@ -138,3 +144,19 @@ func decodeResponse(res *http.Response, target interface{}) error {
}
return codec.Unmarshal(data, target)
}
// checkResponse returns an error (of type *Error) if the response
// status code is not 2xx.
func checkResponse(ctx context.Context, res *http.Response) error {
if res.StatusCode >= 200 && res.StatusCode <= 299 {
return nil
}
defer res.Body.Close()
if data, err := ioutil.ReadAll(res.Body); err == nil {
st := new(spb.Status)
if err = protojson.Unmarshal(data, st); err == nil {
return status.ErrorProto(st)
}
}
return errors.New(res.StatusCode, "")
}

@ -3,26 +3,18 @@ package http
import (
"io/ioutil"
"net/http"
"strings"
"github.com/go-kratos/kratos/v2/encoding"
"github.com/go-kratos/kratos/v2/encoding/json"
"github.com/go-kratos/kratos/v2/errors"
xhttp "github.com/go-kratos/kratos/v2/internal/http"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport/http/binding"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
)
const (
// SupportPackageIsVersion1 These constants should not be referenced from any other code.
SupportPackageIsVersion1 = true
baseContentType = "application"
)
var (
acceptHeader = http.CanonicalHeaderKey("Accept")
contentTypeHeader = http.CanonicalHeaderKey("Content-Type")
)
// SupportPackageIsVersion1 These constants should not be referenced from any other code.
const SupportPackageIsVersion1 = true
// DecodeRequestFunc is decode request func.
type DecodeRequestFunc func(*http.Request, interface{}) error
@ -83,7 +75,7 @@ func Middleware(m middleware.Middleware) HandleOption {
// decodeRequest decodes the request body to object.
func decodeRequest(req *http.Request, v interface{}) error {
subtype := contentSubtype(req.Header.Get(contentTypeHeader))
subtype := xhttp.ContentSubtype(req.Header.Get(xhttp.HeaderContentType))
if codec := encoding.GetCodec(subtype); codec != nil {
data, err := ioutil.ReadAll(req.Body)
if err != nil {
@ -101,26 +93,29 @@ func encodeResponse(w http.ResponseWriter, r *http.Request, v interface{}) error
if err != nil {
return err
}
w.Header().Set(contentTypeHeader, contentType(codec.Name()))
w.Header().Set(xhttp.HeaderContentType, xhttp.ContentType(codec.Name()))
_, _ = w.Write(data)
return nil
}
// encodeError encodes the error to the HTTP response.
func encodeError(w http.ResponseWriter, r *http.Request, err error) {
se := errors.FromError(err)
codec := codecForRequest(r)
data, _ := codec.Marshal(se)
w.Header().Set(contentTypeHeader, contentType(codec.Name()))
w.WriteHeader(se.Code)
_, _ = w.Write(data)
st, _ := status.FromError(err)
data, err := protojson.Marshal(st.Proto())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set(xhttp.HeaderContentType, "application/json; charset=utf-8")
w.WriteHeader(xhttp.StatusFromGRPCCode(st.Code()))
w.Write(data)
}
// codecForRequest get encoding.Codec via http.Request
func codecForRequest(r *http.Request) encoding.Codec {
var codec encoding.Codec
for _, accept := range r.Header[acceptHeader] {
if codec = encoding.GetCodec(contentSubtype(accept)); codec != nil {
for _, accept := range r.Header[xhttp.HeaderAccept] {
if codec = encoding.GetCodec(xhttp.ContentSubtype(accept)); codec != nil {
break
}
}
@ -129,25 +124,3 @@ func codecForRequest(r *http.Request) encoding.Codec {
}
return codec
}
func contentType(subtype string) string {
return strings.Join([]string{baseContentType, subtype}, "/")
}
func contentSubtype(contentType string) string {
if contentType == baseContentType {
return ""
}
if !strings.HasPrefix(contentType, baseContentType) {
return ""
}
switch contentType[len(baseContentType)] {
case '/', ';':
if i := strings.Index(contentType, ";"); i != -1 {
return contentType[len(baseContentType)+1 : i]
}
return contentType[len(baseContentType)+1:]
default:
return ""
}
}

Loading…
Cancel
Save