diff --git a/transport/http/codec.go b/transport/http/codec.go index ef00ea6e9..ff734ba15 100644 --- a/transport/http/codec.go +++ b/transport/http/codec.go @@ -45,6 +45,9 @@ func DefaultResponseEncoder(w http.ResponseWriter, r *http.Request, v interface{ return err } w.Header().Set("Content-Type", httputil.ContentType(codec.Name())) + if code, ok := v.(StatusCoder); ok { + w.WriteHeader(code.StatusCode()) + } _, err = w.Write(data) if err != nil { return err @@ -76,3 +79,9 @@ func CodecForRequest(r *http.Request, name string) (encoding.Codec, bool) { } return encoding.GetCodec("json"), false } + +// StatusCoder is checked by DefaultResponseEncoder. If a response value implements +// StatusCoder, the StatusCode will be used to set http status code. +type StatusCoder interface { + StatusCode() int +} diff --git a/transport/http/codec_test.go b/transport/http/codec_test.go index 880890722..742dd3d9e 100644 --- a/transport/http/codec_test.go +++ b/transport/http/codec_test.go @@ -52,11 +52,21 @@ func (w *mockResponseWriter) WriteHeader(statusCode int) { w.StatusCode = statusCode } -type dataWithStatusCode struct { +type respData struct { A string `json:"a"` B int64 `json:"b"` } +type respDataWithStatusCode struct { + A string `json:"a"` + B int64 `json:"b"` + sCode int +} + +func (r *respDataWithStatusCode) StatusCode() int { + return r.sCode +} + func TestDefaultResponseEncoder(t *testing.T) { w := &mockResponseWriter{StatusCode: 200, header: make(nethttp.Header)} req1 := &nethttp.Request{ @@ -64,7 +74,7 @@ func TestDefaultResponseEncoder(t *testing.T) { } req1.Header.Set("Content-Type", "application/json") - v1 := &dataWithStatusCode{A: "1", B: 2} + v1 := &respData{A: "1", B: 2} err := DefaultResponseEncoder(w, req1, v1) if err != nil { t.Errorf("expected no error, got %v", err) @@ -129,3 +139,26 @@ func TestCodecForRequest(t *testing.T) { t.Errorf("expected %v, got %v", "json", c.Name()) } } + +func TestDefaultResponseEncoderWithStatusCoder(t *testing.T) { + w := &mockResponseWriter{StatusCode: 200, header: make(nethttp.Header)} + req1 := &nethttp.Request{ + Header: make(nethttp.Header), + } + req1.Header.Set("Content-Type", "application/json") + cusStatusCode := 201 + v1 := &respDataWithStatusCode{A: "1", B: 2, sCode: cusStatusCode} + err := DefaultResponseEncoder(w, req1, v1) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if !reflect.DeepEqual("application/json", w.Header().Get("Content-Type")) { + t.Errorf("expected %v, got %v", "application/json", w.Header().Get("Content-Type")) + } + if !reflect.DeepEqual(cusStatusCode, w.StatusCode) { + t.Errorf("expected %v, got %v", cusStatusCode, w.StatusCode) + } + if w.Data == nil { + t.Errorf("expected not nil, got %v", w.Data) + } +}