diff --git a/app.go b/app.go index e3d81ffb1..436be2949 100644 --- a/app.go +++ b/app.go @@ -89,8 +89,15 @@ func (a *App) Run() error { a.mu.Lock() a.instance = instance a.mu.Unlock() - eg, ctx := errgroup.WithContext(NewContext(a.ctx, a)) + sctx := NewContext(a.ctx, a) + eg, ctx := errgroup.WithContext(sctx) wg := sync.WaitGroup{} + + for _, fn := range a.opts.beforeStart { + if err = fn(sctx); err != nil { + return err + } + } for _, srv := range a.opts.servers { srv := srv eg.Go(func() error { @@ -102,17 +109,23 @@ func (a *App) Run() error { wg.Add(1) eg.Go(func() error { wg.Done() // here is to ensure server start has begun running before register, so defer is not needed - return srv.Start(NewContext(a.opts.ctx, a)) + return srv.Start(sctx) }) } wg.Wait() if a.opts.registrar != nil { rctx, rcancel := context.WithTimeout(ctx, a.opts.registrarTimeout) defer rcancel() - if err := a.opts.registrar.Register(rctx, instance); err != nil { + if err = a.opts.registrar.Register(rctx, instance); err != nil { + return err + } + } + for _, fn := range a.opts.afterStart { + if err = fn(sctx); err != nil { return err } } + c := make(chan os.Signal, 1) signal.Notify(c, a.opts.sigs...) eg.Go(func() error { @@ -123,28 +136,36 @@ func (a *App) Run() error { return a.Stop() } }) - if err := eg.Wait(); err != nil && !errors.Is(err, context.Canceled) { + if err = eg.Wait(); err != nil && !errors.Is(err, context.Canceled) { return err } - return nil + for _, fn := range a.opts.afterStop { + err = fn(sctx) + } + return err } // Stop gracefully stops the application. -func (a *App) Stop() error { +func (a *App) Stop() (err error) { + sctx := NewContext(a.ctx, a) + for _, fn := range a.opts.beforeStop { + err = fn(sctx) + } + a.mu.Lock() instance := a.instance a.mu.Unlock() if a.opts.registrar != nil && instance != nil { ctx, cancel := context.WithTimeout(NewContext(a.ctx, a), a.opts.registrarTimeout) defer cancel() - if err := a.opts.registrar.Deregister(ctx, instance); err != nil { + if err = a.opts.registrar.Deregister(ctx, instance); err != nil { return err } } if a.cancel != nil { a.cancel() } - return nil + return err } func (a *App) buildInstance() (*registry.ServiceInstance, error) { diff --git a/app_test.go b/app_test.go index 2b8c8c9d2..92a2ca29f 100644 --- a/app_test.go +++ b/app_test.go @@ -47,6 +47,22 @@ func TestApp(t *testing.T) { Name("kratos"), Version("v1.0.0"), Server(hs, gs), + BeforeStart(func(_ context.Context) error { + t.Log("BeforeStart...") + return nil + }), + BeforeStop(func(_ context.Context) error { + t.Log("BeforeStop...") + return nil + }), + AfterStart(func(_ context.Context) error { + t.Log("AfterStart...") + return nil + }), + AfterStop(func(_ context.Context) error { + t.Log("AfterStop...") + return nil + }), Registrar(&mockRegistry{service: make(map[string]*registry.ServiceInstance)}), ) time.AfterFunc(time.Second, func() { diff --git a/options.go b/options.go index 8f9e9a55d..5e337a38e 100644 --- a/options.go +++ b/options.go @@ -30,6 +30,12 @@ type options struct { registrarTimeout time.Duration stopTimeout time.Duration servers []transport.Server + + // Before and After funcs + beforeStart []func(context.Context) error + beforeStop []func(context.Context) error + afterStart []func(context.Context) error + afterStop []func(context.Context) error } // ID with service id. @@ -91,3 +97,33 @@ func RegistrarTimeout(t time.Duration) Option { func StopTimeout(t time.Duration) Option { return func(o *options) { o.stopTimeout = t } } + +// Before and Afters + +// BeforeStart run funcs before app starts +func BeforeStart(fn func(context.Context) error) Option { + return func(o *options) { + o.beforeStart = append(o.beforeStart, fn) + } +} + +// BeforeStop run funcs before app stops +func BeforeStop(fn func(context.Context) error) Option { + return func(o *options) { + o.beforeStop = append(o.beforeStop, fn) + } +} + +// AfterStart run funcs after app starts +func AfterStart(fn func(context.Context) error) Option { + return func(o *options) { + o.afterStart = append(o.afterStart, fn) + } +} + +// AfterStop run funcs after app stops +func AfterStop(fn func(context.Context) error) Option { + return func(o *options) { + o.afterStop = append(o.afterStop, fn) + } +} diff --git a/options_test.go b/options_test.go index 438de3d83..606001737 100644 --- a/options_test.go +++ b/options_test.go @@ -152,3 +152,39 @@ func TestStopTimeout(t *testing.T) { t.Fatal("o.stopTimeout is not equal to v") } } + +func TestBeforeStart(t *testing.T) { + o := &options{} + v := func(_ context.Context) error { + t.Log("BeforeStart...") + return nil + } + BeforeStart(v)(o) +} + +func TestBeforeStop(t *testing.T) { + o := &options{} + v := func(_ context.Context) error { + t.Log("BeforeStop...") + return nil + } + BeforeStop(v)(o) +} + +func TestAfterStart(t *testing.T) { + o := &options{} + v := func(_ context.Context) error { + t.Log("AfterStart...") + return nil + } + AfterStart(v)(o) +} + +func TestAfterStop(t *testing.T) { + o := &options{} + v := func(_ context.Context) error { + t.Log("AfterStop...") + return nil + } + AfterStop(v)(o) +}