feat:add stream interceptor use ctx encapsulation (#1770)
* feat:add stream interceptor use ctx encapsulation * add reply headerpull/1776/head
parent
0965bf8e22
commit
89583885e4
@ -0,0 +1,83 @@ |
|||||||
|
package grpc |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
|
||||||
|
ic "github.com/go-kratos/kratos/v2/internal/context" |
||||||
|
"github.com/go-kratos/kratos/v2/middleware" |
||||||
|
"github.com/go-kratos/kratos/v2/transport" |
||||||
|
"google.golang.org/grpc" |
||||||
|
grpcmd "google.golang.org/grpc/metadata" |
||||||
|
) |
||||||
|
|
||||||
|
// unaryServerInterceptor is a gRPC unary server interceptor
|
||||||
|
func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor { |
||||||
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { |
||||||
|
ctx, cancel := ic.Merge(ctx, s.baseCtx) |
||||||
|
defer cancel() |
||||||
|
md, _ := grpcmd.FromIncomingContext(ctx) |
||||||
|
replyHeader := grpcmd.MD{} |
||||||
|
ctx = transport.NewServerContext(ctx, &Transport{ |
||||||
|
endpoint: s.endpoint.String(), |
||||||
|
operation: info.FullMethod, |
||||||
|
reqHeader: headerCarrier(md), |
||||||
|
replyHeader: headerCarrier(replyHeader), |
||||||
|
}) |
||||||
|
if s.timeout > 0 { |
||||||
|
ctx, cancel = context.WithTimeout(ctx, s.timeout) |
||||||
|
defer cancel() |
||||||
|
} |
||||||
|
h := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||||
|
return handler(ctx, req) |
||||||
|
} |
||||||
|
if len(s.middleware) > 0 { |
||||||
|
h = middleware.Chain(s.middleware...)(h) |
||||||
|
} |
||||||
|
reply, err := h(ctx, req) |
||||||
|
if len(replyHeader) > 0 { |
||||||
|
_ = grpc.SetHeader(ctx, replyHeader) |
||||||
|
} |
||||||
|
return reply, err |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// wrappedStream is rewrite grpc stream's context
|
||||||
|
type wrappedStream struct { |
||||||
|
grpc.ServerStream |
||||||
|
ctx context.Context |
||||||
|
} |
||||||
|
|
||||||
|
func NewWrappedStream(ctx context.Context, stream grpc.ServerStream) grpc.ServerStream { |
||||||
|
return &wrappedStream{ |
||||||
|
ServerStream: stream, |
||||||
|
ctx: ctx, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (w *wrappedStream) Context() context.Context { |
||||||
|
return w.ctx |
||||||
|
} |
||||||
|
|
||||||
|
// streamServerInterceptor is a gRPC stream server interceptor
|
||||||
|
func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor { |
||||||
|
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { |
||||||
|
ctx, cancel := ic.Merge(ss.Context(), s.baseCtx) |
||||||
|
defer cancel() |
||||||
|
md, _ := grpcmd.FromIncomingContext(ctx) |
||||||
|
replyHeader := grpcmd.MD{} |
||||||
|
ctx = transport.NewServerContext(ctx, &Transport{ |
||||||
|
endpoint: s.endpoint.String(), |
||||||
|
operation: info.FullMethod, |
||||||
|
reqHeader: headerCarrier(md), |
||||||
|
replyHeader: headerCarrier(replyHeader), |
||||||
|
}) |
||||||
|
|
||||||
|
ws := NewWrappedStream(ctx, ss) |
||||||
|
|
||||||
|
err := handler(srv, ws) |
||||||
|
if len(replyHeader) > 0 { |
||||||
|
_ = grpc.SetHeader(ctx, replyHeader) |
||||||
|
} |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
Loading…
Reference in new issue