diff --git a/pkg/net/rpc/warden/CHANGELOG.md b/pkg/net/rpc/warden/CHANGELOG.md index 8918b7f5f..7bd96bbc7 100644 --- a/pkg/net/rpc/warden/CHANGELOG.md +++ b/pkg/net/rpc/warden/CHANGELOG.md @@ -1,5 +1,8 @@ ### net/rpc/warden +##### Version 1.1.20 +1. client增加timeoutCallOpt强制覆盖每次请求的timeout + ##### Version 1.1.19 1. 升级grpc至1.22.0 2. client增加keepAlive选项 diff --git a/pkg/net/rpc/warden/client.go b/pkg/net/rpc/warden/client.go index ec5ec4bca..93b9262eb 100644 --- a/pkg/net/rpc/warden/client.go +++ b/pkg/net/rpc/warden/client.go @@ -77,6 +77,16 @@ type Client struct { handlers []grpc.UnaryClientInterceptor } +type TimeOutCallOption struct { + *grpc.EmptyCallOption + Timeout time.Duration +} + +// WithTimeoutCallOption can override the timeout in ctx and the timeout in the configuration file +func WithTimeoutCallOption(timeout time.Duration) *TimeOutCallOption { + return &TimeOutCallOption{&grpc.EmptyCallOption{}, timeout} +} + // handle returns a new unary client interceptor for OpenTracing\Logging\LinkTimeout. func (c *Client) handle() grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) { @@ -110,7 +120,20 @@ func (c *Client) handle() grpc.UnaryClientInterceptor { return } defer onBreaker(brk, &err) - _, ctx, cancel = conf.Timeout.Shrink(ctx) + var timeOpt *TimeOutCallOption + for _, opt := range opts { + var tok bool + timeOpt, tok = opt.(*TimeOutCallOption) + if tok { + break + } + } + if timeOpt != nil && timeOpt.Timeout > 0 { + ctx, cancel = context.WithTimeout(nmd.WithContext(ctx), timeOpt.Timeout) + } else { + _, ctx, cancel = conf.Timeout.Shrink(ctx) + } + defer cancel() nmd.Range(ctx, func(key string, value interface{}) { diff --git a/pkg/net/rpc/warden/server_test.go b/pkg/net/rpc/warden/server_test.go index a13f5bfd8..72f479ac7 100644 --- a/pkg/net/rpc/warden/server_test.go +++ b/pkg/net/rpc/warden/server_test.go @@ -130,7 +130,12 @@ func (s *helloServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.He } reply := &pb.HelloReply{Message: "status", Success: true} return reply, nil + } else if in.Name == "time_opt" { + time.Sleep(time.Second) + reply := &pb.HelloReply{Message: "status", Success: true} + return reply, nil } + return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil } @@ -201,6 +206,7 @@ func Test_Warden(t *testing.T) { testValidation(t) testServerRecovery(t) testClientRecovery(t) + testTimeoutOpt(t) testErrorDetail(t) testECodeStatus(t) testColorPass(t) @@ -219,6 +225,26 @@ func testValidation(t *testing.T) { } } +func testTimeoutOpt(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + client := NewClient(&clientConfig) + conn, err := client.Dial(ctx, "127.0.0.1:8080") + if err != nil { + t.Fatalf("did not connect: %v", err) + } + defer conn.Close() + c := pb.NewGreeterClient(conn) + start := time.Now() + _, err = c.SayHello(ctx, &pb.HelloRequest{Name: "time_opt", Age: 0}, WithTimeoutCallOption(time.Millisecond*500)) + if err == nil { + t.Fatalf("recovery must return error") + } + if time.Since(start) < time.Millisecond*400 { + t.Fatalf("client timeout must be greater than 400 Milliseconds;err:=%v", err) + } +} + func testAllErrorCase(t *testing.T) { // } else if in.Name == "general_error" { // return nil, fmt.Errorf("haha is error")