package http import ( "bytes" "context" "crypto/tls" "encoding/json" "errors" "fmt" "io" "log" nethttp "net/http" "reflect" "strconv" "testing" "time" kratosErrors "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/selector" ) type mockRoundTripper struct{} func (rt *mockRoundTripper) RoundTrip(req *nethttp.Request) (resp *nethttp.Response, err error) { return } type mockCallOption struct { needErr bool } func (x *mockCallOption) before(info *callInfo) error { if x.needErr { return fmt.Errorf("option need return err") } return nil } func (x *mockCallOption) after(info *callInfo, attempt *csAttempt) { log.Println("run in mockCallOption.after") } func TestWithTransport(t *testing.T) { ov := &mockRoundTripper{} o := WithTransport(ov) co := &clientOptions{} o(co) if !reflect.DeepEqual(co.transport, ov) { t.Errorf("expected transport to be %v, got %v", ov, co.transport) } } func TestWithTimeout(t *testing.T) { ov := 1 * time.Second o := WithTimeout(ov) co := &clientOptions{} o(co) if !reflect.DeepEqual(co.timeout, ov) { t.Errorf("expected timeout to be %v, got %v", ov, co.timeout) } } func TestWithBlock(t *testing.T) { o := WithBlock() co := &clientOptions{} o(co) if !co.block { t.Errorf("expected block to be true, got %v", co.block) } } func TestWithBalancer(t *testing.T) { } func TestWithTLSConfig(t *testing.T) { ov := &tls.Config{} o := WithTLSConfig(ov) co := &clientOptions{} o(co) if !reflect.DeepEqual(co.tlsConf, ov) { t.Errorf("expected tls config to be %v, got %v", ov, co.tlsConf) } } func TestWithUserAgent(t *testing.T) { ov := "kratos" o := WithUserAgent(ov) co := &clientOptions{} o(co) if !reflect.DeepEqual(co.userAgent, ov) { t.Errorf("expected user agent to be %v, got %v", ov, co.userAgent) } } func TestWithMiddleware(t *testing.T) { o := &clientOptions{} v := []middleware.Middleware{ func(middleware.Handler) middleware.Handler { return nil }, } WithMiddleware(v...)(o) if !reflect.DeepEqual(o.middleware, v) { t.Errorf("expected middleware to be %v, got %v", v, o.middleware) } } func TestWithEndpoint(t *testing.T) { ov := "some-endpoint" o := WithEndpoint(ov) co := &clientOptions{} o(co) if !reflect.DeepEqual(co.endpoint, ov) { t.Errorf("expected endpoint to be %v, got %v", ov, co.endpoint) } } func TestWithRequestEncoder(t *testing.T) { o := &clientOptions{} v := func(ctx context.Context, contentType string, in interface{}) (body []byte, err error) { return nil, nil } WithRequestEncoder(v)(o) if o.encoder == nil { t.Errorf("expected encoder to be not nil") } } func TestWithResponseDecoder(t *testing.T) { o := &clientOptions{} v := func(ctx context.Context, res *nethttp.Response, out interface{}) error { return nil } WithResponseDecoder(v)(o) if o.decoder == nil { t.Errorf("expected encoder to be not nil") } } func TestWithErrorDecoder(t *testing.T) { o := &clientOptions{} v := func(ctx context.Context, res *nethttp.Response) error { return nil } WithErrorDecoder(v)(o) if o.errorDecoder == nil { t.Errorf("expected encoder to be not nil") } } type mockDiscovery struct{} func (*mockDiscovery) GetService(ctx context.Context, serviceName string) ([]*registry.ServiceInstance, error) { return nil, nil } func (*mockDiscovery) Watch(ctx context.Context, serviceName string) (registry.Watcher, error) { return &mockWatcher{}, nil } type mockWatcher struct{} func (m *mockWatcher) Next() ([]*registry.ServiceInstance, error) { instance := ®istry.ServiceInstance{ ID: "1", Name: "kratos", Version: "v1", Metadata: map[string]string{}, Endpoints: []string{fmt.Sprintf("http://127.0.0.1:9001?isSecure=%s", strconv.FormatBool(false))}, } time.Sleep(time.Millisecond * 500) return []*registry.ServiceInstance{instance}, nil } func (*mockWatcher) Stop() error { return nil } func TestWithDiscovery(t *testing.T) { ov := &mockDiscovery{} o := WithDiscovery(ov) co := &clientOptions{} o(co) if !reflect.DeepEqual(co.discovery, ov) { t.Errorf("expected discovery to be %v, got %v", ov, co.discovery) } } func TestWithNodeFilter(t *testing.T) { ov := func(context.Context, []selector.Node) []selector.Node { return []selector.Node{&selector.DefaultNode{}} } o := WithNodeFilter(ov) co := &clientOptions{} o(co) for _, n := range co.nodeFilters { ret := n(context.Background(), nil) if len(ret) != 1 { t.Errorf("expected node length to be 1, got %v", len(ret)) } } } func TestDefaultRequestEncoder(t *testing.T) { req1 := &nethttp.Request{ Header: make(nethttp.Header), Body: io.NopCloser(bytes.NewBufferString("{\"a\":\"1\", \"b\": 2}")), } req1.Header.Set("Content-Type", "application/xml") v1 := &struct { A string `json:"a"` B int64 `json:"b"` }{"a", 1} b, err1 := DefaultRequestEncoder(context.TODO(), "application/json", v1) if err1 != nil { t.Errorf("expected no error, got %v", err1) } v1b := &struct { A string `json:"a"` B int64 `json:"b"` }{} err1 = json.Unmarshal(b, v1b) if err1 != nil { t.Errorf("expected no error, got %v", err1) } if !reflect.DeepEqual(v1b, v1) { t.Errorf("expected %v, got %v", v1, v1b) } } func TestDefaultResponseDecoder(t *testing.T) { resp1 := &nethttp.Response{ Header: make(nethttp.Header), StatusCode: 200, Body: io.NopCloser(bytes.NewBufferString("{\"a\":\"1\", \"b\": 2}")), } v1 := &struct { A string `json:"a"` B int64 `json:"b"` }{} err1 := DefaultResponseDecoder(context.TODO(), resp1, &v1) if err1 != nil { t.Errorf("expected no error, got %v", err1) } if !reflect.DeepEqual("1", v1.A) { t.Errorf("expected %v, got %v", "1", v1.A) } if !reflect.DeepEqual(int64(2), v1.B) { t.Errorf("expected %v, got %v", 2, v1.B) } resp2 := &nethttp.Response{ Header: make(nethttp.Header), StatusCode: 200, Body: io.NopCloser(bytes.NewBufferString("{badjson}")), } v2 := &struct { A string `json:"a"` B int64 `json:"b"` }{} err2 := DefaultResponseDecoder(context.TODO(), resp2, &v2) terr1 := &json.SyntaxError{} if !errors.As(err2, &terr1) { t.Errorf("expected %v, got %v", terr1, err2) } } func TestDefaultErrorDecoder(t *testing.T) { for i := 200; i < 300; i++ { resp := &nethttp.Response{Header: make(nethttp.Header), StatusCode: i} if DefaultErrorDecoder(context.TODO(), resp) != nil { t.Errorf("expected no error, got %v", DefaultErrorDecoder(context.TODO(), resp)) } } resp1 := &nethttp.Response{ Header: make(nethttp.Header), StatusCode: 300, Body: io.NopCloser(bytes.NewBufferString("{\"foo\":\"bar\"}")), } if DefaultErrorDecoder(context.TODO(), resp1) == nil { t.Errorf("expected error, got nil") } resp2 := &nethttp.Response{ Header: make(nethttp.Header), StatusCode: 500, Body: io.NopCloser(bytes.NewBufferString("{\"code\":54321, \"message\": \"hi\", \"reason\": \"FOO\"}")), } err2 := DefaultErrorDecoder(context.TODO(), resp2) if err2 == nil { t.Errorf("expected error, got nil") } if !reflect.DeepEqual(int32(500), err2.(*kratosErrors.Error).Code) { t.Errorf("expected %v, got %v", 500, err2.(*kratosErrors.Error).Code) } if !reflect.DeepEqual("hi", err2.(*kratosErrors.Error).Message) { t.Errorf("expected %v, got %v", "hi", err2.(*kratosErrors.Error).Message) } if !reflect.DeepEqual("FOO", err2.(*kratosErrors.Error).Reason) { t.Errorf("expected %v, got %v", "FOO", err2.(*kratosErrors.Error).Reason) } } func TestCodecForResponse(t *testing.T) { resp := &nethttp.Response{Header: make(nethttp.Header)} resp.Header.Set("Content-Type", "application/xml") c := CodecForResponse(resp) if !reflect.DeepEqual("xml", c.Name()) { t.Errorf("expected %v, got %v", "xml", c.Name()) } } func TestNewClient(t *testing.T) { _, err := NewClient(context.Background(), WithEndpoint("127.0.0.1:8888")) if err != nil { t.Error(err) } _, err = NewClient(context.Background(), WithEndpoint("127.0.0.1:9999"), WithTLSConfig(&tls.Config{ServerName: "www.kratos.com", RootCAs: nil})) if err != nil { t.Error(err) } _, err = NewClient(context.Background(), WithDiscovery(&mockDiscovery{}), WithEndpoint("discovery:///go-kratos")) if err != nil { t.Error(err) } _, err = NewClient(context.Background(), WithDiscovery(&mockDiscovery{}), WithEndpoint("127.0.0.1:8888")) if err != nil { t.Error(err) } _, err = NewClient(context.Background(), WithEndpoint("127.0.0.1:8888:xxxxa")) if err == nil { t.Error("except a parseTarget error") } _, err = NewClient(context.Background(), WithDiscovery(&mockDiscovery{}), WithEndpoint("https://go-kratos.dev/")) if err == nil { t.Error("err should not be equal to nil") } client, err := NewClient( context.Background(), WithDiscovery(&mockDiscovery{}), WithEndpoint("discovery:///go-kratos"), WithMiddleware(func(handler middleware.Handler) middleware.Handler { t.Logf("handle in middleware") return func(ctx context.Context, req interface{}) (interface{}, error) { return handler(ctx, req) } }), ) if err != nil { t.Error(err) } err = client.Invoke(context.Background(), "POST", "/go", map[string]string{"name": "kratos"}, nil, EmptyCallOption{}, &mockCallOption{}) if err == nil { t.Error("err should not be equal to nil") } err = client.Invoke(context.Background(), "POST", "/go", map[string]string{"name": "kratos"}, nil, EmptyCallOption{}, &mockCallOption{needErr: true}) if err == nil { t.Error("err should be equal to callOption err") } client.opts.encoder = func(ctx context.Context, contentType string, in interface{}) (body []byte, err error) { return nil, fmt.Errorf("mock test encoder error") } err = client.Invoke(context.Background(), "POST", "/go", map[string]string{"name": "kratos"}, nil, EmptyCallOption{}) if err == nil { t.Error("err should be equal to encoder error") } reqURL := fmt.Sprintf(client.target.Endpoint + "/go") req, err := nethttp.NewRequest("POST", reqURL, nil) err = client.DoWithMiddleware(req, nil, EmptyCallOption{}) if err == nil { t.Error("err should not be equal to nil") } }