From 834b781ee2119ec405472a4d617d4725c45fb4e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8C=85=E5=AD=90?= Date: Mon, 27 Feb 2023 15:43:21 +0800 Subject: [PATCH] feat: support load balance for streaming connection creation (#2669) * feat: support load balance for streaming connection creation * fix lint --- transport/grpc/client.go | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/transport/grpc/client.go b/transport/grpc/client.go index 5062286e3..b2b5c0c2b 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -146,16 +146,22 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien ints := []grpc.UnaryClientInterceptor{ unaryClientInterceptor(options.middleware, options.timeout, options.filters), } + sints := []grpc.StreamClientInterceptor{ + streamClientInterceptor(options.filters), + } + if len(options.ints) > 0 { ints = append(ints, options.ints...) } + if len(options.streamInts) > 0 { + sints = append(sints, options.streamInts...) + } grpcOpts := []grpc.DialOption{ grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, options.balancerName)), grpc.WithChainUnaryInterceptor(ints...), + grpc.WithChainStreamInterceptor(sints...), } - if len(options.streamInts) > 0 { - grpcOpts = append(grpcOpts, grpc.WithChainStreamInterceptor(options.streamInts...)) - } + if options.discovery != nil { grpcOpts = append(grpcOpts, grpc.WithResolvers( @@ -211,3 +217,17 @@ func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, f return err } } + +func streamClientInterceptor(filters []selector.NodeFilter) grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { // nolint + ctx = transport.NewClientContext(ctx, &Transport{ + endpoint: cc.Target(), + operation: method, + reqHeader: headerCarrier{}, + nodeFilters: filters, + }) + var p selector.Peer + ctx = selector.NewPeerContext(ctx, &p) + return streamer(ctx, desc, cc, method, opts...) + } +}