diff --git a/transport/grpc/client.go b/transport/grpc/client.go index 0d92f90f9..5062286e3 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -82,6 +82,13 @@ func WithUnaryInterceptor(in ...grpc.UnaryClientInterceptor) ClientOption { } } +// WithStreamInterceptor returns a DialOption that specifies the interceptor for streaming RPCs. +func WithStreamInterceptor(in ...grpc.StreamClientInterceptor) ClientOption { + return func(o *clientOptions) { + o.streamInts = in + } +} + // WithOptions with gRPC options. func WithOptions(opts ...grpc.DialOption) ClientOption { return func(o *clientOptions) { @@ -111,6 +118,7 @@ type clientOptions struct { discovery registry.Discovery middleware []middleware.Middleware ints []grpc.UnaryClientInterceptor + streamInts []grpc.StreamClientInterceptor grpcOpts []grpc.DialOption balancerName string filters []selector.NodeFilter @@ -145,6 +153,9 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, options.balancerName)), grpc.WithChainUnaryInterceptor(ints...), } + if len(options.streamInts) > 0 { + grpcOpts = append(grpcOpts, grpc.WithChainStreamInterceptor(options.streamInts...)) + } if options.discovery != nil { grpcOpts = append(grpcOpts, grpc.WithResolvers(