|
|
|
package grpc
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"crypto/tls"
|
|
|
|
"net"
|
|
|
|
"net/url"
|
|
|
|
"sync"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/go-kratos/kratos/v2/internal/endpoint"
|
|
|
|
|
|
|
|
apimd "github.com/go-kratos/kratos/v2/api/metadata"
|
|
|
|
ic "github.com/go-kratos/kratos/v2/internal/context"
|
|
|
|
"github.com/go-kratos/kratos/v2/internal/host"
|
|
|
|
"github.com/go-kratos/kratos/v2/log"
|
|
|
|
"github.com/go-kratos/kratos/v2/middleware"
|
|
|
|
"github.com/go-kratos/kratos/v2/transport"
|
|
|
|
|
|
|
|
"google.golang.org/grpc"
|
|
|
|
"google.golang.org/grpc/credentials"
|
|
|
|
"google.golang.org/grpc/health"
|
|
|
|
"google.golang.org/grpc/health/grpc_health_v1"
|
|
|
|
grpcmd "google.golang.org/grpc/metadata"
|
|
|
|
"google.golang.org/grpc/reflection"
|
|
|
|
)
|
|
|
|
|
|
|
|
var _ transport.Server = (*Server)(nil)
|
|
|
|
var _ transport.Endpointer = (*Server)(nil)
|
|
|
|
|
|
|
|
// ServerOption is gRPC server option.
|
|
|
|
type ServerOption func(o *Server)
|
|
|
|
|
|
|
|
// Network with server network.
|
|
|
|
func Network(network string) ServerOption {
|
|
|
|
return func(s *Server) {
|
|
|
|
s.network = network
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Address with server address.
|
|
|
|
func Address(addr string) ServerOption {
|
|
|
|
return func(s *Server) {
|
|
|
|
s.address = addr
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Timeout with server timeout.
|
|
|
|
func Timeout(timeout time.Duration) ServerOption {
|
|
|
|
return func(s *Server) {
|
|
|
|
s.timeout = timeout
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Logger with server logger.
|
|
|
|
func Logger(logger log.Logger) ServerOption {
|
|
|
|
return func(s *Server) {
|
|
|
|
s.log = log.NewHelper(logger)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Middleware with server middleware.
|
|
|
|
func Middleware(m ...middleware.Middleware) ServerOption {
|
|
|
|
return func(s *Server) {
|
|
|
|
s.middleware = m
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// TLSConfig with TLS config.
|
|
|
|
func TLSConfig(c *tls.Config) ServerOption {
|
|
|
|
return func(s *Server) {
|
|
|
|
s.tlsConf = c
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the server.
|
|
|
|
func UnaryInterceptor(in ...grpc.UnaryServerInterceptor) ServerOption {
|
|
|
|
return func(s *Server) {
|
|
|
|
s.ints = in
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Options with grpc options.
|
|
|
|
func Options(opts ...grpc.ServerOption) ServerOption {
|
|
|
|
return func(s *Server) {
|
|
|
|
s.grpcOpts = opts
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Server is a gRPC server wrapper.
|
|
|
|
type Server struct {
|
|
|
|
*grpc.Server
|
|
|
|
ctx context.Context
|
|
|
|
tlsConf *tls.Config
|
|
|
|
lis net.Listener
|
|
|
|
once sync.Once
|
|
|
|
err error
|
|
|
|
network string
|
|
|
|
address string
|
|
|
|
endpoint *url.URL
|
|
|
|
timeout time.Duration
|
|
|
|
log *log.Helper
|
|
|
|
middleware []middleware.Middleware
|
|
|
|
ints []grpc.UnaryServerInterceptor
|
|
|
|
grpcOpts []grpc.ServerOption
|
|
|
|
health *health.Server
|
|
|
|
metadata *apimd.Server
|
|
|
|
}
|
|
|
|
|
|
|
|
// NewServer creates a gRPC server by options.
|
|
|
|
func NewServer(opts ...ServerOption) *Server {
|
|
|
|
srv := &Server{
|
|
|
|
network: "tcp",
|
|
|
|
address: ":0",
|
|
|
|
timeout: 1 * time.Second,
|
|
|
|
health: health.NewServer(),
|
|
|
|
log: log.NewHelper(log.DefaultLogger),
|
|
|
|
}
|
|
|
|
for _, o := range opts {
|
|
|
|
o(srv)
|
|
|
|
}
|
|
|
|
var ints = []grpc.UnaryServerInterceptor{
|
|
|
|
srv.unaryServerInterceptor(),
|
|
|
|
}
|
|
|
|
if len(srv.ints) > 0 {
|
|
|
|
ints = append(ints, srv.ints...)
|
|
|
|
}
|
|
|
|
var grpcOpts = []grpc.ServerOption{
|
|
|
|
grpc.ChainUnaryInterceptor(ints...),
|
|
|
|
}
|
|
|
|
if srv.tlsConf != nil {
|
|
|
|
grpcOpts = append(grpcOpts, grpc.Creds(credentials.NewTLS(srv.tlsConf)))
|
|
|
|
}
|
|
|
|
if len(srv.grpcOpts) > 0 {
|
|
|
|
grpcOpts = append(grpcOpts, srv.grpcOpts...)
|
|
|
|
}
|
|
|
|
srv.Server = grpc.NewServer(grpcOpts...)
|
|
|
|
srv.metadata = apimd.NewServer(srv.Server)
|
|
|
|
// internal register
|
|
|
|
grpc_health_v1.RegisterHealthServer(srv.Server, srv.health)
|
|
|
|
apimd.RegisterMetadataServer(srv.Server, srv.metadata)
|
|
|
|
reflection.Register(srv.Server)
|
|
|
|
return srv
|
|
|
|
}
|
|
|
|
|
|
|
|
// Endpoint return a real address to registry endpoint.
|
|
|
|
// examples:
|
|
|
|
// grpc://127.0.0.1:9000?isSecure=false
|
|
|
|
func (s *Server) Endpoint() (*url.URL, error) {
|
|
|
|
s.once.Do(func() {
|
|
|
|
lis, err := net.Listen(s.network, s.address)
|
|
|
|
if err != nil {
|
|
|
|
s.err = err
|
|
|
|
return
|
|
|
|
}
|
|
|
|
addr, err := host.Extract(s.address, lis)
|
|
|
|
if err != nil {
|
|
|
|
_ = lis.Close()
|
|
|
|
s.err = err
|
|
|
|
return
|
|
|
|
}
|
|
|
|
s.lis = lis
|
|
|
|
s.endpoint = endpoint.NewEndpoint("grpc", addr, s.tlsConf != nil)
|
|
|
|
})
|
|
|
|
if s.err != nil {
|
|
|
|
return nil, s.err
|
|
|
|
}
|
|
|
|
return s.endpoint, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Start start the gRPC server.
|
|
|
|
func (s *Server) Start(ctx context.Context) error {
|
|
|
|
if _, err := s.Endpoint(); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
s.ctx = ctx
|
|
|
|
s.log.Infof("[gRPC] server listening on: %s", s.lis.Addr().String())
|
|
|
|
s.health.Resume()
|
|
|
|
return s.Serve(s.lis)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Stop stop the gRPC server.
|
|
|
|
func (s *Server) Stop(ctx context.Context) error {
|
|
|
|
s.GracefulStop()
|
|
|
|
s.health.Shutdown()
|
|
|
|
s.log.Info("[gRPC] server stopping")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
|
|
|
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
|
|
|
ctx, cancel := ic.Merge(ctx, s.ctx)
|
|
|
|
defer cancel()
|
|
|
|
md, _ := grpcmd.FromIncomingContext(ctx)
|
|
|
|
replyHeader := grpcmd.MD{}
|
|
|
|
ctx = transport.NewServerContext(ctx, &Transport{
|
|
|
|
endpoint: s.endpoint.String(),
|
|
|
|
operation: info.FullMethod,
|
|
|
|
reqHeader: headerCarrier(md),
|
|
|
|
replyHeader: headerCarrier(replyHeader),
|
|
|
|
})
|
|
|
|
if s.timeout > 0 {
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, s.timeout)
|
|
|
|
defer cancel()
|
|
|
|
}
|
|
|
|
h := func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
|
|
return handler(ctx, req)
|
|
|
|
}
|
|
|
|
if len(s.middleware) > 0 {
|
|
|
|
h = middleware.Chain(s.middleware...)(h)
|
|
|
|
}
|
|
|
|
reply, err := h(ctx, req)
|
|
|
|
if len(replyHeader) > 0 {
|
|
|
|
_ = grpc.SetHeader(ctx, replyHeader)
|
|
|
|
}
|
|
|
|
return reply, err
|
|
|
|
}
|
|
|
|
}
|