From d0a0edf67be8ec9787478984cf7a67d516e14e58 Mon Sep 17 00:00:00 2001 From: Tony Chen Date: Mon, 11 Jul 2022 19:32:54 +0800 Subject: [PATCH] feat(http): add http router walk (#2181) * add http router walk Co-authored-by: chenzhihui --- .github/workflows/go.yml | 2 +- transport/grpc/server.go | 7 +++-- transport/grpc/server_test.go | 8 ++++- transport/http/router.go | 9 ++++++ transport/http/server.go | 27 +++++++++++++++-- transport/http/server_test.go | 57 ++++++++++++++++++++++++++++++----- 6 files changed, 95 insertions(+), 15 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 4d65cdad5..3794a7c0f 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -36,7 +36,7 @@ jobs: - "8848:8848" - "9848:9848" polaris: - image: polarismesh/polaris-server-standalone:latest + image: polarismesh/polaris-server-standalone:v1.9.0 ports: - 8090:8090 - 8091:8091 diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 0131692e3..2b5bde745 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -199,13 +199,14 @@ func (s *Server) listenAndEndpoint() error { s.err = err return err } - addr, err := host.Extract(s.address, lis) + s.lis = lis + } + if s.endpoint == nil { + addr, err := host.Extract(s.address, s.lis) if err != nil { - _ = s.lis.Close() s.err = err return err } - s.lis = lis s.endpoint = endpoint.NewEndpoint(endpoint.Scheme("grpc", s.tlsConf != nil), addr) } return s.err diff --git a/transport/grpc/server_test.go b/transport/grpc/server_test.go index 3a8141c24..368c32953 100644 --- a/transport/grpc/server_test.go +++ b/transport/grpc/server_test.go @@ -289,10 +289,16 @@ func TestServer_unaryServerInterceptor(t *testing.T) { } func TestListener(t *testing.T) { - lis := &net.TCPListener{} + lis, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } s := &Server{} Listener(lis)(s) if !reflect.DeepEqual(lis, s.lis) { t.Errorf("expect %v, got %v", lis, s.lis) } + if e, err := s.Endpoint(); err != nil || e == nil { + t.Errorf("expect not empty") + } } diff --git a/transport/http/router.go b/transport/http/router.go index d8d5c08e8..1575ff112 100644 --- a/transport/http/router.go +++ b/transport/http/router.go @@ -6,6 +6,15 @@ import ( "sync" ) +// WalkRouteFunc is the type of the function called for each route visited by Walk. +type WalkRouteFunc func(RouteInfo) error + +// RouteInfo is an HTTP route info. +type RouteInfo struct { + Path string + Method string +} + // HandlerFunc defines a function to serve HTTP requests. type HandlerFunc func(Context) error diff --git a/transport/http/server.go b/transport/http/server.go index a0d4f14f3..9bb9ffea5 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -157,6 +157,26 @@ func NewServer(opts ...ServerOption) *Server { return srv } +// WalkRoute walks the router and all its sub-routers, calling walkFn for each route in the tree. +func (s *Server) WalkRoute(fn WalkRouteFunc) error { + return s.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { + methods, err := route.GetMethods() + if err != nil { + return nil // ignore no methods + } + path, err := route.GetPathTemplate() + if err != nil { + return err + } + for _, method := range methods { + if err := fn(RouteInfo{Method: method, Path: path}); err != nil { + return err + } + } + return nil + }) +} + // Route registers an HTTP router. func (s *Server) Route(prefix string, filters ...FilterFunc) *Router { return newRouter(prefix, s, filters...) @@ -268,13 +288,14 @@ func (s *Server) listenAndEndpoint() error { s.err = err return err } - addr, err := host.Extract(s.address, lis) + s.lis = lis + } + if s.endpoint == nil { + addr, err := host.Extract(s.address, s.lis) if err != nil { - _ = s.lis.Close() s.err = err return err } - s.lis = lis s.endpoint = endpoint.NewEndpoint(endpoint.Scheme("http", s.tlsConf != nil), addr) } return s.err diff --git a/transport/http/server_test.go b/transport/http/server_test.go index ea60d0cf4..a6f05080a 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -19,6 +19,10 @@ import ( "github.com/go-kratos/kratos/v2/internal/host" ) +var h = func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(testData{Path: r.RequestURI}) +} + type testKey struct{} type testData struct { @@ -38,15 +42,48 @@ func newHandleFuncWrapper(fn http.HandlerFunc) http.Handler { return &handleFuncWrapper{fn: fn} } -func TestServer(t *testing.T) { - fn := func(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(testData{Path: r.RequestURI}) +func TestServeHTTP(t *testing.T) { + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + mux := NewServer(Listener(ln)) + mux.HandleFunc("/index", h) + mux.Route("/errors").GET("/cause", func(ctx Context) error { + return errors.BadRequest("xxx", "zzz"). + WithMetadata(map[string]string{"foo": "bar"}). + WithCause(fmt.Errorf("error cause")) + }) + if err = mux.WalkRoute(func(r RouteInfo) error { + t.Logf("WalkRoute: %+v", r) + return nil + }); err != nil { + t.Fatal(err) + } + if e, err := mux.Endpoint(); err != nil || e == nil || strings.HasSuffix(e.Host, ":0") { + t.Fatal(e, err) + } + srv := http.Server{Handler: mux} + go func() { + if err := srv.Serve(ln); err != nil { + if errors.Is(err, http.ErrServerClosed) { + return + } + panic(err) + } + }() + time.Sleep(time.Second) + if err := srv.Shutdown(context.Background()); err != nil { + t.Log(err) } +} + +func TestServer(t *testing.T) { ctx := context.Background() srv := NewServer() - srv.Handle("/index", newHandleFuncWrapper(fn)) - srv.HandleFunc("/index/{id:[0-9]+}", fn) - srv.HandlePrefix("/test/prefix", newHandleFuncWrapper(fn)) + srv.Handle("/index", newHandleFuncWrapper(h)) + srv.HandleFunc("/index/{id:[0-9]+}", h) + srv.HandlePrefix("/test/prefix", newHandleFuncWrapper(h)) srv.HandleHeader("content-type", "application/grpc-web+json", func(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(testData{Path: r.RequestURI}) }) @@ -333,10 +370,16 @@ func TestStrictSlash(t *testing.T) { } func TestListener(t *testing.T) { - lis := &net.TCPListener{} + lis, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } s := &Server{} Listener(lis)(s) if !reflect.DeepEqual(s.lis, lis) { t.Errorf("expected %v got %v", lis, s.lis) } + if e, err := s.Endpoint(); err != nil || e == nil { + t.Errorf("expected not empty") + } }