package grpc import ( "context" "crypto/tls" "net" "net/url" "time" "github.com/go-kratos/kratos/v2/internal/endpoint" apimd "github.com/go-kratos/kratos/v2/api/metadata" ic "github.com/go-kratos/kratos/v2/internal/context" "github.com/go-kratos/kratos/v2/internal/host" "github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/health" "google.golang.org/grpc/health/grpc_health_v1" grpcmd "google.golang.org/grpc/metadata" "google.golang.org/grpc/reflection" ) var ( _ transport.Server = (*Server)(nil) _ transport.Endpointer = (*Server)(nil) ) // ServerOption is gRPC server option. type ServerOption func(o *Server) // Network with server network. func Network(network string) ServerOption { return func(s *Server) { s.network = network } } // Address with server address. func Address(addr string) ServerOption { return func(s *Server) { s.address = addr } } // Timeout with server timeout. func Timeout(timeout time.Duration) ServerOption { return func(s *Server) { s.timeout = timeout } } // Logger with server logger. func Logger(logger log.Logger) ServerOption { return func(s *Server) { s.log = log.NewHelper(logger) } } // Middleware with server middleware. func Middleware(m ...middleware.Middleware) ServerOption { return func(s *Server) { s.middleware = m } } // TLSConfig with TLS config. func TLSConfig(c *tls.Config) ServerOption { return func(s *Server) { s.tlsConf = c } } // Listener with server lis func Listener(lis net.Listener) ServerOption { return func(s *Server) { s.lis = lis } } // UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the server. func UnaryInterceptor(in ...grpc.UnaryServerInterceptor) ServerOption { return func(s *Server) { s.ints = in } } // Options with grpc options. func Options(opts ...grpc.ServerOption) ServerOption { return func(s *Server) { s.grpcOpts = opts } } // Server is a gRPC server wrapper. type Server struct { *grpc.Server baseCtx context.Context tlsConf *tls.Config lis net.Listener err error network string address string endpoint *url.URL timeout time.Duration log *log.Helper middleware []middleware.Middleware ints []grpc.UnaryServerInterceptor grpcOpts []grpc.ServerOption health *health.Server metadata *apimd.Server } // NewServer creates a gRPC server by options. func NewServer(opts ...ServerOption) *Server { srv := &Server{ baseCtx: context.Background(), network: "tcp", address: ":0", timeout: 1 * time.Second, health: health.NewServer(), log: log.NewHelper(log.GetLogger()), } for _, o := range opts { o(srv) } ints := []grpc.UnaryServerInterceptor{ srv.unaryServerInterceptor(), } if len(srv.ints) > 0 { ints = append(ints, srv.ints...) } grpcOpts := []grpc.ServerOption{ grpc.ChainUnaryInterceptor(ints...), } if srv.tlsConf != nil { grpcOpts = append(grpcOpts, grpc.Creds(credentials.NewTLS(srv.tlsConf))) } if len(srv.grpcOpts) > 0 { grpcOpts = append(grpcOpts, srv.grpcOpts...) } srv.Server = grpc.NewServer(grpcOpts...) srv.metadata = apimd.NewServer(srv.Server) // listen and endpoint srv.err = srv.listenAndEndpoint() // internal register grpc_health_v1.RegisterHealthServer(srv.Server, srv.health) apimd.RegisterMetadataServer(srv.Server, srv.metadata) reflection.Register(srv.Server) return srv } // Endpoint return a real address to registry endpoint. // examples: // grpc://127.0.0.1:9000?isSecure=false func (s *Server) Endpoint() (*url.URL, error) { if s.err != nil { return nil, s.err } return s.endpoint, nil } // Start start the gRPC server. func (s *Server) Start(ctx context.Context) error { if s.err != nil { return s.err } s.baseCtx = ctx s.log.Infof("[gRPC] server listening on: %s", s.lis.Addr().String()) s.health.Resume() return s.Serve(s.lis) } // Stop stop the gRPC server. func (s *Server) Stop(ctx context.Context) error { s.GracefulStop() s.health.Shutdown() s.log.Info("[gRPC] server stopping") return nil } func (s *Server) listenAndEndpoint() error { if s.lis == nil { lis, err := net.Listen(s.network, s.address) if err != nil { return err } s.lis = lis } addr, err := host.Extract(s.address, s.lis) if err != nil { _ = s.lis.Close() return err } s.endpoint = endpoint.NewEndpoint("grpc", addr, s.tlsConf != nil) return nil } func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { ctx, cancel := ic.Merge(ctx, s.baseCtx) defer cancel() md, _ := grpcmd.FromIncomingContext(ctx) replyHeader := grpcmd.MD{} ctx = transport.NewServerContext(ctx, &Transport{ endpoint: s.endpoint.String(), operation: info.FullMethod, reqHeader: headerCarrier(md), replyHeader: headerCarrier(replyHeader), }) if s.timeout > 0 { ctx, cancel = context.WithTimeout(ctx, s.timeout) defer cancel() } h := func(ctx context.Context, req interface{}) (interface{}, error) { return handler(ctx, req) } if len(s.middleware) > 0 { h = middleware.Chain(s.middleware...)(h) } reply, err := h(ctx, req) if len(replyHeader) > 0 { _ = grpc.SetHeader(ctx, replyHeader) } return reply, err } }