diff --git a/transport/http/context.go b/transport/http/context.go index b70e39035..4225483f1 100644 --- a/transport/http/context.go +++ b/transport/http/context.go @@ -43,10 +43,27 @@ type Context interface { Reset(http.ResponseWriter, *http.Request) } +type responseWriter struct { + code int + w http.ResponseWriter +} + +func (w *responseWriter) rest(res http.ResponseWriter) { + w.w = res + w.code = http.StatusOK +} +func (w *responseWriter) Header() http.Header { return w.w.Header() } +func (w *responseWriter) WriteHeader(statusCode int) { w.code = statusCode } +func (w *responseWriter) Write(data []byte) (int, error) { + w.w.WriteHeader(w.code) + return w.w.Write(data) +} + type wrapper struct { route *Route req *http.Request res http.ResponseWriter + w responseWriter } func (c *wrapper) Header() http.Header { @@ -83,14 +100,14 @@ func (c *wrapper) Returns(v interface{}, err error) error { if err != nil { return err } - if err := c.route.srv.enc(c.res, c.req, v); err != nil { + if err := c.route.srv.enc(&c.w, c.req, v); err != nil { return err } return nil } func (c *wrapper) Result(code int, v interface{}) error { - c.res.WriteHeader(code) - if err := c.route.srv.enc(c.res, c.req, v); err != nil { + c.w.WriteHeader(code) + if err := c.route.srv.enc(&c.w, c.req, v); err != nil { return err } return nil @@ -124,6 +141,7 @@ func (c *wrapper) Stream(code int, contentType string, rd io.Reader) error { return err } func (c *wrapper) Reset(res http.ResponseWriter, req *http.Request) { + c.w.rest(res) c.res = res c.req = req } diff --git a/transport/http/route_test.go b/transport/http/route_test.go index a32b28a6d..c9046c4f0 100644 --- a/transport/http/route_test.go +++ b/transport/http/route_test.go @@ -91,6 +91,9 @@ func testRoute(t *testing.T, srv *Server) { if resp.StatusCode != 200 { t.Fatalf("code: %d", resp.StatusCode) } + if v := resp.Header.Get("Content-Type"); v != "application/json" { + t.Fatalf("contentType: %s", v) + } u := new(User) if err := json.NewDecoder(resp.Body).Decode(u); err != nil { t.Fatal(err) @@ -107,6 +110,9 @@ func testRoute(t *testing.T, srv *Server) { if resp.StatusCode != 201 { t.Fatalf("code: %d", resp.StatusCode) } + if v := resp.Header.Get("Content-Type"); v != "application/json" { + t.Fatalf("contentType: %s", v) + } u = new(User) if err = json.NewDecoder(resp.Body).Decode(u); err != nil { t.Fatal(err) @@ -125,6 +131,9 @@ func testRoute(t *testing.T, srv *Server) { if resp.StatusCode != 200 { t.Fatalf("code: %d", resp.StatusCode) } + if v := resp.Header.Get("Content-Type"); v != "application/json" { + t.Fatalf("contentType: %s", v) + } u = new(User) if err = json.NewDecoder(resp.Body).Decode(u); err != nil { t.Fatal(err)