package http import ( "context" "crypto/tls" "errors" "net" "net/http" "net/url" "sync" "time" "github.com/go-kratos/kratos/v2/internal/endpoint" "github.com/go-kratos/kratos/v2/internal/host" "github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" "github.com/gorilla/mux" ) var ( _ transport.Server = (*Server)(nil) _ transport.Endpointer = (*Server)(nil) ) // ServerOption is an HTTP server option. type ServerOption func(*Server) // Network with server network. func Network(network string) ServerOption { return func(s *Server) { s.network = network } } // Address with server address. func Address(addr string) ServerOption { return func(s *Server) { s.address = addr } } // Timeout with server timeout. func Timeout(timeout time.Duration) ServerOption { return func(s *Server) { s.timeout = timeout } } // Logger with server logger. func Logger(logger log.Logger) ServerOption { return func(s *Server) { s.log = log.NewHelper(logger) } } // Middleware with service middleware option. func Middleware(m ...middleware.Middleware) ServerOption { return func(o *Server) { o.ms = m } } // Filter with HTTP middleware option. func Filter(filters ...FilterFunc) ServerOption { return func(o *Server) { o.filters = filters } } // RequestDecoder with request decoder. func RequestDecoder(dec DecodeRequestFunc) ServerOption { return func(o *Server) { o.dec = dec } } // ResponseEncoder with response encoder. func ResponseEncoder(en EncodeResponseFunc) ServerOption { return func(o *Server) { o.enc = en } } // ErrorEncoder with error encoder. func ErrorEncoder(en EncodeErrorFunc) ServerOption { return func(o *Server) { o.ene = en } } // Endpoint with server endpoint. func Endpoint(endpoint *url.URL) ServerOption { return func(o *Server) { o.endpoint = endpoint } } // TLSConfig with TLS config. func TLSConfig(c *tls.Config) ServerOption { return func(o *Server) { o.tlsConf = c } } // Server is an HTTP server wrapper. type Server struct { *http.Server lis net.Listener tlsConf *tls.Config once sync.Once endpoint *url.URL err error network string address string timeout time.Duration filters []FilterFunc ms []middleware.Middleware dec DecodeRequestFunc enc EncodeResponseFunc ene EncodeErrorFunc router *mux.Router log *log.Helper } // NewServer creates an HTTP server by options. func NewServer(opts ...ServerOption) *Server { srv := &Server{ network: "tcp", address: ":0", timeout: 1 * time.Second, dec: DefaultRequestDecoder, enc: DefaultResponseEncoder, ene: DefaultErrorEncoder, log: log.NewHelper(log.DefaultLogger), } for _, o := range opts { o(srv) } srv.router = mux.NewRouter() srv.router.Use(srv.filter()) srv.Server = &http.Server{ Handler: FilterChain(srv.filters...)(srv.router), TLSConfig: srv.tlsConf, } return srv } // Route registers an HTTP router. func (s *Server) Route(prefix string, filters ...FilterFunc) *Router { return newRouter(prefix, s, filters...) } // Handle registers a new route with a matcher for the URL path. func (s *Server) Handle(path string, h http.Handler) { s.router.Handle(path, h) } // HandlePrefix registers a new route with a matcher for the URL path prefix. func (s *Server) HandlePrefix(prefix string, h http.Handler) { s.router.PathPrefix(prefix).Handler(h) } // HandleFunc registers a new route with a matcher for the URL path. func (s *Server) HandleFunc(path string, h http.HandlerFunc) { s.router.HandleFunc(path, h) } // ServeHTTP should write reply headers and data to the ResponseWriter and then return. func (s *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) { s.Handler.ServeHTTP(res, req) } func (s *Server) filter() mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() if s.timeout > 0 { ctx, cancel = context.WithTimeout(ctx, s.timeout) defer cancel() } pathTemplate := req.URL.Path if route := mux.CurrentRoute(req); route != nil { // /path/123 -> /path/{id} pathTemplate, _ = route.GetPathTemplate() } tr := &Transport{ endpoint: s.endpoint.String(), operation: pathTemplate, reqHeader: headerCarrier(req.Header), replyHeader: headerCarrier(w.Header()), request: req, pathTemplate: pathTemplate, } ctx = transport.NewServerContext(ctx, tr) next.ServeHTTP(w, req.WithContext(ctx)) }) } } // Endpoint return a real address to registry endpoint. // examples: // http://127.0.0.1:8000?isSecure=false func (s *Server) Endpoint() (*url.URL, error) { s.once.Do(func() { if s.endpoint != nil { return } lis, err := net.Listen(s.network, s.address) if err != nil { s.err = err return } addr, err := host.Extract(s.address, lis) if err != nil { lis.Close() s.err = err return } s.lis = lis s.endpoint = endpoint.NewEndpoint("http", addr, s.tlsConf != nil) }) if s.err != nil { return nil, s.err } return s.endpoint, nil } // Start start the HTTP server. func (s *Server) Start(ctx context.Context) error { if _, err := s.Endpoint(); err != nil { return err } s.BaseContext = func(net.Listener) context.Context { return ctx } s.log.Infof("[HTTP] server listening on: %s", s.lis.Addr().String()) var err error if s.tlsConf != nil { err = s.ServeTLS(s.lis, "", "") } else { err = s.Serve(s.lis) } if !errors.Is(err, http.ErrServerClosed) { return err } return nil } // Stop stop the HTTP server. func (s *Server) Stop(ctx context.Context) error { s.log.Info("[HTTP] server stopping") return s.Shutdown(ctx) }