fix(consul):return err if ctx is done (#2550)

pull/2598/head
Remember 2 years ago committed by GitHub
parent c442a320a0
commit 271b6c2924
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 38
      contrib/registry/consul/registry.go
  2. 29
      contrib/registry/consul/registry_test.go
  3. 1
      contrib/registry/consul/watcher.go

@ -27,6 +27,13 @@ func WithHealthCheck(enable bool) Option {
} }
} }
// WithTimeout with get services timeout option.
func WithTimeout(timeout time.Duration) Option {
return func(o *Registry) {
o.timeout = timeout
}
}
// WithDatacenter with registry datacenter option // WithDatacenter with registry datacenter option
func WithDatacenter(dc Datacenter) Option { func WithDatacenter(dc Datacenter) Option {
return func(o *Registry) { return func(o *Registry) {
@ -90,6 +97,7 @@ type Registry struct {
enableHealthCheck bool enableHealthCheck bool
registry map[string]*serviceSet registry map[string]*serviceSet
lock sync.RWMutex lock sync.RWMutex
timeout time.Duration
dc Datacenter dc Datacenter
} }
@ -99,6 +107,7 @@ func New(apiClient *api.Client, opts ...Option) *Registry {
dc: SingleDatacenter, dc: SingleDatacenter,
registry: make(map[string]*serviceSet), registry: make(map[string]*serviceSet),
enableHealthCheck: true, enableHealthCheck: true,
timeout: 10 * time.Second,
} }
for _, o := range opts { for _, o := range opts {
o(r) o(r)
@ -178,11 +187,11 @@ func (r *Registry) Watch(ctx context.Context, name string) (registry.Watcher, er
r.registry[name] = set r.registry[name] = set
} }
// 初始化watcher // init watcher
w := &watcher{ w := &watcher{
event: make(chan struct{}, 1), event: make(chan struct{}, 1),
} }
w.ctx, w.cancel = context.WithCancel(context.Background()) w.ctx, w.cancel = context.WithCancel(ctx)
w.set = set w.set = set
set.lock.Lock() set.lock.Lock()
set.watcher[w] = struct{}{} set.watcher[w] = struct{}{}
@ -195,7 +204,7 @@ func (r *Registry) Watch(ctx context.Context, name string) (registry.Watcher, er
} }
if !ok { if !ok {
err := r.resolve(set) err := r.resolve(ctx, set)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -203,22 +212,26 @@ func (r *Registry) Watch(ctx context.Context, name string) (registry.Watcher, er
return w, nil return w, nil
} }
func (r *Registry) resolve(ss *serviceSet) error { func (r *Registry) resolve(ctx context.Context, ss *serviceSet) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) timeoutCtx, cancel := context.WithTimeout(ctx, r.timeout)
services, idx, err := r.cli.Service(ctx, ss.serviceName, 0, true) defer cancel()
cancel()
services, idx, err := r.cli.Service(timeoutCtx, ss.serviceName, 0, true)
if err != nil { if err != nil {
return err return err
} else if len(services) > 0 { }
if len(services) > 0 {
ss.broadcast(services) ss.broadcast(services)
} }
go func() { go func() {
ticker := time.NewTicker(time.Second) ticker := time.NewTicker(time.Second)
defer ticker.Stop() defer ticker.Stop()
for { for {
<-ticker.C select {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) case <-ticker.C:
tmpService, tmpIdx, err := r.cli.Service(ctx, ss.serviceName, idx, true) timeoutCtx, cancel := context.WithTimeout(context.Background(), r.timeout)
tmpService, tmpIdx, err := r.cli.Service(timeoutCtx, ss.serviceName, idx, true)
cancel() cancel()
if err != nil { if err != nil {
time.Sleep(time.Second) time.Sleep(time.Second)
@ -229,6 +242,9 @@ func (r *Registry) resolve(ss *serviceSet) error {
ss.broadcast(services) ss.broadcast(services)
} }
idx = tmpIdx idx = tmpIdx
case <-ctx.Done():
return
}
} }
}() }()

@ -13,7 +13,7 @@ import (
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
) )
func tcpServer(t *testing.T, lis net.Listener) { func tcpServer(lis net.Listener) {
for { for {
conn, err := lis.Accept() conn, err := lis.Accept()
if err != nil { if err != nil {
@ -148,7 +148,7 @@ func TestRegistry_GetService(t *testing.T) {
t.Fail() t.Fail()
} }
defer lis.Close() defer lis.Close()
go tcpServer(t, lis) go tcpServer(lis)
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
cli, err := api.NewClient(&api.Config{Address: "127.0.0.1:8500"}) cli, err := api.NewClient(&api.Config{Address: "127.0.0.1:8500"})
if err != nil { if err != nil {
@ -284,9 +284,12 @@ func TestRegistry_Watch(t *testing.T) {
type args struct { type args struct {
ctx context.Context ctx context.Context
cancel func()
opts []Option opts []Option
instance *registry.ServiceInstance instance *registry.ServiceInstance
} }
canceledCtx, cancel := context.WithCancel(context.Background())
tests := []struct { tests := []struct {
name string name string
args args args args
@ -308,6 +311,21 @@ func TestRegistry_Watch(t *testing.T) {
preFunc: func(t *testing.T) { preFunc: func(t *testing.T) {
}, },
}, },
{
name: "ctx has been cancelled",
args: args{
ctx: canceledCtx,
cancel: cancel,
instance: instance1,
opts: []Option{
WithHealthCheck(false),
},
},
want: nil,
wantErr: true,
preFunc: func(t *testing.T) {
},
},
{ {
name: "register with healthCheck", name: "register with healthCheck",
args: args{ args: args{
@ -325,8 +343,9 @@ func TestRegistry_Watch(t *testing.T) {
lis, err := net.Listen("tcp", addr) lis, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
t.Errorf("listen tcp %s failed!", addr) t.Errorf("listen tcp %s failed!", addr)
return
} }
go tcpServer(t, lis) go tcpServer(lis)
}, },
}, },
} }
@ -355,6 +374,10 @@ func TestRegistry_Watch(t *testing.T) {
t.Error(err) t.Error(err)
} }
if tt.args.cancel != nil {
tt.args.cancel()
}
service, err := watch.Next() service, err := watch.Next()
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {

@ -19,6 +19,7 @@ func (w *watcher) Next() (services []*registry.ServiceInstance, err error) {
select { select {
case <-w.ctx.Done(): case <-w.ctx.Done():
err = w.ctx.Err() err = w.ctx.Err()
return
case <-w.event: case <-w.event:
} }

Loading…
Cancel
Save