diff --git a/transport/http/route_test.go b/transport/http/route_test.go index c9046c4f0..074df92ea 100644 --- a/transport/http/route_test.go +++ b/transport/http/route_test.go @@ -17,6 +17,17 @@ type User struct { Name string `json:"name"` } +func corsFilter(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions { + log.Println("cors:", r.Method, r.RequestURI) + w.Header().Set("Access-Control-Allow-Methods", r.Method) + return + } + next.ServeHTTP(w, r) + }) +} + func authFilter(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Do stuff here @@ -37,7 +48,7 @@ func loggingFilter(next http.Handler) http.Handler { func TestRoute(t *testing.T) { ctx := context.Background() srv := NewServer( - Filter(loggingFilter), + Filter(corsFilter, loggingFilter), ) route := srv.Route("/v1") route.GET("/users/{name}", func(ctx Context) error { @@ -141,4 +152,17 @@ func testRoute(t *testing.T, srv *Server) { if u.Name != "bar" { t.Fatalf("got %s want bar", u.Name) } + // OPTIONS + req, _ = http.NewRequest("OPTIONS", base+"/users", nil) + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("code: %d", resp.StatusCode) + } + if resp.Header.Get("Access-Control-Allow-Methods") != "OPTIONS" { + t.Fatal("cors failed") + } } diff --git a/transport/http/server.go b/transport/http/server.go index f8cc7b35e..709fe4c7e 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -121,7 +121,9 @@ func NewServer(opts ...ServerOption) *Server { for _, o := range opts { o(srv) } - srv.Server = &http.Server{Handler: srv} + srv.Server = &http.Server{ + Handler: FilterChain(srv.filters...)(srv), + } srv.router = mux.NewRouter() srv.router.Use(srv.filter()) return srv @@ -154,7 +156,6 @@ func (s *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) { func (s *Server) filter() mux.MiddlewareFunc { return func(next http.Handler) http.Handler { - next = FilterChain(s.filters...)(next) return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { ctx, cancel := ic.Merge(req.Context(), s.ctx) defer cancel()