You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
kratos/transport/grpc/server.go

246 lines
5.7 KiB

package grpc
import (
"context"
"crypto/tls"
"net"
"net/url"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/admin"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/health"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/reflection"
apimd "github.com/go-kratos/kratos/v2/api/metadata"
"github.com/go-kratos/kratos/v2/internal/endpoint"
"github.com/go-kratos/kratos/v2/internal/host"
"github.com/go-kratos/kratos/v2/internal/matcher"
"github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
)
var (
_ transport.Server = (*Server)(nil)
_ transport.Endpointer = (*Server)(nil)
)
// ServerOption is gRPC server option.
type ServerOption func(o *Server)
// Network with server network.
func Network(network string) ServerOption {
return func(s *Server) {
s.network = network
}
}
// Address with server address.
func Address(addr string) ServerOption {
return func(s *Server) {
s.address = addr
}
}
// Endpoint with server address.
func Endpoint(endpoint *url.URL) ServerOption {
return func(s *Server) {
s.endpoint = endpoint
}
}
// Timeout with server timeout.
func Timeout(timeout time.Duration) ServerOption {
return func(s *Server) {
s.timeout = timeout
}
}
// Logger with server logger.
// Deprecated: use global logger instead.
func Logger(_ log.Logger) ServerOption {
return func(s *Server) {}
}
// Middleware with server middleware.
func Middleware(m ...middleware.Middleware) ServerOption {
return func(s *Server) {
s.middleware.Use(m...)
}
}
// CustomHealth Checks server.
func CustomHealth() ServerOption {
return func(s *Server) {
s.customHealth = true
}
}
// TLSConfig with TLS config.
func TLSConfig(c *tls.Config) ServerOption {
return func(s *Server) {
s.tlsConf = c
}
}
// Listener with server lis
func Listener(lis net.Listener) ServerOption {
return func(s *Server) {
s.lis = lis
}
}
// UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the server.
func UnaryInterceptor(in ...grpc.UnaryServerInterceptor) ServerOption {
return func(s *Server) {
s.unaryInts = in
}
}
// StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the server.
func StreamInterceptor(in ...grpc.StreamServerInterceptor) ServerOption {
return func(s *Server) {
s.streamInts = in
}
}
// Options with grpc options.
func Options(opts ...grpc.ServerOption) ServerOption {
return func(s *Server) {
s.grpcOpts = opts
}
}
// Server is a gRPC server wrapper.
type Server struct {
*grpc.Server
baseCtx context.Context
tlsConf *tls.Config
lis net.Listener
err error
network string
address string
endpoint *url.URL
timeout time.Duration
middleware matcher.Matcher
unaryInts []grpc.UnaryServerInterceptor
streamInts []grpc.StreamServerInterceptor
grpcOpts []grpc.ServerOption
health *health.Server
customHealth bool
metadata *apimd.Server
adminClean func()
}
// NewServer creates a gRPC server by options.
func NewServer(opts ...ServerOption) *Server {
srv := &Server{
baseCtx: context.Background(),
network: "tcp",
address: ":0",
timeout: 1 * time.Second,
health: health.NewServer(),
middleware: matcher.New(),
}
for _, o := range opts {
o(srv)
}
unaryInts := []grpc.UnaryServerInterceptor{
srv.unaryServerInterceptor(),
}
streamInts := []grpc.StreamServerInterceptor{
srv.streamServerInterceptor(),
}
if len(srv.unaryInts) > 0 {
unaryInts = append(unaryInts, srv.unaryInts...)
}
if len(srv.streamInts) > 0 {
streamInts = append(streamInts, srv.streamInts...)
}
grpcOpts := []grpc.ServerOption{
grpc.ChainUnaryInterceptor(unaryInts...),
grpc.ChainStreamInterceptor(streamInts...),
}
if srv.tlsConf != nil {
grpcOpts = append(grpcOpts, grpc.Creds(credentials.NewTLS(srv.tlsConf)))
}
if len(srv.grpcOpts) > 0 {
grpcOpts = append(grpcOpts, srv.grpcOpts...)
}
srv.Server = grpc.NewServer(grpcOpts...)
srv.metadata = apimd.NewServer(srv.Server)
// internal register
if !srv.customHealth {
grpc_health_v1.RegisterHealthServer(srv.Server, srv.health)
}
apimd.RegisterMetadataServer(srv.Server, srv.metadata)
reflection.Register(srv.Server)
// admin register
srv.adminClean, _ = admin.Register(srv.Server)
return srv
}
// Use uses a service middleware with selector.
// selector:
// - '/*'
// - '/helloworld.v1.Greeter/*'
// - '/helloworld.v1.Greeter/SayHello'
func (s *Server) Use(selector string, m ...middleware.Middleware) {
s.middleware.Add(selector, m...)
}
// Endpoint return a real address to registry endpoint.
// examples:
//
// grpc://127.0.0.1:9000?isSecure=false
func (s *Server) Endpoint() (*url.URL, error) {
if err := s.listenAndEndpoint(); err != nil {
return nil, s.err
}
return s.endpoint, nil
}
// Start start the gRPC server.
func (s *Server) Start(ctx context.Context) error {
if err := s.listenAndEndpoint(); err != nil {
return s.err
}
s.baseCtx = ctx
log.Infof("[gRPC] server listening on: %s", s.lis.Addr().String())
s.health.Resume()
return s.Serve(s.lis)
}
// Stop stop the gRPC server.
func (s *Server) Stop(_ context.Context) error {
if s.adminClean != nil {
s.adminClean()
}
s.health.Shutdown()
s.GracefulStop()
log.Info("[gRPC] server stopping")
return nil
}
func (s *Server) listenAndEndpoint() error {
if s.lis == nil {
lis, err := net.Listen(s.network, s.address)
if err != nil {
s.err = err
return err
}
s.lis = lis
}
if s.endpoint == nil {
addr, err := host.Extract(s.address, s.lis)
if err != nil {
s.err = err
return err
}
s.endpoint = endpoint.NewEndpoint(endpoint.Scheme("grpc", s.tlsConf != nil), addr)
}
return s.err
}