feat(http): add http router walk (#2181)

* add http router walk

Co-authored-by: chenzhihui <chenzhihui@bilibili.com>
pull/2085/head
Tony Chen 2 years ago committed by GitHub
parent afd108cdc7
commit d0a0edf67b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      .github/workflows/go.yml
  2. 7
      transport/grpc/server.go
  3. 8
      transport/grpc/server_test.go
  4. 9
      transport/http/router.go
  5. 27
      transport/http/server.go
  6. 57
      transport/http/server_test.go

@ -36,7 +36,7 @@ jobs:
- "8848:8848"
- "9848:9848"
polaris:
image: polarismesh/polaris-server-standalone:latest
image: polarismesh/polaris-server-standalone:v1.9.0
ports:
- 8090:8090
- 8091:8091

@ -199,13 +199,14 @@ func (s *Server) listenAndEndpoint() error {
s.err = err
return err
}
addr, err := host.Extract(s.address, lis)
s.lis = lis
}
if s.endpoint == nil {
addr, err := host.Extract(s.address, s.lis)
if err != nil {
_ = s.lis.Close()
s.err = err
return err
}
s.lis = lis
s.endpoint = endpoint.NewEndpoint(endpoint.Scheme("grpc", s.tlsConf != nil), addr)
}
return s.err

@ -289,10 +289,16 @@ func TestServer_unaryServerInterceptor(t *testing.T) {
}
func TestListener(t *testing.T) {
lis := &net.TCPListener{}
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
s := &Server{}
Listener(lis)(s)
if !reflect.DeepEqual(lis, s.lis) {
t.Errorf("expect %v, got %v", lis, s.lis)
}
if e, err := s.Endpoint(); err != nil || e == nil {
t.Errorf("expect not empty")
}
}

@ -6,6 +6,15 @@ import (
"sync"
)
// WalkRouteFunc is the type of the function called for each route visited by Walk.
type WalkRouteFunc func(RouteInfo) error
// RouteInfo is an HTTP route info.
type RouteInfo struct {
Path string
Method string
}
// HandlerFunc defines a function to serve HTTP requests.
type HandlerFunc func(Context) error

@ -157,6 +157,26 @@ func NewServer(opts ...ServerOption) *Server {
return srv
}
// WalkRoute walks the router and all its sub-routers, calling walkFn for each route in the tree.
func (s *Server) WalkRoute(fn WalkRouteFunc) error {
return s.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
methods, err := route.GetMethods()
if err != nil {
return nil // ignore no methods
}
path, err := route.GetPathTemplate()
if err != nil {
return err
}
for _, method := range methods {
if err := fn(RouteInfo{Method: method, Path: path}); err != nil {
return err
}
}
return nil
})
}
// Route registers an HTTP router.
func (s *Server) Route(prefix string, filters ...FilterFunc) *Router {
return newRouter(prefix, s, filters...)
@ -268,13 +288,14 @@ func (s *Server) listenAndEndpoint() error {
s.err = err
return err
}
addr, err := host.Extract(s.address, lis)
s.lis = lis
}
if s.endpoint == nil {
addr, err := host.Extract(s.address, s.lis)
if err != nil {
_ = s.lis.Close()
s.err = err
return err
}
s.lis = lis
s.endpoint = endpoint.NewEndpoint(endpoint.Scheme("http", s.tlsConf != nil), addr)
}
return s.err

@ -19,6 +19,10 @@ import (
"github.com/go-kratos/kratos/v2/internal/host"
)
var h = func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(testData{Path: r.RequestURI})
}
type testKey struct{}
type testData struct {
@ -38,15 +42,48 @@ func newHandleFuncWrapper(fn http.HandlerFunc) http.Handler {
return &handleFuncWrapper{fn: fn}
}
func TestServer(t *testing.T) {
fn := func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(testData{Path: r.RequestURI})
func TestServeHTTP(t *testing.T) {
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
mux := NewServer(Listener(ln))
mux.HandleFunc("/index", h)
mux.Route("/errors").GET("/cause", func(ctx Context) error {
return errors.BadRequest("xxx", "zzz").
WithMetadata(map[string]string{"foo": "bar"}).
WithCause(fmt.Errorf("error cause"))
})
if err = mux.WalkRoute(func(r RouteInfo) error {
t.Logf("WalkRoute: %+v", r)
return nil
}); err != nil {
t.Fatal(err)
}
if e, err := mux.Endpoint(); err != nil || e == nil || strings.HasSuffix(e.Host, ":0") {
t.Fatal(e, err)
}
srv := http.Server{Handler: mux}
go func() {
if err := srv.Serve(ln); err != nil {
if errors.Is(err, http.ErrServerClosed) {
return
}
panic(err)
}
}()
time.Sleep(time.Second)
if err := srv.Shutdown(context.Background()); err != nil {
t.Log(err)
}
}
func TestServer(t *testing.T) {
ctx := context.Background()
srv := NewServer()
srv.Handle("/index", newHandleFuncWrapper(fn))
srv.HandleFunc("/index/{id:[0-9]+}", fn)
srv.HandlePrefix("/test/prefix", newHandleFuncWrapper(fn))
srv.Handle("/index", newHandleFuncWrapper(h))
srv.HandleFunc("/index/{id:[0-9]+}", h)
srv.HandlePrefix("/test/prefix", newHandleFuncWrapper(h))
srv.HandleHeader("content-type", "application/grpc-web+json", func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(testData{Path: r.RequestURI})
})
@ -333,10 +370,16 @@ func TestStrictSlash(t *testing.T) {
}
func TestListener(t *testing.T) {
lis := &net.TCPListener{}
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
s := &Server{}
Listener(lis)(s)
if !reflect.DeepEqual(s.lis, lis) {
t.Errorf("expected %v got %v", lis, s.lis)
}
if e, err := s.Endpoint(); err != nil || e == nil {
t.Errorf("expected not empty")
}
}

Loading…
Cancel
Save