Merge branch '#2228'

pull/2316/head
wangtingshun 2 years ago
commit 3b6a40c75b
  1. 37
      transport/http/client.go
  2. 39
      transport/http/server_test.go

@ -264,6 +264,43 @@ func (client *Client) invoke(ctx context.Context, req *http.Request, args interf
return err return err
} }
// DoWithMiddleware send an HTTP request WithMiddleware and decodes the body of response into target.
// returns an error (of type *Error) if the response status code is not 2xx.
func (client *Client) DoWithMiddleware(req *http.Request, opts ...CallOption) (*http.Response, error) {
c := defaultCallInfo(req.URL.Path)
for _, o := range opts {
if err := o.before(&c); err != nil {
return nil, err
}
}
h := func(ctx context.Context, in interface{}) (interface{}, error) {
res, err := client.do(req)
if res != nil {
cs := csAttempt{res: res}
for _, o := range opts {
o.after(&c, &cs)
}
}
if err != nil {
return nil, err
}
return res, nil
}
if len(client.opts.middleware) > 0 {
h = middleware.Chain(client.opts.middleware...)(h)
}
resp, err := h(req.Context(), req)
if err != nil {
return nil, err
}
response, ok := resp.(*http.Response)
if ok {
return response, nil
}
return nil, errors.New(500, "Client DoWithMiddleware Failed", "response convert failed ")
}
// Do send an HTTP request and decodes the body of response into target. // Do send an HTTP request and decodes the body of response into target.
// returns an error (of type *Error) if the response status code is not 2xx. // returns an error (of type *Error) if the response status code is not 2xx.
func (client *Client) Do(req *http.Request, opts ...CallOption) (*http.Response, error) { func (client *Client) Do(req *http.Request, opts ...CallOption) (*http.Response, error) {

@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/go-kratos/kratos/v2/middleware"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -192,7 +193,12 @@ func testClient(t *testing.T, srv *Server) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
client, err := NewClient(context.Background(), WithEndpoint(e.Host)) client, err := NewClient(context.Background(), WithEndpoint(e.Host), WithMiddleware(func(handler middleware.Handler) middleware.Handler {
t.Logf("handle in middleware")
return func(ctx context.Context, req interface{}) (interface{}, error) {
return handler(ctx, req)
}
}))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -228,6 +234,37 @@ func testClient(t *testing.T, srv *Server) {
t.Errorf("expected %s got %s", test.path, res.Path) t.Errorf("expected %s got %s", test.path, res.Path)
} }
} }
for _, test := range tests {
var res testData
reqURL := fmt.Sprintf(e.String() + test.path)
req, err := http.NewRequest(test.method, reqURL, nil)
if err != nil {
t.Fatal(err)
}
resp, err := client.DoWithMiddleware(req)
if errors.Code(err) != test.code {
t.Fatalf("want %v, but got %v", test, err)
}
if err != nil {
continue
}
if resp.StatusCode != 200 {
_ = resp.Body.Close()
t.Fatalf("http status got %d", resp.StatusCode)
}
content, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
t.Fatalf("read resp error %v", err)
}
err = json.Unmarshal(content, &res)
if err != nil {
t.Fatalf("unmarshal resp error %v", err)
}
if res.Path != test.path {
t.Errorf("expected %s got %s", test.path, res.Path)
}
}
for _, test := range tests { for _, test := range tests {
var res testData var res testData
err := client.Invoke(context.Background(), test.method, test.path, nil, &res) err := client.Invoke(context.Background(), test.method, test.path, nil, &res)

Loading…
Cancel
Save