feat:add stream interceptor use ctx encapsulation (#1770)
* feat:add stream interceptor use ctx encapsulation * add reply headerpull/1776/head
parent
0965bf8e22
commit
89583885e4
@ -0,0 +1,83 @@ |
||||
package grpc |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
ic "github.com/go-kratos/kratos/v2/internal/context" |
||||
"github.com/go-kratos/kratos/v2/middleware" |
||||
"github.com/go-kratos/kratos/v2/transport" |
||||
"google.golang.org/grpc" |
||||
grpcmd "google.golang.org/grpc/metadata" |
||||
) |
||||
|
||||
// unaryServerInterceptor is a gRPC unary server interceptor
|
||||
func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor { |
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { |
||||
ctx, cancel := ic.Merge(ctx, s.baseCtx) |
||||
defer cancel() |
||||
md, _ := grpcmd.FromIncomingContext(ctx) |
||||
replyHeader := grpcmd.MD{} |
||||
ctx = transport.NewServerContext(ctx, &Transport{ |
||||
endpoint: s.endpoint.String(), |
||||
operation: info.FullMethod, |
||||
reqHeader: headerCarrier(md), |
||||
replyHeader: headerCarrier(replyHeader), |
||||
}) |
||||
if s.timeout > 0 { |
||||
ctx, cancel = context.WithTimeout(ctx, s.timeout) |
||||
defer cancel() |
||||
} |
||||
h := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
return handler(ctx, req) |
||||
} |
||||
if len(s.middleware) > 0 { |
||||
h = middleware.Chain(s.middleware...)(h) |
||||
} |
||||
reply, err := h(ctx, req) |
||||
if len(replyHeader) > 0 { |
||||
_ = grpc.SetHeader(ctx, replyHeader) |
||||
} |
||||
return reply, err |
||||
} |
||||
} |
||||
|
||||
// wrappedStream is rewrite grpc stream's context
|
||||
type wrappedStream struct { |
||||
grpc.ServerStream |
||||
ctx context.Context |
||||
} |
||||
|
||||
func NewWrappedStream(ctx context.Context, stream grpc.ServerStream) grpc.ServerStream { |
||||
return &wrappedStream{ |
||||
ServerStream: stream, |
||||
ctx: ctx, |
||||
} |
||||
} |
||||
|
||||
func (w *wrappedStream) Context() context.Context { |
||||
return w.ctx |
||||
} |
||||
|
||||
// streamServerInterceptor is a gRPC stream server interceptor
|
||||
func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor { |
||||
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { |
||||
ctx, cancel := ic.Merge(ss.Context(), s.baseCtx) |
||||
defer cancel() |
||||
md, _ := grpcmd.FromIncomingContext(ctx) |
||||
replyHeader := grpcmd.MD{} |
||||
ctx = transport.NewServerContext(ctx, &Transport{ |
||||
endpoint: s.endpoint.String(), |
||||
operation: info.FullMethod, |
||||
reqHeader: headerCarrier(md), |
||||
replyHeader: headerCarrier(replyHeader), |
||||
}) |
||||
|
||||
ws := NewWrappedStream(ctx, ss) |
||||
|
||||
err := handler(srv, ws) |
||||
if len(replyHeader) > 0 { |
||||
_ = grpc.SetHeader(ctx, replyHeader) |
||||
} |
||||
return err |
||||
} |
||||
} |
Loading…
Reference in new issue