From 271b6c29243435d498146ff9e72c39bf6760e5a5 Mon Sep 17 00:00:00 2001 From: Remember <36129334+wuqinqiang@users.noreply.github.com> Date: Sun, 1 Jan 2023 20:15:28 +0800 Subject: [PATCH] fix(consul):return err if ctx is done (#2550) --- contrib/registry/consul/registry.go | 56 +++++++++++++++--------- contrib/registry/consul/registry_test.go | 29 ++++++++++-- contrib/registry/consul/watcher.go | 1 + 3 files changed, 63 insertions(+), 23 deletions(-) diff --git a/contrib/registry/consul/registry.go b/contrib/registry/consul/registry.go index b85f929ff..e5ac1819c 100644 --- a/contrib/registry/consul/registry.go +++ b/contrib/registry/consul/registry.go @@ -27,6 +27,13 @@ func WithHealthCheck(enable bool) Option { } } +// WithTimeout with get services timeout option. +func WithTimeout(timeout time.Duration) Option { + return func(o *Registry) { + o.timeout = timeout + } +} + // WithDatacenter with registry datacenter option func WithDatacenter(dc Datacenter) Option { return func(o *Registry) { @@ -90,6 +97,7 @@ type Registry struct { enableHealthCheck bool registry map[string]*serviceSet lock sync.RWMutex + timeout time.Duration dc Datacenter } @@ -99,6 +107,7 @@ func New(apiClient *api.Client, opts ...Option) *Registry { dc: SingleDatacenter, registry: make(map[string]*serviceSet), enableHealthCheck: true, + timeout: 10 * time.Second, } for _, o := range opts { o(r) @@ -178,11 +187,11 @@ func (r *Registry) Watch(ctx context.Context, name string) (registry.Watcher, er r.registry[name] = set } - // 初始化watcher + // init watcher w := &watcher{ event: make(chan struct{}, 1), } - w.ctx, w.cancel = context.WithCancel(context.Background()) + w.ctx, w.cancel = context.WithCancel(ctx) w.set = set set.lock.Lock() set.watcher[w] = struct{}{} @@ -195,7 +204,7 @@ func (r *Registry) Watch(ctx context.Context, name string) (registry.Watcher, er } if !ok { - err := r.resolve(set) + err := r.resolve(ctx, set) if err != nil { return nil, err } @@ -203,32 +212,39 @@ func (r *Registry) Watch(ctx context.Context, name string) (registry.Watcher, er return w, nil } -func (r *Registry) resolve(ss *serviceSet) error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - services, idx, err := r.cli.Service(ctx, ss.serviceName, 0, true) - cancel() +func (r *Registry) resolve(ctx context.Context, ss *serviceSet) error { + timeoutCtx, cancel := context.WithTimeout(ctx, r.timeout) + defer cancel() + + services, idx, err := r.cli.Service(timeoutCtx, ss.serviceName, 0, true) if err != nil { return err - } else if len(services) > 0 { + } + if len(services) > 0 { ss.broadcast(services) } + go func() { ticker := time.NewTicker(time.Second) defer ticker.Stop() for { - <-ticker.C - ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) - tmpService, tmpIdx, err := r.cli.Service(ctx, ss.serviceName, idx, true) - cancel() - if err != nil { - time.Sleep(time.Second) - continue - } - if len(tmpService) != 0 && tmpIdx != idx { - services = tmpService - ss.broadcast(services) + select { + case <-ticker.C: + timeoutCtx, cancel := context.WithTimeout(context.Background(), r.timeout) + tmpService, tmpIdx, err := r.cli.Service(timeoutCtx, ss.serviceName, idx, true) + cancel() + if err != nil { + time.Sleep(time.Second) + continue + } + if len(tmpService) != 0 && tmpIdx != idx { + services = tmpService + ss.broadcast(services) + } + idx = tmpIdx + case <-ctx.Done(): + return } - idx = tmpIdx } }() diff --git a/contrib/registry/consul/registry_test.go b/contrib/registry/consul/registry_test.go index 6a1523c48..a60251236 100644 --- a/contrib/registry/consul/registry_test.go +++ b/contrib/registry/consul/registry_test.go @@ -13,7 +13,7 @@ import ( "github.com/go-kratos/kratos/v2/registry" ) -func tcpServer(t *testing.T, lis net.Listener) { +func tcpServer(lis net.Listener) { for { conn, err := lis.Accept() if err != nil { @@ -148,7 +148,7 @@ func TestRegistry_GetService(t *testing.T) { t.Fail() } defer lis.Close() - go tcpServer(t, lis) + go tcpServer(lis) time.Sleep(time.Millisecond * 100) cli, err := api.NewClient(&api.Config{Address: "127.0.0.1:8500"}) if err != nil { @@ -284,9 +284,12 @@ func TestRegistry_Watch(t *testing.T) { type args struct { ctx context.Context + cancel func() opts []Option instance *registry.ServiceInstance } + canceledCtx, cancel := context.WithCancel(context.Background()) + tests := []struct { name string args args @@ -308,6 +311,21 @@ func TestRegistry_Watch(t *testing.T) { preFunc: func(t *testing.T) { }, }, + { + name: "ctx has been cancelled", + args: args{ + ctx: canceledCtx, + cancel: cancel, + instance: instance1, + opts: []Option{ + WithHealthCheck(false), + }, + }, + want: nil, + wantErr: true, + preFunc: func(t *testing.T) { + }, + }, { name: "register with healthCheck", args: args{ @@ -325,8 +343,9 @@ func TestRegistry_Watch(t *testing.T) { lis, err := net.Listen("tcp", addr) if err != nil { t.Errorf("listen tcp %s failed!", addr) + return } - go tcpServer(t, lis) + go tcpServer(lis) }, }, } @@ -355,6 +374,10 @@ func TestRegistry_Watch(t *testing.T) { t.Error(err) } + if tt.args.cancel != nil { + tt.args.cancel() + } + service, err := watch.Next() if (err != nil) != tt.wantErr { diff --git a/contrib/registry/consul/watcher.go b/contrib/registry/consul/watcher.go index 3b4bfece8..0a5d35751 100644 --- a/contrib/registry/consul/watcher.go +++ b/contrib/registry/consul/watcher.go @@ -19,6 +19,7 @@ func (w *watcher) Next() (services []*registry.ServiceInstance, err error) { select { case <-w.ctx.Done(): err = w.ctx.Err() + return case <-w.event: }