diff --git a/app.go b/app.go index 8658fad90..a0dd9d02e 100644 --- a/app.go +++ b/app.go @@ -52,24 +52,29 @@ func (a *App) Run() error { a.log.Infow( "service_id", a.opts.id, "service_name", a.opts.name, - "version", a.opts.version, + "service_version", a.opts.version, ) - instance, err := buildInstance(a.opts) + instance, err := a.buildInstance() if err != nil { return err } - eg, ctx := errgroup.WithContext(a.ctx) + ctx := NewContext(a.ctx, AppInfo{ + ID: a.opts.id, + Name: a.opts.name, + Version: a.opts.version, + }) + eg, ctx := errgroup.WithContext(ctx) wg := sync.WaitGroup{} for _, srv := range a.opts.servers { srv := srv eg.Go(func() error { <-ctx.Done() // wait for stop signal - return srv.Stop() + return srv.Stop(ctx) }) wg.Add(1) eg.Go(func() error { wg.Done() - return srv.Start() + return srv.Start(ctx) }) } wg.Wait() @@ -110,23 +115,23 @@ func (a *App) Stop() error { return nil } -func buildInstance(o options) (*registry.ServiceInstance, error) { - if len(o.endpoints) == 0 { - for _, srv := range o.servers { +func (a *App) buildInstance() (*registry.ServiceInstance, error) { + if len(a.opts.endpoints) == 0 { + for _, srv := range a.opts.servers { if r, ok := srv.(transport.Endpointer); ok { e, err := r.Endpoint() if err != nil { return nil, err } - o.endpoints = append(o.endpoints, e) + a.opts.endpoints = append(a.opts.endpoints, e) } } } return ®istry.ServiceInstance{ - ID: o.id, - Name: o.name, - Version: o.version, - Metadata: o.metadata, - Endpoints: o.endpoints, + ID: a.opts.id, + Name: a.opts.name, + Version: a.opts.version, + Metadata: a.opts.metadata, + Endpoints: a.opts.endpoints, }, nil } diff --git a/context.go b/context.go new file mode 100644 index 000000000..da070da51 --- /dev/null +++ b/context.go @@ -0,0 +1,23 @@ +package kratos + +import "context" + +// AppInfo is application context value. +type AppInfo struct { + ID string + Name string + Version string +} + +type appKey struct{} + +// NewContext returns a new Context that carries value. +func NewContext(ctx context.Context, s AppInfo) context.Context { + return context.WithValue(ctx, appKey{}, s) +} + +// FromContext returns the Transport value stored in ctx, if any. +func FromContext(ctx context.Context) (s AppInfo, ok bool) { + s, ok = ctx.Value(appKey{}).(AppInfo) + return +} diff --git a/internal/context/context.go b/internal/context/context.go new file mode 100644 index 000000000..0b2c3388e --- /dev/null +++ b/internal/context/context.go @@ -0,0 +1,115 @@ +package context + +import ( + "context" + "sync" + "sync/atomic" + "time" +) + +type mergeCtx struct { + parent1, parent2 context.Context + + done chan struct{} + doneMark uint32 + doneOnce sync.Once + doneErr error + + cancelCh chan struct{} + cancelOnce sync.Once +} + +// Merge merges two contexts into one. +func Merge(parent1, parent2 context.Context) (context.Context, context.CancelFunc) { + mc := &mergeCtx{ + parent1: parent1, + parent2: parent2, + done: make(chan struct{}), + cancelCh: make(chan struct{}), + } + select { + case <-parent1.Done(): + mc.finish(parent1.Err()) + case <-parent2.Done(): + mc.finish(parent2.Err()) + default: + go mc.wait() + } + return mc, mc.cancel +} + +func (mc *mergeCtx) finish(err error) error { + mc.doneOnce.Do(func() { + mc.doneErr = err + atomic.StoreUint32(&mc.doneMark, 1) + close(mc.done) + }) + return mc.doneErr +} + +func (mc *mergeCtx) wait() { + var err error + select { + case <-mc.parent1.Done(): + err = mc.parent1.Err() + case <-mc.parent2.Done(): + err = mc.parent2.Err() + case <-mc.cancelCh: + err = context.Canceled + } + mc.finish(err) +} + +func (mc *mergeCtx) cancel() { + mc.cancelOnce.Do(func() { + close(mc.cancelCh) + }) +} + +// Done implements context.Context. +func (mc *mergeCtx) Done() <-chan struct{} { + return mc.done +} + +// Err implements context.Context. +func (mc *mergeCtx) Err() error { + if atomic.LoadUint32(&mc.doneMark) != 0 { + return mc.doneErr + } + var err error + select { + case <-mc.parent1.Done(): + err = mc.parent1.Err() + case <-mc.parent2.Done(): + err = mc.parent2.Err() + case <-mc.cancelCh: + err = context.Canceled + default: + return nil + } + return mc.finish(err) +} + +// Deadline implements context.Context. +func (mc *mergeCtx) Deadline() (time.Time, bool) { + d1, ok1 := mc.parent1.Deadline() + d2, ok2 := mc.parent2.Deadline() + switch { + case !ok1: + return d2, ok2 + case !ok2: + return d1, ok1 + case d1.Before(d2): + return d1, true + default: + return d2, true + } +} + +// Value implements context.Context. +func (mc *mergeCtx) Value(key interface{}) interface{} { + if v := mc.parent1.Value(key); v != nil { + return v + } + return mc.parent2.Value(key) +} diff --git a/options.go b/options.go index cecb4eaa0..2bc8ba662 100644 --- a/options.go +++ b/options.go @@ -59,22 +59,22 @@ func Context(ctx context.Context) Option { return func(o *options) { o.ctx = ctx } } -// Signal with exit signals. -func Signal(sigs ...os.Signal) Option { - return func(o *options) { o.sigs = sigs } -} - // Logger with service logger. func Logger(logger log.Logger) Option { return func(o *options) { o.logger = logger } } -// Registrar with service registry. -func Registrar(r registry.Registrar) Option { - return func(o *options) { o.registrar = r } -} - // Server with transport servers. func Server(srv ...transport.Server) Option { return func(o *options) { o.servers = srv } } + +// Signal with exit signals. +func Signal(sigs ...os.Signal) Option { + return func(o *options) { o.sigs = sigs } +} + +// Registrar with service registry. +func Registrar(r registry.Registrar) Option { + return func(o *options) { o.registrar = r } +} diff --git a/transport/grpc/client.go b/transport/grpc/client.go index 6b062cd38..c56b3ea91 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -103,7 +103,7 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien func unaryClientInterceptor(m middleware.Middleware, timeout time.Duration) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { ctx = transport.NewContext(ctx, transport.Transport{Kind: transport.KindGRPC}) - ctx = NewClientContext(ctx, ClientInfo{FullMethod: method}) + ctx = NewClientContext(ctx, ClientInfo{FullMethod: method, Target: cc.Target()}) if timeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, timeout) diff --git a/transport/grpc/context.go b/transport/grpc/context.go index a1d7eb658..3ff5610c2 100644 --- a/transport/grpc/context.go +++ b/transport/grpc/context.go @@ -27,6 +27,7 @@ func FromServerContext(ctx context.Context) (info ServerInfo, ok bool) { type ClientInfo struct { // FullMethod is the full RPC method string, i.e., /package.service/method. FullMethod string + Target string } type clientKey struct{} diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 18b9e4eca..062b97e3e 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -8,6 +8,7 @@ import ( "time" "github.com/go-kratos/kratos/v2/api/metadata" + ic "github.com/go-kratos/kratos/v2/internal/context" "github.com/go-kratos/kratos/v2/internal/host" "github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/middleware" @@ -70,6 +71,7 @@ func Options(opts ...grpc.ServerOption) ServerOption { // Server is a gRPC server wrapper. type Server struct { *grpc.Server + ctx context.Context lis net.Listener network string address string @@ -86,7 +88,7 @@ func NewServer(opts ...ServerOption) *Server { srv := &Server{ network: "tcp", address: ":0", - timeout: time.Second, + timeout: 1 * time.Second, middleware: middleware.Chain( recovery.Recovery(), ), @@ -98,7 +100,7 @@ func NewServer(opts ...ServerOption) *Server { } var grpcOpts = []grpc.ServerOption{ grpc.ChainUnaryInterceptor( - unaryServerInterceptor(srv.middleware, srv.timeout), + srv.unaryServerInterceptor(), ), } if len(srv.grpcOpts) > 0 { @@ -128,11 +130,13 @@ func (s *Server) Endpoint() (string, error) { if err != nil { return "", err } + s.address = addr return fmt.Sprintf("grpc://%s", addr), nil } // Start start the gRPC server. -func (s *Server) Start() error { +func (s *Server) Start(ctx context.Context) error { + s.ctx = ctx if s.lis == nil { lis, err := net.Listen(s.network, s.address) if err != nil { @@ -146,27 +150,29 @@ func (s *Server) Start() error { } // Stop stop the gRPC server. -func (s *Server) Stop() error { +func (s *Server) Stop(ctx context.Context) error { s.GracefulStop() s.health.Shutdown() s.log.Info("[gRPC] server stopping") return nil } -func unaryServerInterceptor(m middleware.Middleware, timeout time.Duration) grpc.UnaryServerInterceptor { +func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + ctx, cancel := ic.Merge(ctx, s.ctx) + defer cancel() ctx = transport.NewContext(ctx, transport.Transport{Kind: transport.KindGRPC}) ctx = NewServerContext(ctx, ServerInfo{Server: info.Server, FullMethod: info.FullMethod}) - if timeout > 0 { + if s.timeout > 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, timeout) + ctx, cancel = context.WithTimeout(ctx, s.timeout) defer cancel() } h := func(ctx context.Context, req interface{}) (interface{}, error) { return handler(ctx, req) } - if m != nil { - h = m(h) + if s.middleware != nil { + h = s.middleware(h) } return h(ctx, req) } diff --git a/transport/grpc/server_test.go b/transport/grpc/server_test.go index 9c927bc66..a101648a7 100644 --- a/transport/grpc/server_test.go +++ b/transport/grpc/server_test.go @@ -9,7 +9,11 @@ import ( "github.com/go-kratos/kratos/v2/internal/host" ) +type testKey struct{} + func TestServer(t *testing.T) { + ctx := context.Background() + ctx = context.WithValue(ctx, testKey{}, "test") srv := NewServer() if e, err := srv.Endpoint(); err != nil || e == "" { t.Fatal(e, err) @@ -17,13 +21,13 @@ func TestServer(t *testing.T) { go func() { // start server - if err := srv.Start(); err != nil { + if err := srv.Start(ctx); err != nil { panic(err) } }() time.Sleep(time.Second) testClient(t, srv) - srv.Stop() + srv.Stop(ctx) } func testClient(t *testing.T, srv *Server) { diff --git a/transport/http/client.go b/transport/http/client.go index 8198fd3e9..5d4cf5c89 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -27,6 +27,7 @@ type Client struct { b balancer.Balancer scheme string + endpoint string target Target userAgent string middleware middleware.Middleware @@ -148,7 +149,7 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { options := &clientOptions{ ctx: ctx, scheme: "http", - timeout: 1 * time.Second, + timeout: 500 * time.Millisecond, encoder: DefaultRequestEncoder, decoder: DefaultResponseDecoder, errorDecoder: DefaultErrorDecoder, @@ -196,6 +197,7 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { userAgent: options.userAgent, target: target, scheme: options.scheme, + endpoint: options.endpoint, discovery: options.discovery, b: options.balancer, }, nil @@ -240,6 +242,7 @@ func (client *Client) Invoke(ctx context.Context, path string, args interface{}, ctx = transport.NewContext(ctx, transport.Transport{Kind: transport.KindHTTP}) ctx = NewClientContext(ctx, ClientInfo{ + Target: client.endpoint, PathPattern: c.pathPattern, Request: req, }) diff --git a/transport/http/context.go b/transport/http/context.go index e8bab4a88..eecb7e2ea 100644 --- a/transport/http/context.go +++ b/transport/http/context.go @@ -28,6 +28,7 @@ func FromServerContext(ctx context.Context) (info ServerInfo, ok bool) { type ClientInfo struct { Request *http.Request PathPattern string + Target string } type clientKey struct{} diff --git a/transport/http/server.go b/transport/http/server.go index ae6af783c..358d5245e 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -9,6 +9,7 @@ import ( "strings" "time" + ic "github.com/go-kratos/kratos/v2/internal/context" "github.com/go-kratos/kratos/v2/internal/host" "github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/transport" @@ -53,6 +54,7 @@ func Logger(logger log.Logger) ServerOption { // Server is an HTTP server wrapper. type Server struct { *http.Server + ctx context.Context lis net.Listener network string address string @@ -66,7 +68,7 @@ func NewServer(opts ...ServerOption) *Server { srv := &Server{ network: "tcp", address: ":0", - timeout: time.Second, + timeout: 1 * time.Second, log: log.NewHelper(log.DefaultLogger), } for _, o := range opts { @@ -94,10 +96,14 @@ func (s *Server) HandleFunc(path string, h http.HandlerFunc) { // ServeHTTP should write reply headers and data to the ResponseWriter and then return. func (s *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) { - ctx, cancel := context.WithTimeout(req.Context(), s.timeout) + ctx, cancel := ic.Merge(req.Context(), s.ctx) defer cancel() ctx = transport.NewContext(ctx, transport.Transport{Kind: transport.KindHTTP}) ctx = NewServerContext(ctx, ServerInfo{Request: req, Response: res}) + if s.timeout > 0 { + ctx, cancel = context.WithTimeout(req.Context(), s.timeout) + defer cancel() + } s.router.ServeHTTP(res, req.WithContext(ctx)) } @@ -116,11 +122,13 @@ func (s *Server) Endpoint() (string, error) { if err != nil { return "", err } + s.address = addr return fmt.Sprintf("http://%s", addr), nil } // Start start the HTTP server. -func (s *Server) Start() error { +func (s *Server) Start(ctx context.Context) error { + s.ctx = ctx if s.lis == nil { lis, err := net.Listen(s.network, s.address) if err != nil { @@ -136,7 +144,7 @@ func (s *Server) Start() error { } // Stop stop the HTTP server. -func (s *Server) Stop() error { +func (s *Server) Stop(ctx context.Context) error { s.log.Info("[HTTP] server stopping") return s.Shutdown(context.Background()) } diff --git a/transport/http/server_test.go b/transport/http/server_test.go index e84ff8778..d5efb1417 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -12,6 +12,8 @@ import ( "github.com/go-kratos/kratos/v2/internal/host" ) +type testKey struct{} + type testData struct { Path string `json:"path"` } @@ -20,7 +22,13 @@ func TestServer(t *testing.T) { fn := func(w http.ResponseWriter, r *http.Request) { data := &testData{Path: r.RequestURI} json.NewEncoder(w).Encode(data) + + if r.Context().Value(testKey{}) != "test" { + w.WriteHeader(500) + } } + ctx := context.Background() + ctx = context.WithValue(ctx, testKey{}, "test") srv := NewServer() srv.HandleFunc("/index", fn) @@ -29,13 +37,13 @@ func TestServer(t *testing.T) { } go func() { - if err := srv.Start(); err != nil { + if err := srv.Start(ctx); err != nil { panic(err) } }() time.Sleep(time.Second) testClient(t, srv) - srv.Stop() + srv.Stop(ctx) } func testClient(t *testing.T, srv *Server) { @@ -68,6 +76,9 @@ func testClient(t *testing.T, srv *Server) { if err != nil { t.Fatal(err) } + if resp.StatusCode != 200 { + t.Fatalf("http status got %d", resp.StatusCode) + } content, err := ioutil.ReadAll(resp.Body) if err != nil { t.Fatalf("read resp error %v", err) diff --git a/transport/transport.go b/transport/transport.go index fb91b5c7e..aee456d1b 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -12,8 +12,8 @@ import ( // Server is transport server. type Server interface { - Start() error - Stop() error + Start(context.Context) error + Stop(context.Context) error } // Endpointer is registry endpoint.