package grpc import ( "context" ic "github.com/go-kratos/kratos/v2/internal/context" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" "google.golang.org/grpc" grpcmd "google.golang.org/grpc/metadata" ) // unaryServerInterceptor is a gRPC unary server interceptor func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { ctx, cancel := ic.Merge(ctx, s.baseCtx) defer cancel() md, _ := grpcmd.FromIncomingContext(ctx) replyHeader := grpcmd.MD{} tr := &Transport{ operation: info.FullMethod, reqHeader: headerCarrier(md), replyHeader: headerCarrier(replyHeader), } if s.endpoint != nil { tr.endpoint = s.endpoint.String() } ctx = transport.NewServerContext(ctx, tr) if s.timeout > 0 { ctx, cancel = context.WithTimeout(ctx, s.timeout) defer cancel() } h := func(ctx context.Context, req interface{}) (interface{}, error) { return handler(ctx, req) } if next := s.middleware.Match(tr.Operation()); len(next) > 0 { h = middleware.Chain(next...)(h) } reply, err := h(ctx, req) if len(replyHeader) > 0 { _ = grpc.SetHeader(ctx, replyHeader) } return reply, err } } // wrappedStream is rewrite grpc stream's context type wrappedStream struct { grpc.ServerStream ctx context.Context } func NewWrappedStream(ctx context.Context, stream grpc.ServerStream) grpc.ServerStream { return &wrappedStream{ ServerStream: stream, ctx: ctx, } } func (w *wrappedStream) Context() context.Context { return w.ctx } // streamServerInterceptor is a gRPC stream server interceptor func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor { return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { ctx, cancel := ic.Merge(ss.Context(), s.baseCtx) defer cancel() md, _ := grpcmd.FromIncomingContext(ctx) replyHeader := grpcmd.MD{} ctx = transport.NewServerContext(ctx, &Transport{ endpoint: s.endpoint.String(), operation: info.FullMethod, reqHeader: headerCarrier(md), replyHeader: headerCarrier(replyHeader), }) ws := NewWrappedStream(ctx, ss) err := handler(srv, ws) if len(replyHeader) > 0 { _ = grpc.SetHeader(ctx, replyHeader) } return err } }