diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 5acbbf1e7..2d3ddcc9f 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -4,7 +4,7 @@ import ( "context" "net" "net/url" - "strings" + "sync" "time" "github.com/go-kratos/kratos/v2/api/metadata" @@ -73,6 +73,8 @@ type Server struct { *grpc.Server ctx context.Context lis net.Listener + once sync.Once + err error network string address string endpoint *url.URL @@ -120,35 +122,33 @@ func NewServer(opts ...ServerOption) *Server { // examples: // grpc://127.0.0.1:9000?isSecure=false func (s *Server) Endpoint() (*url.URL, error) { - if s.lis == nil && strings.HasSuffix(s.address, ":0") { + s.once.Do(func() { lis, err := net.Listen(s.network, s.address) if err != nil { - return nil, err + s.err = err + return + } + addr, err := host.Extract(s.address, s.lis) + if err != nil { + lis.Close() + s.err = err + return } s.lis = lis + s.endpoint = &url.URL{Scheme: "grpc", Host: addr} + }) + if s.err != nil { + return nil, s.err } - addr, err := host.Extract(s.address, s.lis) - if err != nil { - return nil, err - } - u := &url.URL{ - Scheme: "grpc", - Host: addr, - } - s.endpoint = u - return u, nil + return s.endpoint, nil } // Start start the gRPC server. func (s *Server) Start(ctx context.Context) error { - s.ctx = ctx - if s.lis == nil { - lis, err := net.Listen(s.network, s.address) - if err != nil { - return err - } - s.lis = lis + if _, err := s.Endpoint(); err != nil { + return err } + s.ctx = ctx s.log.Infof("[gRPC] server listening on: %s", s.lis.Addr().String()) s.health.Resume() return s.Serve(s.lis) diff --git a/transport/http/server.go b/transport/http/server.go index 162746e32..302d4f10a 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -6,7 +6,7 @@ import ( "net" "net/http" "net/url" - "strings" + "sync" "time" ic "github.com/go-kratos/kratos/v2/internal/context" @@ -56,6 +56,8 @@ type Server struct { *http.Server ctx context.Context lis net.Listener + once sync.Once + err error network string address string endpoint *url.URL @@ -112,35 +114,33 @@ func (s *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) { // examples: // http://127.0.0.1:8000?isSecure=false func (s *Server) Endpoint() (*url.URL, error) { - if s.lis == nil && strings.HasSuffix(s.address, ":0") { + s.once.Do(func() { lis, err := net.Listen(s.network, s.address) if err != nil { - return nil, err + s.err = err + return + } + addr, err := host.Extract(s.address, s.lis) + if err != nil { + lis.Close() + s.err = err + return } s.lis = lis + s.endpoint = &url.URL{Scheme: "http", Host: addr} + }) + if s.err != nil { + return nil, s.err } - addr, err := host.Extract(s.address, s.lis) - if err != nil { - return nil, err - } - u := &url.URL{ - Scheme: "http", - Host: addr, - } - s.endpoint = u - return u, nil + return s.endpoint, nil } // Start start the HTTP server. func (s *Server) Start(ctx context.Context) error { - s.ctx = ctx - if s.lis == nil { - lis, err := net.Listen(s.network, s.address) - if err != nil { - return err - } - s.lis = lis + if _, err := s.Endpoint(); err != nil { + return err } + s.ctx = ctx s.log.Infof("[HTTP] server listening on: %s", s.lis.Addr().String()) if err := s.Serve(s.lis); !errors.Is(err, http.ErrServerClosed) { return err