diff --git a/transport/grpc/client.go b/transport/grpc/client.go index 5062286e3..b2b5c0c2b 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -146,16 +146,22 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien ints := []grpc.UnaryClientInterceptor{ unaryClientInterceptor(options.middleware, options.timeout, options.filters), } + sints := []grpc.StreamClientInterceptor{ + streamClientInterceptor(options.filters), + } + if len(options.ints) > 0 { ints = append(ints, options.ints...) } + if len(options.streamInts) > 0 { + sints = append(sints, options.streamInts...) + } grpcOpts := []grpc.DialOption{ grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, options.balancerName)), grpc.WithChainUnaryInterceptor(ints...), + grpc.WithChainStreamInterceptor(sints...), } - if len(options.streamInts) > 0 { - grpcOpts = append(grpcOpts, grpc.WithChainStreamInterceptor(options.streamInts...)) - } + if options.discovery != nil { grpcOpts = append(grpcOpts, grpc.WithResolvers( @@ -211,3 +217,17 @@ func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, f return err } } + +func streamClientInterceptor(filters []selector.NodeFilter) grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { // nolint + ctx = transport.NewClientContext(ctx, &Transport{ + endpoint: cc.Target(), + operation: method, + reqHeader: headerCarrier{}, + nodeFilters: filters, + }) + var p selector.Peer + ctx = selector.NewPeerContext(ctx, &p) + return streamer(ctx, desc, cc, method, opts...) + } +}