diff --git a/transport/grpc/client.go b/transport/grpc/client.go index 9d7ede93f..61ca85d16 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -6,6 +6,7 @@ import ( "fmt" "time" + "github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/selector" @@ -87,6 +88,13 @@ func WithFilter(filters ...selector.Filter) ClientOption { } } +// WithLogger with logger +func WithLogger(log log.Logger) ClientOption { + return func(o *clientOptions) { + o.logger = log + } +} + // clientOptions is gRPC Client type clientOptions struct { endpoint string @@ -98,6 +106,7 @@ type clientOptions struct { grpcOpts []grpc.DialOption balancerName string filters []selector.Filter + logger log.Logger } // Dial returns a GRPC connection. @@ -114,6 +123,7 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien options := clientOptions{ timeout: 2000 * time.Millisecond, balancerName: wrr.Name, + logger: log.DefaultLogger, } for _, o := range opts { o(&options) @@ -129,7 +139,14 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien grpc.WithChainUnaryInterceptor(ints...), } if options.discovery != nil { - grpcOpts = append(grpcOpts, grpc.WithResolvers(discovery.NewBuilder(options.discovery, discovery.WithInsecure(insecure)))) + grpcOpts = append(grpcOpts, + grpc.WithResolvers( + discovery.NewBuilder( + options.discovery, + discovery.WithInsecure(insecure), + discovery.WithLogger(options.logger), + discovery.WithTimeout(options.timeout), + ))) } if insecure { grpcOpts = append(grpcOpts, grpc.WithInsecure()) diff --git a/transport/grpc/client_test.go b/transport/grpc/client_test.go index f618d9bf1..3c510afff 100644 --- a/transport/grpc/client_test.go +++ b/transport/grpc/client_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/registry" "github.com/stretchr/testify/assert" @@ -59,6 +60,13 @@ func TestWithTLSConfig(t *testing.T) { assert.Equal(t, v, o.tlsConf) } +func TestWithLogger(t *testing.T) { + o := &clientOptions{} + v := log.DefaultLogger + WithLogger(v)(o) + assert.Equal(t, v, o.logger) +} + func EmptyMiddleware() middleware.Middleware { return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (reply interface{}, err error) { @@ -112,3 +120,18 @@ func TestDial(t *testing.T) { WithOptions(v...)(o) assert.Equal(t, v, o.grpcOpts) } + +func TestDialConn(t *testing.T) { + _, err := dial( + context.Background(), + true, + WithDiscovery(&mockRegistry{}), + WithTimeout(10*time.Second), + WithLogger(log.DefaultLogger), + WithEndpoint("abc"), + WithMiddleware(EmptyMiddleware()), + ) + if err != nil { + t.Error(err) + } +}