You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
371 lines
9.7 KiB
371 lines
9.7 KiB
package http
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"reflect"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
kratoserrors "github.com/go-kratos/kratos/v2/errors"
|
|
"github.com/go-kratos/kratos/v2/middleware"
|
|
"github.com/go-kratos/kratos/v2/registry"
|
|
"github.com/go-kratos/kratos/v2/selector"
|
|
)
|
|
|
|
type mockRoundTripper struct{}
|
|
|
|
func (rt *mockRoundTripper) RoundTrip(_ *http.Request) (resp *http.Response, err error) {
|
|
return
|
|
}
|
|
|
|
type mockCallOption struct {
|
|
needErr bool
|
|
}
|
|
|
|
func (x *mockCallOption) before(_ *callInfo) error {
|
|
if x.needErr {
|
|
return errors.New("option need return err")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (x *mockCallOption) after(_ *callInfo, _ *csAttempt) {
|
|
log.Println("run in mockCallOption.after")
|
|
}
|
|
|
|
func TestWithTransport(t *testing.T) {
|
|
ov := &mockRoundTripper{}
|
|
o := WithTransport(ov)
|
|
co := &clientOptions{}
|
|
o(co)
|
|
if !reflect.DeepEqual(co.transport, ov) {
|
|
t.Errorf("expected transport to be %v, got %v", ov, co.transport)
|
|
}
|
|
}
|
|
|
|
func TestWithTimeout(t *testing.T) {
|
|
ov := 1 * time.Second
|
|
o := WithTimeout(ov)
|
|
co := &clientOptions{}
|
|
o(co)
|
|
if !reflect.DeepEqual(co.timeout, ov) {
|
|
t.Errorf("expected timeout to be %v, got %v", ov, co.timeout)
|
|
}
|
|
}
|
|
|
|
func TestWithBlock(t *testing.T) {
|
|
o := WithBlock()
|
|
co := &clientOptions{}
|
|
o(co)
|
|
if !co.block {
|
|
t.Errorf("expected block to be true, got %v", co.block)
|
|
}
|
|
}
|
|
|
|
func TestWithBalancer(_ *testing.T) {
|
|
// TODO
|
|
}
|
|
|
|
func TestWithTLSConfig(t *testing.T) {
|
|
ov := &tls.Config{}
|
|
o := WithTLSConfig(ov)
|
|
co := &clientOptions{}
|
|
o(co)
|
|
if !reflect.DeepEqual(co.tlsConf, ov) {
|
|
t.Errorf("expected tls config to be %v, got %v", ov, co.tlsConf)
|
|
}
|
|
}
|
|
|
|
func TestWithUserAgent(t *testing.T) {
|
|
ov := "kratos"
|
|
o := WithUserAgent(ov)
|
|
co := &clientOptions{}
|
|
o(co)
|
|
if !reflect.DeepEqual(co.userAgent, ov) {
|
|
t.Errorf("expected user agent to be %v, got %v", ov, co.userAgent)
|
|
}
|
|
}
|
|
|
|
func TestWithMiddleware(t *testing.T) {
|
|
o := &clientOptions{}
|
|
v := []middleware.Middleware{
|
|
func(middleware.Handler) middleware.Handler { return nil },
|
|
}
|
|
WithMiddleware(v...)(o)
|
|
if !reflect.DeepEqual(o.middleware, v) {
|
|
t.Errorf("expected middleware to be %v, got %v", v, o.middleware)
|
|
}
|
|
}
|
|
|
|
func TestWithEndpoint(t *testing.T) {
|
|
ov := "some-endpoint"
|
|
o := WithEndpoint(ov)
|
|
co := &clientOptions{}
|
|
o(co)
|
|
if !reflect.DeepEqual(co.endpoint, ov) {
|
|
t.Errorf("expected endpoint to be %v, got %v", ov, co.endpoint)
|
|
}
|
|
}
|
|
|
|
func TestWithRequestEncoder(t *testing.T) {
|
|
o := &clientOptions{}
|
|
v := func(ctx context.Context, contentType string, in interface{}) (body []byte, err error) {
|
|
return nil, nil
|
|
}
|
|
WithRequestEncoder(v)(o)
|
|
if o.encoder == nil {
|
|
t.Errorf("expected encoder to be not nil")
|
|
}
|
|
}
|
|
|
|
func TestWithResponseDecoder(t *testing.T) {
|
|
o := &clientOptions{}
|
|
v := func(ctx context.Context, res *http.Response, out interface{}) error { return nil }
|
|
WithResponseDecoder(v)(o)
|
|
if o.decoder == nil {
|
|
t.Errorf("expected encoder to be not nil")
|
|
}
|
|
}
|
|
|
|
func TestWithErrorDecoder(t *testing.T) {
|
|
o := &clientOptions{}
|
|
v := func(ctx context.Context, res *http.Response) error { return nil }
|
|
WithErrorDecoder(v)(o)
|
|
if o.errorDecoder == nil {
|
|
t.Errorf("expected encoder to be not nil")
|
|
}
|
|
}
|
|
|
|
type mockDiscovery struct{}
|
|
|
|
func (*mockDiscovery) GetService(_ context.Context, _ string) ([]*registry.ServiceInstance, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (*mockDiscovery) Watch(_ context.Context, _ string) (registry.Watcher, error) {
|
|
return &mockWatcher{}, nil
|
|
}
|
|
|
|
type mockWatcher struct{}
|
|
|
|
func (m *mockWatcher) Next() ([]*registry.ServiceInstance, error) {
|
|
instance := ®istry.ServiceInstance{
|
|
ID: "1",
|
|
Name: "kratos",
|
|
Version: "v1",
|
|
Metadata: map[string]string{},
|
|
Endpoints: []string{fmt.Sprintf("http://127.0.0.1:9001?isSecure=%s", strconv.FormatBool(false))},
|
|
}
|
|
time.Sleep(time.Millisecond * 500)
|
|
return []*registry.ServiceInstance{instance}, nil
|
|
}
|
|
|
|
func (*mockWatcher) Stop() error {
|
|
return nil
|
|
}
|
|
|
|
func TestWithDiscovery(t *testing.T) {
|
|
ov := &mockDiscovery{}
|
|
o := WithDiscovery(ov)
|
|
co := &clientOptions{}
|
|
o(co)
|
|
if !reflect.DeepEqual(co.discovery, ov) {
|
|
t.Errorf("expected discovery to be %v, got %v", ov, co.discovery)
|
|
}
|
|
}
|
|
|
|
func TestWithNodeFilter(t *testing.T) {
|
|
ov := func(context.Context, []selector.Node) []selector.Node {
|
|
return []selector.Node{&selector.DefaultNode{}}
|
|
}
|
|
o := WithNodeFilter(ov)
|
|
co := &clientOptions{}
|
|
o(co)
|
|
for _, n := range co.nodeFilters {
|
|
ret := n(context.Background(), nil)
|
|
if len(ret) != 1 {
|
|
t.Errorf("expected node length to be 1, got %v", len(ret))
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDefaultRequestEncoder(t *testing.T) {
|
|
r, _ := http.NewRequest(http.MethodPost, "", io.NopCloser(bytes.NewBufferString(`{"a":"1", "b": 2}`)))
|
|
r.Header.Set("Content-Type", "application/xml")
|
|
|
|
v1 := &struct {
|
|
A string `json:"a"`
|
|
B int64 `json:"b"`
|
|
}{"a", 1}
|
|
b, err := DefaultRequestEncoder(context.TODO(), "application/json", v1)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
v1b := &struct {
|
|
A string `json:"a"`
|
|
B int64 `json:"b"`
|
|
}{}
|
|
err = json.Unmarshal(b, v1b)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !reflect.DeepEqual(v1b, v1) {
|
|
t.Errorf("expected %v, got %v", v1, v1b)
|
|
}
|
|
}
|
|
|
|
func TestDefaultResponseDecoder(t *testing.T) {
|
|
resp1 := &http.Response{
|
|
Header: make(http.Header),
|
|
StatusCode: 200,
|
|
Body: io.NopCloser(bytes.NewBufferString(`{"a":"1", "b": 2}`)),
|
|
}
|
|
v1 := &struct {
|
|
A string `json:"a"`
|
|
B int64 `json:"b"`
|
|
}{}
|
|
err := DefaultResponseDecoder(context.TODO(), resp1, &v1)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if v1.A != "1" {
|
|
t.Errorf("expected %v, got %v", "1", v1.A)
|
|
}
|
|
if v1.B != int64(2) {
|
|
t.Errorf("expected %v, got %v", 2, v1.B)
|
|
}
|
|
|
|
resp2 := &http.Response{
|
|
Header: make(http.Header),
|
|
StatusCode: 200,
|
|
Body: io.NopCloser(bytes.NewBufferString("{badjson}")),
|
|
}
|
|
v2 := &struct {
|
|
A string `json:"a"`
|
|
B int64 `json:"b"`
|
|
}{}
|
|
err = DefaultResponseDecoder(context.TODO(), resp2, &v2)
|
|
syntaxErr := &json.SyntaxError{}
|
|
if !errors.As(err, &syntaxErr) {
|
|
t.Errorf("expected %v, got %v", syntaxErr, err)
|
|
}
|
|
}
|
|
|
|
func TestDefaultErrorDecoder(t *testing.T) {
|
|
for i := 200; i < 300; i++ {
|
|
resp := &http.Response{Header: make(http.Header), StatusCode: i}
|
|
if DefaultErrorDecoder(context.TODO(), resp) != nil {
|
|
t.Errorf("expected no error, got %v", DefaultErrorDecoder(context.TODO(), resp))
|
|
}
|
|
}
|
|
resp1 := &http.Response{
|
|
Header: make(http.Header),
|
|
StatusCode: 300,
|
|
Body: io.NopCloser(bytes.NewBufferString("{\"foo\":\"bar\"}")),
|
|
}
|
|
if DefaultErrorDecoder(context.TODO(), resp1) == nil {
|
|
t.Errorf("expected error, got nil")
|
|
}
|
|
|
|
resp2 := &http.Response{
|
|
Header: make(http.Header),
|
|
StatusCode: 500,
|
|
Body: io.NopCloser(bytes.NewBufferString(`{"code":54321, "message": "hi", "reason": "FOO"}`)),
|
|
}
|
|
err := DefaultErrorDecoder(context.TODO(), resp2)
|
|
if err == nil {
|
|
t.Errorf("expected error, got nil")
|
|
}
|
|
if err.(*kratoserrors.Error).Code != int32(500) {
|
|
t.Errorf("expected %v, got %v", 500, err.(*kratoserrors.Error).Code)
|
|
}
|
|
if err.(*kratoserrors.Error).Message != "hi" {
|
|
t.Errorf("expected %v, got %v", "hi", err.(*kratoserrors.Error).Message)
|
|
}
|
|
if err.(*kratoserrors.Error).Reason != "FOO" {
|
|
t.Errorf("expected %v, got %v", "FOO", err.(*kratoserrors.Error).Reason)
|
|
}
|
|
}
|
|
|
|
func TestCodecForResponse(t *testing.T) {
|
|
resp := &http.Response{Header: make(http.Header)}
|
|
resp.Header.Set("Content-Type", "application/xml")
|
|
c := CodecForResponse(resp)
|
|
if !reflect.DeepEqual("xml", c.Name()) {
|
|
t.Errorf("expected %v, got %v", "xml", c.Name())
|
|
}
|
|
}
|
|
|
|
func TestNewClient(t *testing.T) {
|
|
_, err := NewClient(context.Background(), WithEndpoint("127.0.0.1:8888"))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
_, err = NewClient(context.Background(), WithEndpoint("127.0.0.1:9999"), WithTLSConfig(&tls.Config{ServerName: "www.kratos.com", RootCAs: nil}))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
_, err = NewClient(context.Background(), WithDiscovery(&mockDiscovery{}), WithEndpoint("discovery:///go-kratos"))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
_, err = NewClient(context.Background(), WithDiscovery(&mockDiscovery{}), WithEndpoint("127.0.0.1:8888"))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
_, err = NewClient(context.Background(), WithEndpoint("127.0.0.1:8888:xxxxa"))
|
|
if err == nil {
|
|
t.Error("except a parseTarget error")
|
|
}
|
|
_, err = NewClient(context.Background(), WithDiscovery(&mockDiscovery{}), WithEndpoint("https://go-kratos.dev/"))
|
|
if err == nil {
|
|
t.Error("err should not be equal to nil")
|
|
}
|
|
|
|
client, err := NewClient(
|
|
context.Background(),
|
|
WithDiscovery(&mockDiscovery{}),
|
|
WithEndpoint("discovery:///go-kratos"),
|
|
WithMiddleware(func(handler middleware.Handler) middleware.Handler {
|
|
t.Logf("handle in middleware")
|
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
return handler(ctx, req)
|
|
}
|
|
}),
|
|
)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
err = client.Invoke(context.Background(), http.MethodPost, "/go", map[string]string{"name": "kratos"}, nil, EmptyCallOption{}, &mockCallOption{})
|
|
if err == nil {
|
|
t.Error("err should not be equal to nil")
|
|
}
|
|
err = client.Invoke(context.Background(), http.MethodPost, "/go", map[string]string{"name": "kratos"}, nil, EmptyCallOption{}, &mockCallOption{needErr: true})
|
|
if err == nil {
|
|
t.Error("err should be equal to callOption err")
|
|
}
|
|
client.opts.encoder = func(ctx context.Context, contentType string, in interface{}) (body []byte, err error) {
|
|
return nil, errors.New("mock test encoder error")
|
|
}
|
|
err = client.Invoke(context.Background(), http.MethodPost, "/go", map[string]string{"name": "kratos"}, nil, EmptyCallOption{})
|
|
if err == nil {
|
|
t.Error("err should be equal to encoder error")
|
|
}
|
|
headersCallOpt := Headers(map[string]string{
|
|
"trace_id": "xxxx",
|
|
})
|
|
err = client.Invoke(context.Background(), "POST", "/go", map[string]string{"name": "kratos"}, nil, headersCallOpt)
|
|
if err == nil {
|
|
t.Error("err should be equal to encoder error")
|
|
}
|
|
}
|
|
|