diff --git a/contrib/registry/consul/client.go b/contrib/registry/consul/client.go index 31d0da687..ea5fe846d 100644 --- a/contrib/registry/consul/client.go +++ b/contrib/registry/consul/client.go @@ -2,6 +2,7 @@ package consul import ( "context" + "errors" "fmt" "math/rand" "net" @@ -204,21 +205,32 @@ func (c *Client) Register(_ context.Context, svc *registry.ServiceInstance, enab defer ticker.Stop() for { select { + case <-c.ctx.Done(): + _ = c.cli.Agent().ServiceDeregister(svc.ID) + return + default: + } + select { + case <-c.ctx.Done(): + _ = c.cli.Agent().ServiceDeregister(svc.ID) + return case <-ticker.C: + // ensure that unregistered services will not be re-registered by mistake + if errors.Is(c.ctx.Err(), context.Canceled) || errors.Is(c.ctx.Err(), context.DeadlineExceeded) { + _ = c.cli.Agent().ServiceDeregister(svc.ID) + return + } err = c.cli.Agent().UpdateTTL("service:"+svc.ID, "pass", "pass") if err != nil { log.Errorf("[Consul] update ttl heartbeat to consul failed! err=%v", err) // when the previous report fails, try to re register the service - time.AfterFunc(time.Duration(rand.Intn(5))*time.Second, func() { - if err := c.cli.Agent().ServiceRegister(asr); err != nil { - log.Errorf("[Consul] re registry service failed!, err=%v", err) - } else { - log.Warn("[Consul] re registry of service occurred success") - } - }) + time.Sleep(time.Duration(rand.Intn(5)) * time.Second) + if err := c.cli.Agent().ServiceRegister(asr); err != nil { + log.Errorf("[Consul] re registry service failed!, err=%v", err) + } else { + log.Warn("[Consul] re registry of service occurred success") + } } - case <-c.ctx.Done(): - return } } }() @@ -228,6 +240,6 @@ func (c *Client) Register(_ context.Context, svc *registry.ServiceInstance, enab // Deregister service by service ID func (c *Client) Deregister(_ context.Context, serviceID string) error { - c.cancel() + defer c.cancel() return c.cli.Agent().ServiceDeregister(serviceID) } diff --git a/contrib/registry/consul/registry_test.go b/contrib/registry/consul/registry_test.go index a60251236..791c2efc0 100644 --- a/contrib/registry/consul/registry_test.go +++ b/contrib/registry/consul/registry_test.go @@ -74,14 +74,14 @@ func TestRegistry_Register(t *testing.T) { serverName: "server-1", server: []*registry.ServiceInstance{ { - ID: "1", + ID: "2", Name: "server-1", Version: "v0.0.1", Metadata: nil, Endpoints: []string{"http://127.0.0.1:8000"}, }, { - ID: "1", + ID: "2", Name: "server-1", Version: "v0.0.2", Metadata: nil, @@ -91,7 +91,7 @@ func TestRegistry_Register(t *testing.T) { }, want: []*registry.ServiceInstance{ { - ID: "1", + ID: "2", Name: "server-1", Version: "v0.0.2", Metadata: nil, @@ -168,6 +168,13 @@ func TestRegistry_GetService(t *testing.T) { Endpoints: []string{fmt.Sprintf("tcp://%s?isSecure=false", addr)}, } + instance2 := ®istry.ServiceInstance{ + ID: "2", + Name: "server-1", + Version: "v0.0.1", + Endpoints: []string{fmt.Sprintf("tcp://%s?isSecure=false", addr)}, + } + type fields struct { registry *Registry } @@ -223,10 +230,10 @@ func TestRegistry_GetService(t *testing.T) { want: nil, wantErr: true, preFunc: func(t *testing.T) { - if err := r.Register(context.Background(), instance1); err != nil { + if err := r.Register(context.Background(), instance2); err != nil { t.Error(err) } - watch, err := r.Watch(context.Background(), instance1.Name) + watch, err := r.Watch(context.Background(), instance2.Name) if err != nil { t.Error(err) } @@ -236,7 +243,7 @@ func TestRegistry_GetService(t *testing.T) { } }, deferFunc: func(t *testing.T) { - err := r.Deregister(context.Background(), instance1) + err := r.Deregister(context.Background(), instance2) if err != nil { t.Error(err) } @@ -282,6 +289,20 @@ func TestRegistry_Watch(t *testing.T) { Endpoints: []string{fmt.Sprintf("tcp://%s?isSecure=false", addr)}, } + instance2 := ®istry.ServiceInstance{ + ID: "2", + Name: "server-1", + Version: "v0.0.1", + Endpoints: []string{fmt.Sprintf("tcp://%s?isSecure=false", addr)}, + } + + instance3 := ®istry.ServiceInstance{ + ID: "3", + Name: "server-1", + Version: "v0.0.1", + Endpoints: []string{fmt.Sprintf("tcp://%s?isSecure=false", addr)}, + } + type args struct { ctx context.Context cancel func() @@ -316,7 +337,7 @@ func TestRegistry_Watch(t *testing.T) { args: args{ ctx: canceledCtx, cancel: cancel, - instance: instance1, + instance: instance2, opts: []Option{ WithHealthCheck(false), }, @@ -330,14 +351,14 @@ func TestRegistry_Watch(t *testing.T) { name: "register with healthCheck", args: args{ ctx: context.Background(), - instance: instance1, + instance: instance3, opts: []Option{ WithHeartbeat(true), WithHealthCheck(true), WithHealthCheckInterval(5), }, }, - want: []*registry.ServiceInstance{instance1}, + want: []*registry.ServiceInstance{instance3}, wantErr: false, preFunc: func(t *testing.T) { lis, err := net.Listen("tcp", addr)