diff --git a/app.go b/app.go index ebf119176..509a2e51b 100644 --- a/app.go +++ b/app.go @@ -81,7 +81,7 @@ func (a *App) Run() error { }) } if a.opts.registry != nil { - if err := a.opts.registry.Register(a.instance); err != nil { + if err := a.opts.registry.Register(a.opts.ctx, a.instance); err != nil { return err } } @@ -106,7 +106,7 @@ func (a *App) Run() error { // Stop gracefully stops the application. func (a *App) Stop() error { if a.opts.registry != nil { - if err := a.opts.registry.Deregister(a.instance); err != nil { + if err := a.opts.registry.Deregister(a.opts.ctx, a.instance); err != nil { return err } } diff --git a/registry/registry.go b/registry/registry.go index cd74ca694..6f4914553 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -1,15 +1,17 @@ package registry +import "context" + // Registry is service registry. type Registry interface { // Register the registration. - Register(service *ServiceInstance) error + Register(ctx context.Context, service *ServiceInstance) error // Deregister the registration. - Deregister(service *ServiceInstance) error + Deregister(ctx context.Context, service *ServiceInstance) error // Service return the service instances in memory according to the service name. - Service(name string) ([]*ServiceInstance, error) + Service(ctx context.Context, name string) ([]*ServiceInstance, error) // Watch creates a watcher according to the service name. - Watch(name string) (Watcher, error) + Watch(ctx context.Context, name string) (Watcher, error) } // Watcher is service watcher. diff --git a/transport/grpc/resolver/discovery/builder.go b/transport/grpc/resolver/discovery/builder.go index f5fa279ab..462c63c20 100644 --- a/transport/grpc/resolver/discovery/builder.go +++ b/transport/grpc/resolver/discovery/builder.go @@ -1,6 +1,8 @@ package discovery import ( + "context" + "github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/registry" "google.golang.org/grpc/resolver" @@ -36,7 +38,7 @@ func NewBuilder(r registry.Registry, opts ...Option) resolver.Builder { } func (d *builder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { - w, err := d.registry.Watch(target.Endpoint) + w, err := d.registry.Watch(context.Background(), target.Endpoint) if err != nil { return nil, err }