diff --git a/transport/http/client.go b/transport/http/client.go index 79ac4f911..82425a7b9 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -264,6 +264,43 @@ func (client *Client) invoke(ctx context.Context, req *http.Request, args interf return err } +// DoWithMiddleware send an HTTP request WithMiddleware and decodes the body of response into target. +// returns an error (of type *Error) if the response status code is not 2xx. +func (client *Client) DoWithMiddleware(req *http.Request, opts ...CallOption) (*http.Response, error) { + c := defaultCallInfo(req.URL.Path) + for _, o := range opts { + if err := o.before(&c); err != nil { + return nil, err + } + } + h := func(ctx context.Context, in interface{}) (interface{}, error) { + res, err := client.do(req) + if res != nil { + cs := csAttempt{res: res} + for _, o := range opts { + o.after(&c, &cs) + } + } + if err != nil { + return nil, err + } + return res, nil + } + if len(client.opts.middleware) > 0 { + h = middleware.Chain(client.opts.middleware...)(h) + } + resp, err := h(req.Context(), req) + if err != nil { + return nil, err + } + + response, ok := resp.(*http.Response) + if ok { + return response, nil + } + return nil, errors.New(500, "Client DoWithMiddleware Failed", "response convert failed ") +} + // Do send an HTTP request and decodes the body of response into target. // returns an error (of type *Error) if the response status code is not 2xx. func (client *Client) Do(req *http.Request, opts ...CallOption) (*http.Response, error) { diff --git a/transport/http/server_test.go b/transport/http/server_test.go index 05bc8399e..f00f4cb25 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/json" "fmt" + "github.com/go-kratos/kratos/v2/middleware" "io" "net" "net/http" @@ -192,7 +193,12 @@ func testClient(t *testing.T, srv *Server) { if err != nil { t.Fatal(err) } - client, err := NewClient(context.Background(), WithEndpoint(e.Host)) + client, err := NewClient(context.Background(), WithEndpoint(e.Host), WithMiddleware(func(handler middleware.Handler) middleware.Handler { + t.Logf("handle in middleware") + return func(ctx context.Context, req interface{}) (interface{}, error) { + return handler(ctx, req) + } + })) if err != nil { t.Fatal(err) } @@ -228,6 +234,37 @@ func testClient(t *testing.T, srv *Server) { t.Errorf("expected %s got %s", test.path, res.Path) } } + for _, test := range tests { + var res testData + reqURL := fmt.Sprintf(e.String() + test.path) + req, err := http.NewRequest(test.method, reqURL, nil) + if err != nil { + t.Fatal(err) + } + resp, err := client.DoWithMiddleware(req) + if errors.Code(err) != test.code { + t.Fatalf("want %v, but got %v", test, err) + } + if err != nil { + continue + } + if resp.StatusCode != 200 { + _ = resp.Body.Close() + t.Fatalf("http status got %d", resp.StatusCode) + } + content, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + t.Fatalf("read resp error %v", err) + } + err = json.Unmarshal(content, &res) + if err != nil { + t.Fatalf("unmarshal resp error %v", err) + } + if res.Path != test.path { + t.Errorf("expected %s got %s", test.path, res.Path) + } + } for _, test := range tests { var res testData err := client.Invoke(context.Background(), test.method, test.path, nil, &res)