package status import ( "context" "net/http" "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/middleware" //lint:ignore SA1019 grpc "github.com/golang/protobuf/proto" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) type domainKey struct{} // HandlerFunc is middleware error handler. type HandlerFunc func(context.Context, error) error // Option is recovery option. type Option func(*options) type options struct { domain string handler HandlerFunc } // WithDomain with service domain. func WithDomain(domain string) Option { return func(o *options) { o.domain = domain } } // WithHandler with status handler. func WithHandler(h HandlerFunc) Option { return func(o *options) { o.handler = h } } // Server is an error middleware. func Server(opts ...Option) middleware.Middleware { options := options{ handler: encodeErr, } for _, o := range opts { o(&options) } return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (interface{}, error) { reply, err := handler(ctx, req) if err != nil { ctx = context.WithValue(ctx, domainKey{}, options.domain) return nil, options.handler(ctx, err) } return reply, nil } } } // Client is an error middleware. func Client(opts ...Option) middleware.Middleware { options := options{ handler: decodeErr, } for _, o := range opts { o(&options) } return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (interface{}, error) { reply, err := handler(ctx, req) if err != nil { return nil, options.handler(ctx, err) } return reply, nil } } } func encodeErr(ctx context.Context, err error) error { se := errors.FromError(err) if se.Domain == "" { se.Domain, _ = ctx.Value(domainKey{}).(string) } gs := status.Newf(httpToGRPCCode(se.Code), "%s: %s", se.Reason, se.Message) details := []proto.Message{ &errdetails.ErrorInfo{ Domain: se.Domain, Reason: se.Reason, Metadata: se.Metadata, }, } gs, err = gs.WithDetails(details...) if err != nil { return err } return gs.Err() } func decodeErr(ctx context.Context, err error) error { gs := status.Convert(err) se := &errors.Error{ Code: grpcToHTTPCode(gs.Code()), Message: gs.Message(), } for _, detail := range gs.Details() { switch d := detail.(type) { case *errdetails.ErrorInfo: se.Domain = d.Domain se.Reason = d.Reason se.Metadata = d.Metadata return se } } return se } func httpToGRPCCode(code int) codes.Code { switch code { case http.StatusBadRequest: return codes.InvalidArgument case http.StatusUnauthorized: return codes.Unauthenticated case http.StatusForbidden: return codes.PermissionDenied case http.StatusNotFound: return codes.NotFound case http.StatusConflict: return codes.Aborted case http.StatusInternalServerError: return codes.Internal case http.StatusServiceUnavailable: return codes.Unavailable } return codes.Unknown } func grpcToHTTPCode(code codes.Code) int { switch code { case codes.InvalidArgument: return http.StatusBadRequest case codes.Unauthenticated: return http.StatusUnauthorized case codes.PermissionDenied: return http.StatusForbidden case codes.NotFound: return http.StatusNotFound case codes.Aborted: return http.StatusConflict case codes.Internal: return http.StatusInternalServerError case codes.Unavailable: return http.StatusServiceUnavailable } return http.StatusInternalServerError }