test:remove testify go mod (#1766)

* test:remove testify go mod

* tidy go mdo

* fix test
pull/1772/head
haiyux 3 years ago committed by GitHub
parent c6c5e4595c
commit 00c05e82a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 17
      app_test.go
  2. 1
      cmd/protoc-gen-go-http/go.mod
  3. 52
      cmd/protoc-gen-go-http/http_test.go
  4. 50
      config/config_test.go
  5. 33
      config/env/env_test.go
  6. 15
      config/env/watcher_test.go
  7. 30
      config/file/file_test.go
  8. 42
      config/options_test.go
  9. 110
      config/reader_test.go
  10. 124
      config/value_test.go
  11. 36
      container/group/group_test.go
  12. 35
      contrib/config/apollo/apollo_test.go
  13. 1
      contrib/config/apollo/go.mod
  14. 19
      contrib/config/consul/config_test.go
  15. 1
      contrib/config/consul/go.mod
  16. 19
      contrib/config/etcd/config_test.go
  17. 1
      contrib/config/etcd/go.mod
  18. 1
      contrib/encoding/msgpack/go.mod
  19. 1
      contrib/encoding/msgpack/go.sum
  20. 39
      contrib/encoding/msgpack/msgpack_test.go
  21. 1
      contrib/registry/consul/go.mod
  22. 3
      contrib/registry/consul/go.sum
  23. 38
      contrib/registry/consul/registry_test.go
  24. 88
      encoding/form/form_test.go
  25. 28
      encoding/proto/proto_test.go
  26. 30
      errors/errors_test.go
  27. 14
      errors/wrap_test.go
  28. 1
      examples/go.mod
  29. 1
      examples/go.sum
  30. 6
      examples/log/logrus_test.go
  31. 10
      examples/registry/registry_test.go
  32. 18
      examples/tls/tls_test.go
  33. 1
      go.mod
  34. 85
      internal/context/context_test.go
  35. 31
      internal/host/host_test.go
  36. 58
      middleware/auth/jwt/jwt_test.go
  37. 10
      middleware/metrics/metrics_test.go
  38. 15
      middleware/middleware_test.go
  39. 18
      middleware/selector/selector_test.go
  40. 14
      middleware/tracing/metadata_test.go
  41. 46
      middleware/tracing/tracing_test.go
  42. 46
      options_test.go
  43. 10
      selector/filter/version_test.go
  44. 35
      selector/node/direct/direct_test.go
  45. 57
      selector/node/ewma/node_test.go
  46. 58
      selector/p2c/p2c_test.go
  47. 33
      selector/random/random_test.go
  48. 71
      selector/selector_test.go
  49. 26
      selector/wrr/wrr_test.go
  50. 18
      transport/grpc/balancer_test.go
  51. 42
      transport/grpc/client_test.go
  52. 10
      transport/grpc/resolver/direct/builder_test.go
  53. 18
      transport/grpc/resolver/discovery/builder_test.go
  54. 14
      transport/grpc/resolver/discovery/resolver_test.go
  55. 76
      transport/grpc/server_test.go
  56. 40
      transport/grpc/transport_test.go
  57. 66
      transport/http/calloption_test.go
  58. 103
      transport/http/client_test.go
  59. 58
      transport/http/codec_test.go
  60. 99
      transport/http/context_test.go
  61. 50
      transport/http/resolver_test.go
  62. 7
      transport/http/router_test.go
  63. 70
      transport/http/server_test.go
  64. 47
      transport/http/transport_test.go
  65. 45
      transport/transport_test.go

@ -11,7 +11,6 @@ import (
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/transport/grpc" "github.com/go-kratos/kratos/v2/transport/grpc"
"github.com/go-kratos/kratos/v2/transport/http" "github.com/go-kratos/kratos/v2/transport/http"
"github.com/stretchr/testify/assert"
) )
type mockRegistry struct { type mockRegistry struct {
@ -60,19 +59,25 @@ func TestApp(t *testing.T) {
func TestApp_ID(t *testing.T) { func TestApp_ID(t *testing.T) {
v := "123" v := "123"
o := New(ID(v)) o := New(ID(v))
assert.Equal(t, v, o.ID()) if !reflect.DeepEqual(v, o.ID()) {
t.Fatalf("o.ID():%s is not equal to v:%s", o.ID(), v)
}
} }
func TestApp_Name(t *testing.T) { func TestApp_Name(t *testing.T) {
v := "123" v := "123"
o := New(Name(v)) o := New(Name(v))
assert.Equal(t, v, o.Name()) if !reflect.DeepEqual(v, o.Name()) {
t.Fatalf("o.Name():%s is not equal to v:%s", o.Name(), v)
}
} }
func TestApp_Version(t *testing.T) { func TestApp_Version(t *testing.T) {
v := "123" v := "123"
o := New(Version(v)) o := New(Version(v))
assert.Equal(t, v, o.Version()) if !reflect.DeepEqual(v, o.Version()) {
t.Fatalf("o.Version():%s is not equal to v:%s", o.Version(), v)
}
} }
func TestApp_Metadata(t *testing.T) { func TestApp_Metadata(t *testing.T) {
@ -81,7 +86,9 @@ func TestApp_Metadata(t *testing.T) {
"b": "2", "b": "2",
} }
o := New(Metadata(v)) o := New(Metadata(v))
assert.Equal(t, v, o.Metadata()) if !reflect.DeepEqual(v, o.Metadata()) {
t.Fatalf("o.Metadata():%s is not equal to v:%s", o.Metadata(), v)
}
} }
func TestApp_Endpoint(t *testing.T) { func TestApp_Endpoint(t *testing.T) {

@ -4,7 +4,6 @@ go 1.16
require ( require (
github.com/go-kratos/kratos/v2 v2.1.3 github.com/go-kratos/kratos/v2 v2.1.3
github.com/stretchr/testify v1.7.0
google.golang.org/genproto v0.0.0-20210805201207-89edb61ffb67 google.golang.org/genproto v0.0.0-20210805201207-89edb61ffb67
google.golang.org/protobuf v1.27.1 google.golang.org/protobuf v1.27.1
) )

@ -1,45 +1,63 @@
package main package main
import ( import (
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestNoParameters(t *testing.T) { func TestNoParameters(t *testing.T) {
path := "/test/noparams" path := "/test/noparams"
m := buildPathVars(path) m := buildPathVars(path)
assert.Emptyf(t, m, "Map should be empty") if !reflect.DeepEqual(m, map[string]*string{}) {
t.Fatalf("Map should be empty")
}
} }
func TestSingleParam(t *testing.T) { func TestSingleParam(t *testing.T) {
path := "/test/{message.id}" path := "/test/{message.id}"
m := buildPathVars(path) m := buildPathVars(path)
assert.Len(t, m, 1) if !reflect.DeepEqual(len(m), 1) {
assert.Empty(t, m["message.id"]) t.Fatalf("len(m) not is 1")
}
if m["message.id"] != nil {
t.Fatalf(`m["message.id"] should be empty`)
}
} }
func TestTwoParametersReplacement(t *testing.T) { func TestTwoParametersReplacement(t *testing.T) {
path := "/test/{message.id}/{message.name=messages/*}" path := "/test/{message.id}/{message.name=messages/*}"
m := buildPathVars(path) m := buildPathVars(path)
assert.Len(t, m, 2) if len(m) != 2 {
assert.Empty(t, m["message.id"]) t.Fatal("len(m) should be 2")
assert.NotEmpty(t, m["message.name"]) }
assert.Equal(t, *m["message.name"], "messages/*") if m["message.id"] != nil {
t.Fatal(`m["message.id"] should be nil`)
}
if m["message.name"] == nil {
t.Fatal(`m["message.name"] should not be nil`)
}
if *m["message.name"] != "messages/*" {
t.Fatal(`m["message.name"] should be "messages/*"`)
}
} }
func TestNoReplacePath(t *testing.T) { func TestNoReplacePath(t *testing.T) {
path := "/test/{message.id=test}" path := "/test/{message.id=test}"
assert.Equal(t, "/test/{message.id:test}", replacePath("message.id", "test", path)) if !reflect.DeepEqual(replacePath("message.id", "test", path), "/test/{message.id:test}") {
t.Fatal(`replacePath("message.id", "test", path) should be "/test/{message.id:test}"`)
}
path = "/test/{message.id=test/*}" path = "/test/{message.id=test/*}"
assert.Equal(t, "/test/{message.id:test/.*}", replacePath("message.id", "test/*", path)) if !reflect.DeepEqual(replacePath("message.id", "test/*", path), "/test/{message.id:test/.*}") {
t.Fatal(`replacePath("message.id", "test/*", path) should be "/test/{message.id:test/.*}"`)
}
} }
func TestReplacePath(t *testing.T) { func TestReplacePath(t *testing.T) {
path := "/test/{message.id}/{message.name=messages/*}" path := "/test/{message.id}/{message.name=messages/*}"
newPath := replacePath("message.name", "messages/*", path) newPath := replacePath("message.name", "messages/*", path)
assert.Equal(t, "/test/{message.id}/{message.name:messages/.*}", newPath) if !reflect.DeepEqual("/test/{message.id}/{message.name:messages/.*}", newPath) {
t.Fatal(`replacePath("message.name", "messages/*", path) should be "/test/{message.id}/{message.name:messages/.*}"`)
}
} }
func TestIteration(t *testing.T) { func TestIteration(t *testing.T) {
@ -50,7 +68,9 @@ func TestIteration(t *testing.T) {
path = replacePath(v, *s, path) path = replacePath(v, *s, path)
} }
} }
assert.Equal(t, "/test/{message.id}/{message.name:messages/.*}", path) if !reflect.DeepEqual("/test/{message.id}/{message.name:messages/.*}", path) {
t.Fatal(`replacePath("message.name", "messages/*", path) should be "/test/{message.id}/{message.name:messages/.*}"`)
}
} }
func TestIterationMiddle(t *testing.T) { func TestIterationMiddle(t *testing.T) {
@ -61,5 +81,7 @@ func TestIterationMiddle(t *testing.T) {
path = replacePath(v, *s, path) path = replacePath(v, *s, path)
} }
} }
assert.Equal(t, "/test/{message.name:messages/.*}/books", path) if !reflect.DeepEqual("/test/{message.name:messages/.*}/books", path) {
t.Fatal(`replacePath("message.name", "messages/*", path) should be "/test/{message.name:messages/.*}/books"`)
}
} }

@ -2,10 +2,10 @@ package config
import ( import (
"errors" "errors"
"reflect"
"testing" "testing"
"github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/log"
"github.com/stretchr/testify/assert"
) )
const ( const (
@ -126,7 +126,9 @@ func TestConfig(t *testing.T) {
WithLogger(log.GetLogger()), WithLogger(log.GetLogger()),
) )
err = c.Close() err = c.Close()
assert.Nil(t, err) if err != nil {
t.Fatal("t is not nil")
}
jSource := newTestJSONSource(_testJSON) jSource := newTestJSONSource(_testJSON)
opts := options{ opts := options{
@ -141,26 +143,48 @@ func TestConfig(t *testing.T) {
cf.log = log.NewHelper(opts.logger) cf.log = log.NewHelper(opts.logger)
err = cf.Load() err = cf.Load()
assert.Nil(t, err) if err != nil {
t.Fatal("t is not nil")
}
val, err := cf.Value("data.database.driver").String() val, err := cf.Value("data.database.driver").String()
assert.Nil(t, err) if err != nil {
assert.Equal(t, databaseDriver, val) t.Fatal("t is not nil")
}
if !reflect.DeepEqual(databaseDriver, val) {
t.Fatal(`databaseDriver is not equal to val`)
}
err = cf.Watch("endpoints", func(key string, value Value) { err = cf.Watch("endpoints", func(key string, value Value) {
}) })
assert.Nil(t, err) if err != nil {
t.Fatal("t is not nil")
}
jSource.sig <- struct{}{} jSource.sig <- struct{}{}
jSource.err <- struct{}{} jSource.err <- struct{}{}
var testConf testConfigStruct var testConf testConfigStruct
err = cf.Scan(&testConf) err = cf.Scan(&testConf)
assert.Nil(t, err) if err != nil {
assert.Equal(t, httpAddr, testConf.Server.HTTP.Addr) t.Fatal("t is not nil")
assert.Equal(t, httpTimeout, testConf.Server.HTTP.Timeout) }
assert.Equal(t, true, testConf.Server.HTTP.EnableSSL) if !reflect.DeepEqual(httpAddr, testConf.Server.HTTP.Addr) {
assert.Equal(t, grpcPort, testConf.Server.GRPC.Port) t.Fatal(`httpAddr is not equal to testConf.Server.HTTP.Addr`)
assert.Equal(t, endpoint1, testConf.Endpoints[0]) }
assert.Equal(t, 2, len(testConf.Endpoints)) if !reflect.DeepEqual(httpTimeout, testConf.Server.HTTP.Timeout) {
t.Fatal(`httpTimeout is not equal to testConf.Server.HTTP.Timeout`)
}
if !reflect.DeepEqual(true, testConf.Server.HTTP.EnableSSL) {
t.Fatal(`testConf.Server.HTTP.EnableSSL is not equal to true`)
}
if !reflect.DeepEqual(grpcPort, testConf.Server.GRPC.Port) {
t.Fatal(`grpcPort is not equal to testConf.Server.GRPC.Port`)
}
if !reflect.DeepEqual(endpoint1, testConf.Endpoints[0]) {
t.Fatal(`endpoint1 is not equal to testConf.Endpoints[0]`)
}
if !reflect.DeepEqual(len(testConf.Endpoints), 2) {
t.Fatal(`len(testConf.Endpoints) is not equal to 2`)
}
} }

@ -8,7 +8,6 @@ import (
"github.com/go-kratos/kratos/v2/config" "github.com/go-kratos/kratos/v2/config"
"github.com/go-kratos/kratos/v2/config/file" "github.com/go-kratos/kratos/v2/config/file"
"github.com/stretchr/testify/assert"
) )
const _testJSON = ` const _testJSON = `
@ -107,19 +106,27 @@ func TestEnvWithPrefix(t *testing.T) {
switch test.expect.(type) { switch test.expect.(type) {
case int: case int:
if actual, err = v.Int(); err == nil { if actual, err = v.Int(); err == nil {
assert.Equal(t, test.expect, int(actual.(int64)), "int value should be equal") if !reflect.DeepEqual(test.expect.(int), int(actual.(int64))) {
t.Errorf("expect %v, actual %v", test.expect, actual)
}
} }
case string: case string:
if actual, err = v.String(); err == nil { if actual, err = v.String(); err == nil {
assert.Equal(t, test.expect, actual, "string value should be equal") if !reflect.DeepEqual(test.expect.(string), actual.(string)) {
t.Errorf(`expect %v, actual %v`, test.expect, actual)
}
} }
case bool: case bool:
if actual, err = v.Bool(); err == nil { if actual, err = v.Bool(); err == nil {
assert.Equal(t, test.expect, actual, "bool value should be equal") if !reflect.DeepEqual(test.expect.(bool), actual.(bool)) {
t.Errorf(`expect %v, actual %v`, test.expect, actual)
}
} }
case float64: case float64:
if actual, err = v.Float(); err == nil { if actual, err = v.Float(); err == nil {
assert.Equal(t, test.expect, actual, "float64 value should be equal") if !reflect.DeepEqual(test.expect.(float64), actual.(float64)) {
t.Errorf(`expect %v, actual %v`, test.expect, actual)
}
} }
default: default:
actual = v.Load() actual = v.Load()
@ -213,19 +220,27 @@ func TestEnvWithoutPrefix(t *testing.T) {
switch test.expect.(type) { switch test.expect.(type) {
case int: case int:
if actual, err = v.Int(); err == nil { if actual, err = v.Int(); err == nil {
assert.Equal(t, test.expect, int(actual.(int64)), "int value should be equal") if !reflect.DeepEqual(test.expect.(int), int(actual.(int64))) {
t.Errorf("expect %v, actual %v", test.expect, actual)
}
} }
case string: case string:
if actual, err = v.String(); err == nil { if actual, err = v.String(); err == nil {
assert.Equal(t, test.expect, actual, "string value should be equal") if !reflect.DeepEqual(test.expect.(string), actual.(string)) {
t.Errorf(`expect %v, actual %v`, test.expect, actual)
}
} }
case bool: case bool:
if actual, err = v.Bool(); err == nil { if actual, err = v.Bool(); err == nil {
assert.Equal(t, test.expect, actual, "bool value should be equal") if !reflect.DeepEqual(test.expect.(bool), actual.(bool)) {
t.Errorf(`expect %v, actual %v`, test.expect, actual)
}
} }
case float64: case float64:
if actual, err = v.Float(); err == nil { if actual, err = v.Float(); err == nil {
assert.Equal(t, test.expect, actual, "float64 value should be equal") if !reflect.DeepEqual(test.expect.(float64), actual.(float64)) {
t.Errorf(`expect %v, actual %v`, test.expect, actual)
}
} }
default: default:
actual = v.Load() actual = v.Load()

@ -2,26 +2,29 @@ package env
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_watcher_next(t *testing.T) { func Test_watcher_next(t *testing.T) {
t.Run("next after stop should return err", func(t *testing.T) { t.Run("next after stop should return err", func(t *testing.T) {
w, err := NewWatcher() w, err := NewWatcher()
require.NoError(t, err) if err != nil {
t.Errorf("expect no error, got %v", err)
}
_ = w.Stop() _ = w.Stop()
_, err = w.Next() _, err = w.Next()
assert.Error(t, err) if err == nil {
t.Error("expect error, actual nil")
}
}) })
} }
func Test_watcher_stop(t *testing.T) { func Test_watcher_stop(t *testing.T) {
t.Run("stop multiple times should not panic", func(t *testing.T) { t.Run("stop multiple times should not panic", func(t *testing.T) {
w, err := NewWatcher() w, err := NewWatcher()
require.NoError(t, err) if err != nil {
t.Errorf("expect no error, got %v", err)
}
_ = w.Stop() _ = w.Stop()
_ = w.Stop() _ = w.Stop()

@ -4,12 +4,12 @@ import (
"errors" "errors"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/go-kratos/kratos/v2/config" "github.com/go-kratos/kratos/v2/config"
"github.com/stretchr/testify/assert"
) )
const ( const (
@ -120,19 +120,29 @@ func testWatchFile(t *testing.T, path string) {
t.Error(err) t.Error(err)
} }
kvs, err := watch.Next() kvs, err := watch.Next()
assert.Nil(t, err) if err != nil {
assert.Equal(t, string(kvs[0].Value), _testJSONUpdate) t.Errorf(`watch.Next() error(%v)`, err)
}
if !reflect.DeepEqual(string(kvs[0].Value), _testJSONUpdate) {
t.Errorf(`string(kvs[0].Value(%v) is not equal to _testJSONUpdate(%v)`, kvs[0].Value, _testJSONUpdate)
}
newFilepath := filepath.Join(filepath.Dir(path), "test1.json") newFilepath := filepath.Join(filepath.Dir(path), "test1.json")
if err = os.Rename(path, newFilepath); err != nil { if err = os.Rename(path, newFilepath); err != nil {
t.Error(err) t.Error(err)
} }
kvs, err = watch.Next() kvs, err = watch.Next()
assert.NotNil(t, err) if err == nil {
assert.Nil(t, kvs) t.Errorf(`watch.Next() error(%v)`, err)
}
if kvs != nil {
t.Errorf(`watch.Next() error(%v)`, err)
}
err = watch.Stop() err = watch.Stop()
assert.Nil(t, err) if err != nil {
t.Errorf(`watch.Stop() error(%v)`, err)
}
if err := os.Rename(newFilepath, path); err != nil { if err := os.Rename(newFilepath, path); err != nil {
t.Error(err) t.Error(err)
@ -160,8 +170,12 @@ func testWatchDir(t *testing.T, path, file string) {
} }
kvs, err := watch.Next() kvs, err := watch.Next()
assert.Nil(t, err) if err != nil {
assert.Equal(t, string(kvs[0].Value), _testJSONUpdate) t.Errorf(`watch.Next() error(%v)`, err)
}
if !reflect.DeepEqual(string(kvs[0].Value), _testJSONUpdate) {
t.Errorf(`string(kvs[0].Value(%v) is not equal to _testJSONUpdate(%v)`, kvs[0].Value, _testJSONUpdate)
}
} }
func testSource(t *testing.T, path string, data []byte) { func testSource(t *testing.T, path string, data []byte) {

@ -3,8 +3,6 @@ package config
import ( import (
"reflect" "reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestDefaultDecoder(t *testing.T) { func TestDefaultDecoder(t *testing.T) {
@ -15,10 +13,12 @@ func TestDefaultDecoder(t *testing.T) {
} }
target := make(map[string]interface{}) target := make(map[string]interface{})
err := defaultDecoder(src, target) err := defaultDecoder(src, target)
assert.Nil(t, err) if err != nil {
assert.Equal(t, map[string]interface{}{ t.Fatal("err is not nil")
"service": []byte("config"), }
}, target) if !reflect.DeepEqual(target, map[string]interface{}{"service": []byte("config")}) {
t.Fatal(`target is not equal to map[string]interface{}{"service": "config"}`)
}
src = &KeyValue{ src = &KeyValue{
Key: "service.name.alias", Key: "service.name.alias",
@ -27,14 +27,18 @@ func TestDefaultDecoder(t *testing.T) {
} }
target = make(map[string]interface{}) target = make(map[string]interface{})
err = defaultDecoder(src, target) err = defaultDecoder(src, target)
assert.Nil(t, err) if err != nil {
assert.Equal(t, map[string]interface{}{ t.Fatal("err is not nil")
}
if !reflect.DeepEqual(map[string]interface{}{
"service": map[string]interface{}{ "service": map[string]interface{}{
"name": map[string]interface{}{ "name": map[string]interface{}{
"alias": []byte("2233"), "alias": []byte("2233"),
}, },
}, },
}, target) }, target) {
t.Fatal(`target is not equal to map[string]interface{}{"service": map[string]interface{}{"name": map[string]interface{}{"alias": []byte("2233")}}}`)
}
} }
func TestDefaultResolver(t *testing.T) { func TestDefaultResolver(t *testing.T) {
@ -144,7 +148,9 @@ func TestDefaultResolver(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
err := defaultResolver(data) err := defaultResolver(data)
assert.NoError(t, err) if err != nil {
t.Fatal(`err is not nil`)
}
rd := reader{ rd := reader{
values: data, values: data,
} }
@ -153,19 +159,27 @@ func TestDefaultResolver(t *testing.T) {
switch test.expect.(type) { switch test.expect.(type) {
case int: case int:
if actual, err = v.Int(); err == nil { if actual, err = v.Int(); err == nil {
assert.Equal(t, test.expect, int(actual.(int64)), "int value should be equal") if !reflect.DeepEqual(test.expect.(int), int(actual.(int64))) {
t.Fatal(`expect is not equal to actual`)
}
} }
case string: case string:
if actual, err = v.String(); err == nil { if actual, err = v.String(); err == nil {
assert.Equal(t, test.expect, actual, "string value should be equal") if !reflect.DeepEqual(test.expect, actual) {
t.Fatal(`expect is not equal to actual`)
}
} }
case bool: case bool:
if actual, err = v.Bool(); err == nil { if actual, err = v.Bool(); err == nil {
assert.Equal(t, test.expect, actual, "bool value should be equal") if !reflect.DeepEqual(test.expect, actual) {
t.Fatal(`expect is not equal to actual`)
}
} }
case float64: case float64:
if actual, err = v.Float(); err == nil { if actual, err = v.Float(); err == nil {
assert.Equal(t, test.expect, actual, "float64 value should be equal") if !reflect.DeepEqual(test.expect, actual) {
t.Fatal(`expect is not equal to actual`)
}
} }
default: default:
actual = v.Load() actual = v.Load()

@ -2,10 +2,10 @@ package config
import ( import (
"fmt" "fmt"
"reflect"
"testing" "testing"
"github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/encoding"
"github.com/stretchr/testify/assert"
) )
func TestReader_Merge(t *testing.T) { func TestReader_Merge(t *testing.T) {
@ -28,31 +28,49 @@ func TestReader_Merge(t *testing.T) {
Value: []byte("bad"), Value: []byte("bad"),
Format: "json", Format: "json",
}) })
assert.Error(t, err) if err == nil {
t.Fatal(`err is nil`)
}
err = r.Merge(&KeyValue{ err = r.Merge(&KeyValue{
Key: "b", Key: "b",
Value: []byte(`{"nice": "boat", "x": 1}`), Value: []byte(`{"nice": "boat", "x": 1}`),
Format: "json", Format: "json",
}) })
assert.NoError(t, err) if err != nil {
t.Fatal(`err is not nil`)
}
vv, ok := r.Value("nice") vv, ok := r.Value("nice")
assert.True(t, ok) if !ok {
t.Fatal(`ok is false`)
}
vvv, err := vv.String() vvv, err := vv.String()
assert.NoError(t, err) if err != nil {
assert.Equal(t, "boat", vvv) t.Fatal(`err is not nil`)
}
if vvv != "boat" {
t.Fatal(`vvv is not equal to "boat"`)
}
err = r.Merge(&KeyValue{ err = r.Merge(&KeyValue{
Key: "b", Key: "b",
Value: []byte(`{"x": 2}`), Value: []byte(`{"x": 2}`),
Format: "json", Format: "json",
}) })
assert.NoError(t, err) if err != nil {
t.Fatal(`err is not nil`)
}
vv, ok = r.Value("x") vv, ok = r.Value("x")
assert.True(t, ok) if !ok {
t.Fatal(`ok is false`)
}
vvx, err := vv.Int() vvx, err := vv.Int()
assert.NoError(t, err) if err != nil {
assert.Equal(t, int64(2), vvx) t.Fatal(`err is not nil`)
}
if int64(2) != vvx {
t.Fatal(`vvx is not equal to 2`)
}
} }
func TestReader_Value(t *testing.T) { func TestReader_Value(t *testing.T) {
@ -99,35 +117,65 @@ a:
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
r := newReader(opts) r := newReader(opts)
err := r.Merge(&test.kv) err := r.Merge(&test.kv)
assert.NoError(t, err) if err != nil {
t.Fatal(`err is not nil`)
}
vv, ok := r.Value("a.b.X") vv, ok := r.Value("a.b.X")
assert.True(t, ok) if !ok {
t.Fatal(`ok is false`)
}
vvv, err := vv.Int() vvv, err := vv.Int()
assert.NoError(t, err) if err != nil {
assert.Equal(t, int64(1), vvv) t.Fatal(`err is not nil`)
}
if int64(1) != vvv {
t.Fatal(`vvv is not equal to 1`)
}
assert.NoError(t, err) if err != nil {
t.Fatal(`err is not nil`)
}
vv, ok = r.Value("a.b.Y") vv, ok = r.Value("a.b.Y")
assert.True(t, ok) if !ok {
t.Fatal(`ok is false`)
}
vvy, err := vv.String() vvy, err := vv.String()
assert.NoError(t, err) if err != nil {
assert.Equal(t, "lol", vvy) t.Fatal(`err is not nil`)
}
if vvy != "lol" {
t.Fatal(`vvy is not equal to "lol"`)
}
assert.NoError(t, err) if err != nil {
t.Fatal(`err is not nil`)
}
vv, ok = r.Value("a.b.z") vv, ok = r.Value("a.b.z")
assert.True(t, ok) if !ok {
t.Fatal(`ok is false`)
}
vvz, err := vv.Bool() vvz, err := vv.Bool()
assert.NoError(t, err) if err != nil {
assert.Equal(t, true, vvz) t.Fatal(`err is not nil`)
}
if !vvz {
t.Fatal(`vvz is not equal to true`)
}
_, ok = r.Value("aasasdg=234l.asdfk,") _, ok = r.Value("aasasdg=234l.asdfk,")
assert.False(t, ok) if ok {
t.Fatal(`ok is true`)
}
_, ok = r.Value("aas......asdg=234l.asdfk,") _, ok = r.Value("aas......asdg=234l.asdfk,")
assert.False(t, ok) if ok {
t.Fatal(`ok is true`)
}
_, ok = r.Value("a.b.Y.") _, ok = r.Value("a.b.Y.")
assert.False(t, ok) if ok {
t.Fatal(`ok is true`)
}
}) })
} }
} }
@ -149,8 +197,14 @@ func TestReader_Source(t *testing.T) {
Value: []byte(`{"a": {"b": {"X": 1}}}`), Value: []byte(`{"a": {"b": {"X": 1}}}`),
Format: "json", Format: "json",
}) })
assert.NoError(t, err) if err != nil {
t.Fatal(`err is not nil`)
}
b, err := r.Source() b, err := r.Source()
assert.NoError(t, err) if err != nil {
assert.Equal(t, []byte(`{"a":{"b":{"X":1}}}`), b) t.Fatal(`err is not nil`)
}
if !reflect.DeepEqual([]byte(`{"a":{"b":{"X":1}}}`), b) {
t.Fatal("[]byte(`{\"a\":{\"b\":{\"X\":1}}}`) is not equal to b")
}
} }

@ -4,8 +4,6 @@ import (
"fmt" "fmt"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func Test_atomicValue_Bool(t *testing.T) { func Test_atomicValue_Bool(t *testing.T) {
@ -14,8 +12,12 @@ func Test_atomicValue_Bool(t *testing.T) {
v := atomicValue{} v := atomicValue{}
v.Store(x) v.Store(x)
b, err := v.Bool() b, err := v.Bool()
assert.NoError(t, err, b) if err != nil {
assert.True(t, b, b) t.Fatal(`err is not nil`)
}
if !b {
t.Fatal(`b is not equal to true`)
}
} }
vlist = []interface{}{"0", "f", "F", "false", "FALSE", "False", false, 0, int32(0)} vlist = []interface{}{"0", "f", "F", "false", "FALSE", "False", false, 0, int32(0)}
@ -23,16 +25,22 @@ func Test_atomicValue_Bool(t *testing.T) {
v := atomicValue{} v := atomicValue{}
v.Store(x) v.Store(x)
b, err := v.Bool() b, err := v.Bool()
assert.NoError(t, err, b) if err != nil {
assert.False(t, b, b) t.Fatal(`err is not nil`)
}
if b {
t.Fatal(`b is not equal to false`)
}
} }
vlist = []interface{}{uint16(1), "bbb", "-1"} vlist = []interface{}{uint16(1), "bbb", "-1"}
for _, x := range vlist { for _, x := range vlist {
v := atomicValue{} v := atomicValue{}
v.Store(x) v.Store(x)
b, err := v.Bool() _, err := v.Bool()
assert.Error(t, err, b) if err == nil {
t.Fatal(`err is nil`)
}
} }
} }
@ -42,16 +50,22 @@ func Test_atomicValue_Int(t *testing.T) {
v := atomicValue{} v := atomicValue{}
v.Store(x) v.Store(x)
b, err := v.Int() b, err := v.Int()
assert.NoError(t, err, b) if err != nil {
assert.Equal(t, int64(123123), b, b) t.Fatal(`err is not nil`)
}
if b != 123123 {
t.Fatal(`b is not equal to 123123`)
}
} }
vlist = []interface{}{uint16(1), "bbb", "-x1", true} vlist = []interface{}{uint16(1), "bbb", "-x1", true}
for _, x := range vlist { for _, x := range vlist {
v := atomicValue{} v := atomicValue{}
v.Store(x) v.Store(x)
b, err := v.Int() _, err := v.Int()
assert.Error(t, err, b) if err == nil {
t.Fatal(`err is nil`)
}
} }
} }
@ -61,16 +75,22 @@ func Test_atomicValue_Float(t *testing.T) {
v := atomicValue{} v := atomicValue{}
v.Store(x) v.Store(x)
b, err := v.Float() b, err := v.Float()
assert.NoError(t, err, b) if err != nil {
assert.Equal(t, float64(123123.1), b, b) t.Fatal(`err is not nil`)
}
if b != float64(123123.1) {
t.Fatal(`b is not equal to 123123.1`)
}
} }
vlist = []interface{}{float32(1123123), uint16(1), "bbb", "-x1"} vlist = []interface{}{float32(1123123), uint16(1), "bbb", "-x1"}
for _, x := range vlist { for _, x := range vlist {
v := atomicValue{} v := atomicValue{}
v.Store(x) v.Store(x)
b, err := v.Float() _, err := v.Float()
assert.Error(t, err, b) if err == nil {
t.Fatal(`err is nil`)
}
} }
} }
@ -89,15 +109,23 @@ func Test_atomicValue_String(t *testing.T) {
v := atomicValue{} v := atomicValue{}
v.Store(x) v.Store(x)
b, err := v.String() b, err := v.String()
assert.NoError(t, err, b) if err != nil {
assert.Equal(t, "1", b, b) t.Fatal(`err is not nil`)
}
if b != "1" {
t.Fatal(`b is not equal to 1`)
}
} }
v := atomicValue{} v := atomicValue{}
v.Store(true) v.Store(true)
b, err := v.String() b, err := v.String()
assert.NoError(t, err, b) if err != nil {
assert.Equal(t, "true", b, b) t.Fatal(`err is not nil`)
}
if b != "true" {
t.Fatal(`b is not equal to "true"`)
}
v = atomicValue{} v = atomicValue{}
v.Store(ts{ v.Store(ts{
@ -105,8 +133,12 @@ func Test_atomicValue_String(t *testing.T) {
Age: 10, Age: 10,
}) })
b, err = v.String() b, err = v.String()
assert.NoError(t, err, b) if err != nil {
assert.Equal(t, "test10", b, "test Stringer should be equal") t.Fatal(`err is not nil`)
}
if b != "test10" {
t.Fatal(`b is not equal to "test10"`)
}
} }
func Test_atomicValue_Duration(t *testing.T) { func Test_atomicValue_Duration(t *testing.T) {
@ -115,8 +147,12 @@ func Test_atomicValue_Duration(t *testing.T) {
v := atomicValue{} v := atomicValue{}
v.Store(x) v.Store(x)
b, err := v.Duration() b, err := v.Duration()
assert.NoError(t, err) if err != nil {
assert.Equal(t, time.Duration(5), b) t.Fatal(`err is not nil`)
}
if b != time.Duration(5) {
t.Fatal(`b is not equal to time.Duration(5)`)
}
} }
} }
@ -125,11 +161,17 @@ func Test_atomicValue_Slice(t *testing.T) {
v := atomicValue{} v := atomicValue{}
v.Store(vlist) v.Store(vlist)
slices, err := v.Slice() slices, err := v.Slice()
assert.NoError(t, err) if err != nil {
t.Fatal(`err is not nil`)
}
for _, v := range slices { for _, v := range slices {
b, err := v.Duration() b, err := v.Duration()
assert.NoError(t, err) if err != nil {
assert.Equal(t, time.Duration(5), b) t.Fatal(`err is not nil`)
}
if b != time.Duration(5) {
t.Fatal(`b is not equal to time.Duration(5)`)
}
} }
} }
@ -140,16 +182,26 @@ func Test_atomicValue_Map(t *testing.T) {
v := atomicValue{} v := atomicValue{}
v.Store(vlist) v.Store(vlist)
m, err := v.Map() m, err := v.Map()
assert.NoError(t, err) if err != nil {
t.Fatal(`err is not nil`)
}
for k, v := range m { for k, v := range m {
if k == "5" { if k == "5" {
b, err := v.Duration() b, err := v.Duration()
assert.NoError(t, err) if err != nil {
assert.Equal(t, time.Duration(5), b) t.Fatal(`err is not nil`)
}
if b != time.Duration(5) {
t.Fatal(`b is not equal to time.Duration(5)`)
}
} else { } else {
b, err := v.String() b, err := v.String()
assert.NoError(t, err) if err != nil {
assert.Equal(t, "text", b) t.Fatal(`err is not nil`)
}
if b != "text" {
t.Fatal(`b is not equal to "text"`)
}
} }
} }
} }
@ -160,10 +212,14 @@ func Test_atomicValue_Scan(t *testing.T) {
err = v.Scan(&struct { err = v.Scan(&struct {
A string `json:"a"` A string `json:"a"`
}{"a"}) }{"a"})
assert.NoError(t, err) if err != nil {
t.Fatal(`err is not nil`)
}
err = v.Scan(&struct { err = v.Scan(&struct {
A string `json:"a"` A string `json:"a"`
}{"a"}) }{"a"})
assert.NoError(t, err) if err != nil {
t.Fatal(`err is not nil`)
}
} }

@ -1,9 +1,8 @@
package group package group
import ( import (
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestGroupGet(t *testing.T) { func TestGroupGet(t *testing.T) {
@ -13,14 +12,22 @@ func TestGroupGet(t *testing.T) {
return count return count
}) })
v := g.Get("key_0") v := g.Get("key_0")
assert.Equal(t, 1, v.(int)) if !reflect.DeepEqual(v.(int), 1) {
t.Errorf("expect 1, actual %v", v)
}
v = g.Get("key_1") v = g.Get("key_1")
assert.Equal(t, 2, v.(int)) if !reflect.DeepEqual(v.(int), 2) {
t.Errorf("expect 2, actual %v", v)
}
v = g.Get("key_0") v = g.Get("key_0")
assert.Equal(t, 1, v.(int)) if !reflect.DeepEqual(v.(int), 1) {
assert.Equal(t, 2, count) t.Errorf("expect 1, actual %v", v)
}
if !reflect.DeepEqual(count, 2) {
t.Errorf("expect count 2, actual %v", count)
}
} }
func TestGroupReset(t *testing.T) { func TestGroupReset(t *testing.T) {
@ -38,11 +45,14 @@ func TestGroupReset(t *testing.T) {
for range g.vals { for range g.vals {
length++ length++
} }
if !reflect.DeepEqual(length, 0) {
assert.Equal(t, 0, length) t.Errorf("expect length 0, actual %v", length)
}
g.Get("key") g.Get("key")
assert.Equal(t, true, call) if !reflect.DeepEqual(call, true) {
t.Errorf("expect call true, actual %v", call)
}
} }
func TestGroupClear(t *testing.T) { func TestGroupClear(t *testing.T) {
@ -54,12 +64,16 @@ func TestGroupClear(t *testing.T) {
for range g.vals { for range g.vals {
length++ length++
} }
assert.Equal(t, 1, length) if !reflect.DeepEqual(length, 1) {
t.Errorf("expect length 1, actual %v", length)
}
g.Clear() g.Clear()
length = 0 length = 0
for range g.vals { for range g.vals {
length++ length++
} }
assert.Equal(t, 0, length) if !reflect.DeepEqual(length, 0) {
t.Errorf("expect length 0, actual %v", length)
}
} }

@ -1,9 +1,8 @@
package apollo package apollo
import ( import (
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func Test_genKey(t *testing.T) { func Test_genKey(t *testing.T) {
@ -152,7 +151,9 @@ func Test_convertProperties(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
resolve(tt.args.key, tt.args.value, tt.args.target) resolve(tt.args.key, tt.args.value, tt.args.target)
assert.Equal(t, tt.want, tt.args.target) if !reflect.DeepEqual(tt.args.target, tt.want) {
t.Errorf("convertProperties() = %v, want %v", tt.args.target, tt.want)
}
}) })
} }
} }
@ -160,13 +161,29 @@ func Test_convertProperties(t *testing.T) {
func Test_convertProperties_duplicate(t *testing.T) { func Test_convertProperties_duplicate(t *testing.T) {
target := map[string]interface{}{} target := map[string]interface{}{}
resolve("application.name", "name", target) resolve("application.name", "name", target)
assert.Contains(t, target, "application") _, ok := target["application"]
assert.Contains(t, target["application"], "name") if !reflect.DeepEqual(ok, true) {
assert.Equal(t, "name", target["application"].(map[string]interface{})["name"]) t.Errorf("ok = %v, want %v", ok, true)
}
_, ok = target["application"].(map[string]interface{})["name"]
if !reflect.DeepEqual(ok, true) {
t.Errorf("ok = %v, want %v", ok, true)
}
if !reflect.DeepEqual(target["application"].(map[string]interface{})["name"], "name") {
t.Errorf("target[\"application\"][\"name\"] = %v, want %v", target["application"].(map[string]interface{})["name"], "name")
}
// cause duplicate, the oldest value will be kept // cause duplicate, the oldest value will be kept
resolve("application.name.first", "first name", target) resolve("application.name.first", "first name", target)
assert.Contains(t, target, "application") _, ok = target["application"]
assert.Contains(t, target["application"], "name") if !reflect.DeepEqual(ok, true) {
assert.Equal(t, "name", target["application"].(map[string]interface{})["name"]) t.Errorf("ok = %v, want %v", ok, true)
}
_, ok = target["application"].(map[string]interface{})["name"]
if !reflect.DeepEqual(ok, true) {
t.Errorf("ok = %v, want %v", ok, true)
}
if !reflect.DeepEqual(target["application"].(map[string]interface{})["name"], "name") {
t.Errorf("target[\"application\"][\"name\"] = %v, want %v", target["application"].(map[string]interface{})["name"], "name")
}
} }

@ -5,7 +5,6 @@ go 1.16
require ( require (
github.com/apolloconfig/agollo/v4 v4.0.8 github.com/apolloconfig/agollo/v4 v4.0.8
github.com/go-kratos/kratos/v2 v2.1.4 github.com/go-kratos/kratos/v2 v2.1.4
github.com/stretchr/testify v1.7.0
) )
replace github.com/go-kratos/kratos/v2 => ../../../ replace github.com/go-kratos/kratos/v2 => ../../../

@ -1,10 +1,10 @@
package consul package consul
import ( import (
"reflect"
"testing" "testing"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/stretchr/testify/assert"
) )
const testPath = "kratos/test/config" const testPath = "kratos/test/config"
@ -86,9 +86,16 @@ func TestExtToFormat(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !reflect.DeepEqual(len(kvs), 1) {
assert.Equal(t, 1, len(kvs)) t.Errorf("len(kvs) is %d", len(kvs))
assert.Equal(t, tn, kvs[0].Key) }
assert.Equal(t, tc, string(kvs[0].Value)) if !reflect.DeepEqual(tn, kvs[0].Key) {
assert.Equal(t, "json", kvs[0].Format) t.Errorf("kvs[0].Key is %s", kvs[0].Key)
}
if !reflect.DeepEqual(tc, string(kvs[0].Value)) {
t.Errorf("kvs[0].Value is %s", kvs[0].Value)
}
if !reflect.DeepEqual("json", kvs[0].Format) {
t.Errorf("kvs[0].Format is %s", kvs[0].Format)
}
} }

@ -5,7 +5,6 @@ go 1.15
require ( require (
github.com/go-kratos/kratos/v2 v2.1.4 github.com/go-kratos/kratos/v2 v2.1.4
github.com/hashicorp/consul/api v1.10.0 github.com/hashicorp/consul/api v1.10.0
github.com/stretchr/testify v1.7.0
) )
replace github.com/go-kratos/kratos/v2 => ../../../ replace github.com/go-kratos/kratos/v2 => ../../../

@ -2,10 +2,10 @@ package etcd
import ( import (
"context" "context"
"reflect"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
@ -95,9 +95,16 @@ func TestExtToFormat(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !reflect.DeepEqual(len(kvs), 1) {
assert.Equal(t, 1, len(kvs)) t.Errorf("len(kvs) = %d", len(kvs))
assert.Equal(t, tk, kvs[0].Key) }
assert.Equal(t, tc, string(kvs[0].Value)) if !reflect.DeepEqual(tk, kvs[0].Key) {
assert.Equal(t, "json", kvs[0].Format) t.Errorf("kvs[0].Key is %s", kvs[0].Key)
}
if !reflect.DeepEqual(tc, string(kvs[0].Value)) {
t.Errorf("kvs[0].Value is %s", kvs[0].Value)
}
if !reflect.DeepEqual("json", kvs[0].Format) {
t.Errorf("kvs[0].Format is %s", kvs[0].Format)
}
} }

@ -4,7 +4,6 @@ go 1.16
require ( require (
github.com/go-kratos/kratos/v2 v2.1.4 github.com/go-kratos/kratos/v2 v2.1.4
github.com/stretchr/testify v1.7.0
go.etcd.io/etcd/client/v3 v3.5.0 go.etcd.io/etcd/client/v3 v3.5.0
google.golang.org/grpc v1.43.0 google.golang.org/grpc v1.43.0
) )

@ -4,7 +4,6 @@ go 1.16
require ( require (
github.com/go-kratos/kratos/v2 v2.1.4 github.com/go-kratos/kratos/v2 v2.1.4
github.com/stretchr/testify v1.7.0
github.com/vmihailenco/msgpack/v5 v5.3.4 github.com/vmihailenco/msgpack/v5 v5.3.4
) )

@ -152,7 +152,6 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

@ -1,9 +1,8 @@
package msgpack package msgpack
import ( import (
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
type loginRequest struct { type loginRequest struct {
@ -18,29 +17,47 @@ type testModel struct {
func TestName(t *testing.T) { func TestName(t *testing.T) {
c := new(codec) c := new(codec)
assert.Equal(t, c.Name(), "msgpack") if !reflect.DeepEqual("msgpack", c.Name()) {
t.Errorf("Name() should be msgpack, but got %s", c.Name())
}
} }
func TestCodec(t *testing.T) { func TestCodec(t *testing.T) {
c := new(codec) c := new(codec)
t2 := testModel{ID: 1, Name: "name"} t2 := testModel{ID: 1, Name: "name"}
m, err := c.Marshal(&t2) m, err := c.Marshal(&t2)
assert.Nil(t, err) if err != nil {
t.Errorf("Marshal() should be nil, but got %s", err)
}
var t3 testModel var t3 testModel
err = c.Unmarshal(m, &t3) err = c.Unmarshal(m, &t3)
assert.Nil(t, err) if err != nil {
assert.Equal(t, t3.ID, t2.ID) t.Errorf("Unmarshal() should be nil, but got %s", err)
assert.Equal(t, t3.Name, t2.Name) }
if !reflect.DeepEqual(t2.ID, t3.ID) {
t.Errorf("ID should be %d, but got %d", t2.ID, t3.ID)
}
if !reflect.DeepEqual(t3.Name, t2.Name) {
t.Errorf("Name should be %s, but got %s", t2.Name, t3.Name)
}
request := loginRequest{ request := loginRequest{
UserName: "username", UserName: "username",
Password: "password", Password: "password",
} }
m, err = c.Marshal(&request) m, err = c.Marshal(&request)
assert.Nil(t, err) if err != nil {
t.Errorf("Marshal() should be nil, but got %s", err)
}
var req loginRequest var req loginRequest
err = c.Unmarshal(m, &req) err = c.Unmarshal(m, &req)
assert.Nil(t, err) if err != nil {
assert.Equal(t, req.Password, request.Password) t.Errorf("Unmarshal() should be nil, but got %s", err)
assert.Equal(t, req.UserName, request.UserName) }
if !reflect.DeepEqual(req.Password, request.Password) {
t.Errorf("ID should be %s, but got %s", req.Password, request.Password)
}
if !reflect.DeepEqual(req.UserName, request.UserName) {
t.Errorf("Name should be %s, but got %s", req.UserName, request.UserName)
}
} }

@ -5,7 +5,6 @@ go 1.16
require ( require (
github.com/go-kratos/kratos/v2 v2.1.4 github.com/go-kratos/kratos/v2 v2.1.4
github.com/hashicorp/consul/api v1.9.1 github.com/hashicorp/consul/api v1.9.1
github.com/stretchr/testify v1.7.0
) )
replace github.com/go-kratos/kratos/v2 => ../../../ replace github.com/go-kratos/kratos/v2 => ../../../

@ -106,10 +106,8 @@ github.com/hashicorp/memberlist v0.2.2/go.mod h1:MS2lj3INKhZjWNqd3N0m3J+Jxf3DAOn
github.com/hashicorp/serf v0.9.5 h1:EBWvyu9tcRszt3Bxp3KNssBMP1KuHWyO51lz9+786iM= github.com/hashicorp/serf v0.9.5 h1:EBWvyu9tcRszt3Bxp3KNssBMP1KuHWyO51lz9+786iM=
github.com/hashicorp/serf v0.9.5/go.mod h1:UWDWwZeL5cuWDJdl0C6wrvrUwEqtQ4ZKBKKENpqIUyk= github.com/hashicorp/serf v0.9.5/go.mod h1:UWDWwZeL5cuWDJdl0C6wrvrUwEqtQ4ZKBKKENpqIUyk=
github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
@ -252,7 +250,6 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

@ -4,13 +4,13 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"reflect"
"strconv" "strconv"
"testing" "testing"
"time" "time"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/stretchr/testify/assert"
) )
func tcpServer(t *testing.T, lis net.Listener) { func tcpServer(t *testing.T, lis net.Listener) {
@ -44,7 +44,9 @@ func TestRegister(t *testing.T) {
WithHealthCheckInterval(5), WithHealthCheckInterval(5),
} }
r := New(cli, opts...) r := New(cli, opts...)
assert.Nil(t, err) if err != nil {
t.Errorf("new consul registry failed: %v", err)
}
version := strconv.FormatInt(time.Now().Unix(), 10) version := strconv.FormatInt(time.Now().Unix(), 10)
svc := &registry.ServiceInstance{ svc := &registry.ServiceInstance{
ID: "test2233", ID: "test2233",
@ -56,18 +58,34 @@ func TestRegister(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel() defer cancel()
err = r.Deregister(ctx, svc) err = r.Deregister(ctx, svc)
assert.Nil(t, err) if err != nil {
t.Errorf("Deregister failed: %v", err)
}
err = r.Register(ctx, svc) err = r.Register(ctx, svc)
assert.Nil(t, err) if err != nil {
t.Errorf("Register failed: %v", err)
}
w, err := r.Watch(ctx, "test-provider") w, err := r.Watch(ctx, "test-provider")
assert.Nil(t, err) if err != nil {
t.Errorf("Watchfailed: %v", err)
}
services, err := w.Next() services, err := w.Next()
assert.Nil(t, err) if err != nil {
assert.Equal(t, 1, len(services)) t.Errorf("Next failed: %v", err)
assert.EqualValues(t, "test2233", services[0].ID) }
assert.EqualValues(t, "test-provider", services[0].Name) if !reflect.DeepEqual(1, len(services)) {
assert.EqualValues(t, version, services[0].Version) t.Errorf("no expect float_key value: %v, but got: %v", len(services), 1)
}
if !reflect.DeepEqual("test2233", services[0].ID) {
t.Errorf("no expect float_key value: %v, but got: %v", services[0].ID, "test2233")
}
if !reflect.DeepEqual("test-provider", services[0].Name) {
t.Errorf("no expect float_key value: %v, but got: %v", services[0].Name, "test-provider")
}
if !reflect.DeepEqual(version, services[0].Version) {
t.Errorf("no expect float_key value: %v, but got: %v", services[0].Version, version)
}
} }
func getIntranetIP() string { func getIntranetIP() string {

@ -1,6 +1,7 @@
package form package form
import ( import (
"reflect"
"testing" "testing"
"google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/durationpb"
@ -10,7 +11,6 @@ import (
"github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/encoding"
"github.com/go-kratos/kratos/v2/internal/testdata/complex" "github.com/go-kratos/kratos/v2/internal/testdata/complex"
"github.com/stretchr/testify/require"
) )
type LoginRequest struct { type LoginRequest struct {
@ -31,16 +31,24 @@ func TestFormCodecMarshal(t *testing.T) {
Password: "kratos_pwd", Password: "kratos_pwd",
} }
content, err := encoding.GetCodec(contentType).Marshal(req) content, err := encoding.GetCodec(contentType).Marshal(req)
require.NoError(t, err) if err != nil {
require.Equal(t, []byte("password=kratos_pwd&username=kratos"), content) t.Errorf("marshal error: %v", err)
}
if !reflect.DeepEqual([]byte("password=kratos_pwd&username=kratos"), content) {
t.Errorf("expect %v, got %v", []byte("password=kratos_pwd&username=kratos"), content)
}
req = &LoginRequest{ req = &LoginRequest{
Username: "kratos", Username: "kratos",
Password: "", Password: "",
} }
content, err = encoding.GetCodec(contentType).Marshal(req) content, err = encoding.GetCodec(contentType).Marshal(req)
require.NoError(t, err) if err != nil {
require.Equal(t, []byte("username=kratos"), content) t.Errorf("expect %v, got %v", nil, err)
}
if !reflect.DeepEqual([]byte("username=kratos"), content) {
t.Errorf("expect %v, got %v", []byte("username=kratos"), content)
}
m := &TestModel{ m := &TestModel{
ID: 1, ID: 1,
@ -48,8 +56,12 @@ func TestFormCodecMarshal(t *testing.T) {
} }
content, err = encoding.GetCodec(contentType).Marshal(m) content, err = encoding.GetCodec(contentType).Marshal(m)
t.Log(string(content)) t.Log(string(content))
require.NoError(t, err) if err != nil {
require.Equal(t, []byte("id=1&name=kratos"), content) t.Errorf("expect %v, got %v", nil, err)
}
if !reflect.DeepEqual([]byte("id=1&name=kratos"), content) {
t.Errorf("expect %v, got %v", []byte("id=1&name=kratos"), content)
}
} }
func TestFormCodecUnmarshal(t *testing.T) { func TestFormCodecUnmarshal(t *testing.T) {
@ -58,13 +70,21 @@ func TestFormCodecUnmarshal(t *testing.T) {
Password: "kratos_pwd", Password: "kratos_pwd",
} }
content, err := encoding.GetCodec(contentType).Marshal(req) content, err := encoding.GetCodec(contentType).Marshal(req)
require.NoError(t, err) if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}
bindReq := new(LoginRequest) bindReq := new(LoginRequest)
err = encoding.GetCodec(contentType).Unmarshal(content, bindReq) err = encoding.GetCodec(contentType).Unmarshal(content, bindReq)
require.NoError(t, err) if err != nil {
require.Equal(t, "kratos", bindReq.Username) t.Errorf("expect %v, got %v", nil, err)
require.Equal(t, "kratos_pwd", bindReq.Password) }
if !reflect.DeepEqual("kratos", bindReq.Username) {
t.Errorf("expect %v, got %v", "kratos", bindReq.Username)
}
if !reflect.DeepEqual("kratos_pwd", bindReq.Password) {
t.Errorf("expect %v, got %v", "kratos_pwd", bindReq.Password)
}
} }
func TestProtoEncodeDecode(t *testing.T) { func TestProtoEncodeDecode(t *testing.T) {
@ -97,20 +117,42 @@ func TestProtoEncodeDecode(t *testing.T) {
Bytes: &wrapperspb.BytesValue{Value: []byte("123")}, Bytes: &wrapperspb.BytesValue{Value: []byte("123")},
} }
content, err := encoding.GetCodec(contentType).Marshal(in) content, err := encoding.GetCodec(contentType).Marshal(in)
require.NoError(t, err) if err != nil {
require.Equal(t, "a=19&age=18&b=true&bool=false&byte=MTIz&bytes=MTIz&count=3&d=22.22&double=12.33&duration="+ t.Errorf("expect %v, got %v", nil, err)
}
if !reflect.DeepEqual("a=19&age=18&b=true&bool=false&byte=MTIz&bytes=MTIz&count=3&d=22.22&double=12.33&duration="+
"2m0.000000022s&field=1%2C2&float=12.34&id=2233&int32=32&int64=64&map%5Bkratos%5D=https%3A%2F%2Fgo-kratos.dev%2F&"+ "2m0.000000022s&field=1%2C2&float=12.34&id=2233&int32=32&int64=64&map%5Bkratos%5D=https%3A%2F%2Fgo-kratos.dev%2F&"+
"numberOne=2233&price=11.23&sex=woman&simples=3344&simples=5566&string=go-kratos"+ "numberOne=2233&price=11.23&sex=woman&simples=3344&simples=5566&string=go-kratos"+
"&timestamp=1970-01-01T00%3A00%3A20.000000002Z&uint32=32&uint64=64&very_simple.component=5566", string(content)) "&timestamp=1970-01-01T00%3A00%3A20.000000002Z&uint32=32&uint64=64&very_simple.component=5566", string(content)) {
t.Errorf("rawpath is not equal to %v", string(content))
}
in2 := &complex.Complex{} in2 := &complex.Complex{}
err = encoding.GetCodec(contentType).Unmarshal(content, in2) err = encoding.GetCodec(contentType).Unmarshal(content, in2)
require.NoError(t, err) if err != nil {
require.Equal(t, int64(2233), in2.Id) t.Errorf("expect %v, got %v", nil, err)
require.Equal(t, "2233", in2.NoOne) }
require.NotEmpty(t, in2.Simple) if !reflect.DeepEqual(int64(2233), in2.Id) {
require.Equal(t, "5566", in2.Simple.Component) t.Errorf("expect %v, got %v", int64(2233), in2.Id)
require.NotEmpty(t, in2.Simples) }
require.Len(t, in2.Simples, 2) if !reflect.DeepEqual("2233", in2.NoOne) {
require.Equal(t, "3344", in2.Simples[0]) t.Errorf("expect %v, got %v", "2233", in2.NoOne)
require.Equal(t, "5566", in2.Simples[1]) }
if reflect.DeepEqual(in2.Simple, nil) {
t.Errorf("expect %v, got %v", nil, in2.Simple)
}
if !reflect.DeepEqual("5566", in2.Simple.Component) {
t.Errorf("expect %v, got %v", "5566", in2.Simple.Component)
}
if reflect.DeepEqual(in2.Simples, nil) {
t.Errorf("expect %v, got %v", nil, in2.Simples)
}
if !reflect.DeepEqual(len(in2.Simples), 2) {
t.Errorf("expect %v, got %v", 2, len(in2.Simples))
}
if !reflect.DeepEqual("3344", in2.Simples[0]) {
t.Errorf("expect %v, got %v", "3344", in2.Simples[0])
}
if !reflect.DeepEqual("5566", in2.Simples[1]) {
t.Errorf("expect %v, got %v", "5566", in2.Simples[1])
}
} }

@ -1,16 +1,17 @@
package proto package proto
import ( import (
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
testData "github.com/go-kratos/kratos/v2/internal/testdata/encoding" testData "github.com/go-kratos/kratos/v2/internal/testdata/encoding"
) )
func TestName(t *testing.T) { func TestName(t *testing.T) {
c := new(codec) c := new(codec)
assert.Equal(t, c.Name(), "proto") if !reflect.DeepEqual(c.Name(), "proto") {
t.Errorf("no expect float_key value: %v, but got: %v", c.Name(), "proto")
}
} }
func TestCodec(t *testing.T) { func TestCodec(t *testing.T) {
@ -23,14 +24,23 @@ func TestCodec(t *testing.T) {
} }
m, err := c.Marshal(&model) m, err := c.Marshal(&model)
assert.Nil(t, err) if err != nil {
t.Errorf("Marshal() should be nil, but got %s", err)
}
var res testData.TestModel var res testData.TestModel
err = c.Unmarshal(m, &res) err = c.Unmarshal(m, &res)
assert.Nil(t, err) if err != nil {
t.Errorf("Unmarshal() should be nil, but got %s", err)
assert.Equal(t, res.Id, model.Id) }
assert.Equal(t, res.Name, model.Name) if !reflect.DeepEqual(res.Id, model.Id) {
assert.Equal(t, res.Hobby, model.Hobby) t.Errorf("ID should be %d, but got %d", res.Id, model.Id)
}
if !reflect.DeepEqual(res.Name, model.Name) {
t.Errorf("Name should be %s, but got %s", res.Name, model.Name)
}
if !reflect.DeepEqual(res.Hobby, model.Hobby) {
t.Errorf("Hobby should be %s, but got %s", res.Hobby, model.Hobby)
}
} }

@ -4,9 +4,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/grpc/test/grpc_testing" "google.golang.org/grpc/test/grpc_testing"
@ -58,9 +58,13 @@ func TestError(t *testing.T) {
if se2.Code != http.StatusBadRequest { if se2.Code != http.StatusBadRequest {
t.Errorf("convert code err, got %d want %d", UnknownCode, http.StatusBadRequest) t.Errorf("convert code err, got %d want %d", UnknownCode, http.StatusBadRequest)
} }
assert.Nil(t, FromError(nil)) if FromError(nil) != nil {
t.Errorf("FromError(nil) should be nil")
}
e := FromError(errors.New("test")) e := FromError(errors.New("test"))
assert.Equal(t, e.Code, int32(UnknownCode)) if !reflect.DeepEqual(e.Code, int32(UnknownCode)) {
t.Errorf("no expect value: %v, but got: %v", e.Code, int32(UnknownCode))
}
} }
func TestIs(t *testing.T) { func TestIs(t *testing.T) {
@ -93,10 +97,20 @@ func TestIs(t *testing.T) {
} }
func TestOther(t *testing.T) { func TestOther(t *testing.T) {
assert.Equal(t, Code(nil), 200) if !reflect.DeepEqual(Code(nil), 200) {
assert.Equal(t, Code(errors.New("test")), UnknownCode) t.Errorf("Code(nil) = %v, want %v", Code(nil), 200)
assert.Equal(t, Reason(errors.New("test")), UnknownReason) }
if !reflect.DeepEqual(Code(errors.New("test")), UnknownCode) {
t.Errorf(`Code(errors.New("test")) = %v, want %v`, Code(nil), 200)
}
if !reflect.DeepEqual(Reason(errors.New("test")), UnknownReason) {
t.Errorf(`Reason(errors.New("test")) = %v, want %v`, Reason(nil), UnknownReason)
}
err := Errorf(10001, "test code 10001", "message") err := Errorf(10001, "test code 10001", "message")
assert.Equal(t, Code(err), 10001) if !reflect.DeepEqual(Code(err), 10001) {
assert.Equal(t, Reason(err), "test code 10001") t.Errorf(`Code(err) = %v, want %v`, Code(err), 10001)
}
if !reflect.DeepEqual(Reason(err), "test code 10001") {
t.Errorf(`Reason(err) = %v, want %v`, Reason(err), "test code 10001")
}
} }

@ -3,8 +3,6 @@ package errors
import ( import (
"fmt" "fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
type mockErr struct{} type mockErr struct{}
@ -16,8 +14,14 @@ func (*mockErr) Error() string {
func TestWarp(t *testing.T) { func TestWarp(t *testing.T) {
var err error = &mockErr{} var err error = &mockErr{}
err2 := fmt.Errorf("wrap %w", err) err2 := fmt.Errorf("wrap %w", err)
assert.Equal(t, err, Unwrap(err2)) if err != Unwrap(err2) {
assert.True(t, Is(err2, err)) t.Errorf("got %v want: %v", err, Unwrap(err2))
}
if !Is(err2, err) {
t.Errorf("Is(err2, err) got %v want: %v", Is(err2, err), true)
}
err3 := &mockErr{} err3 := &mockErr{}
assert.True(t, As(err2, &err3)) if !As(err2, &err3) {
t.Errorf("As(err2, &err3) got %v want: %v", As(err2, &err3), true)
}
} }

@ -34,7 +34,6 @@ require (
github.com/segmentio/kafka-go v0.4.17 github.com/segmentio/kafka-go v0.4.17
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/soheilhy/cmux v0.1.4 github.com/soheilhy/cmux v0.1.4
github.com/stretchr/testify v1.7.0
go.etcd.io/etcd/client/v3 v3.5.0 go.etcd.io/etcd/client/v3 v3.5.0
go.opentelemetry.io/otel v1.3.0 go.opentelemetry.io/otel v1.3.0
go.opentelemetry.io/otel/exporters/jaeger v1.3.0 go.opentelemetry.io/otel/exporters/jaeger v1.3.0

@ -1055,7 +1055,6 @@ google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG
google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM=
google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE= google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE=
google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34=
google.golang.org/grpc v1.42.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU=
google.golang.org/grpc v1.43.0 h1:Eeu7bZtDZ2DpRCsLhUlcrLnvYaMK1Gz86a+hMVvELmM= google.golang.org/grpc v1.43.0 h1:Eeu7bZtDZ2DpRCsLhUlcrLnvYaMK1Gz86a+hMVvELmM=
google.golang.org/grpc v1.43.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= google.golang.org/grpc v1.43.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU=
google.golang.org/grpc/examples v0.0.0-20220105183818-2fb1ac854b20 h1:E/V/xnVzrdcIgW1yYPuJONZnNKHb6OV2Jlj6lSTlGXQ= google.golang.org/grpc/examples v0.0.0-20220105183818-2fb1ac854b20 h1:E/V/xnVzrdcIgW1yYPuJONZnNKHb6OV2Jlj6lSTlGXQ=

@ -7,7 +7,6 @@ import (
"github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/log"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
) )
func TestLoggerLog(t *testing.T) { func TestLoggerLog(t *testing.T) {
@ -46,8 +45,9 @@ func TestLoggerLog(t *testing.T) {
output := new(bytes.Buffer) output := new(bytes.Buffer)
logger := NewLogrusLogger(Level(test.level), Formatter(test.formatter), Output(output)) logger := NewLogrusLogger(Level(test.level), Formatter(test.formatter), Output(output))
_ = logger.Log(test.logLevel, test.kvs...) _ = logger.Log(test.logLevel, test.kvs...)
if !strings.HasPrefix(output.String(), test.want) {
assert.True(t, strings.HasPrefix(output.String(), test.want)) t.Errorf("strings.HasPrefix(output.String(), test.want) got %v want: %v", strings.HasPrefix(output.String(), test.want), true)
}
}) })
} }
} }

@ -9,7 +9,6 @@ import (
"github.com/go-kratos/kratos/examples/helloworld/helloworld" "github.com/go-kratos/kratos/examples/helloworld/helloworld"
pb "github.com/go-kratos/kratos/examples/helloworld/helloworld" pb "github.com/go-kratos/kratos/examples/helloworld/helloworld"
"github.com/stretchr/testify/assert"
consulregistry "github.com/go-kratos/kratos/contrib/registry/consul/v2" consulregistry "github.com/go-kratos/kratos/contrib/registry/consul/v2"
etcdregistry "github.com/go-kratos/kratos/contrib/registry/etcd/v2" etcdregistry "github.com/go-kratos/kratos/contrib/registry/etcd/v2"
@ -106,7 +105,9 @@ func TestETCD(t *testing.T) {
} }
callHTTP(t, r) callHTTP(t, r)
callGRPC(t, r) callGRPC(t, r)
assert.NoError(t, srv.Stop()) if srv.Stop() != nil {
t.Errorf("srv.Stop() got error: %v", err)
}
} }
func TestConsul(t *testing.T) { func TestConsul(t *testing.T) {
@ -121,5 +122,8 @@ func TestConsul(t *testing.T) {
} }
callHTTP(t, r) callHTTP(t, r)
callGRPC(t, r) callGRPC(t, r)
assert.NoError(t, srv.Stop())
if srv.Stop() != nil {
t.Errorf("srv.Stop() got error: %v", err)
}
} }

@ -10,13 +10,11 @@ import (
"time" "time"
etcdregistry "github.com/go-kratos/kratos/contrib/registry/etcd/v2" etcdregistry "github.com/go-kratos/kratos/contrib/registry/etcd/v2"
"github.com/go-kratos/kratos/examples/helloworld/helloworld"
pb "github.com/go-kratos/kratos/examples/helloworld/helloworld" pb "github.com/go-kratos/kratos/examples/helloworld/helloworld"
"github.com/go-kratos/kratos/v2" "github.com/go-kratos/kratos/v2"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/transport/grpc" "github.com/go-kratos/kratos/v2/transport/grpc"
"github.com/go-kratos/kratos/v2/transport/http" "github.com/go-kratos/kratos/v2/transport/http"
"github.com/stretchr/testify/assert"
etcd "go.etcd.io/etcd/client/v3" etcd "go.etcd.io/etcd/client/v3"
) )
@ -65,8 +63,8 @@ func callGRPC(t *testing.T, r registry.Discovery, c *tls.Config) {
t.Fatal(err) t.Fatal(err)
} }
defer conn.Close() defer conn.Close()
client := helloworld.NewGreeterClient(conn) client := pb.NewGreeterClient(conn)
reply, err := client.SayHello(context.Background(), &helloworld.HelloRequest{Name: "kratos"}) reply, err := client.SayHello(context.Background(), &pb.HelloRequest{Name: "kratos"})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -85,8 +83,8 @@ func callHTTP(t *testing.T, r registry.Discovery, c *tls.Config) {
t.Fatal(err) t.Fatal(err)
} }
defer conn.Close() defer conn.Close()
client := helloworld.NewGreeterHTTPClient(conn) client := pb.NewGreeterHTTPClient(conn)
reply, err := client.SayHello(context.Background(), &helloworld.HelloRequest{Name: "kratos"}) reply, err := client.SayHello(context.Background(), &pb.HelloRequest{Name: "kratos"})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -128,6 +126,10 @@ func TestETCD(t *testing.T) {
} }
callHTTP(t, r, tlsConf) callHTTP(t, r, tlsConf)
callGRPC(t, r, tlsConf) callGRPC(t, r, tlsConf)
assert.NoError(t, srv.Stop()) if srv.Stop() != nil {
assert.NoError(t, srvTLS.Stop()) t.Errorf("srv.Stop() got error: %v", err)
}
if srvTLS.Stop() != nil {
t.Errorf("srv.Stop() got error: %v", err)
}
} }

@ -10,7 +10,6 @@ require (
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/gorilla/mux v1.8.0 github.com/gorilla/mux v1.8.0
github.com/imdario/mergo v0.3.12 github.com/imdario/mergo v0.3.12
github.com/stretchr/testify v1.7.0
go.opentelemetry.io/otel v1.3.0 go.opentelemetry.io/otel v1.3.0
go.opentelemetry.io/otel/sdk v1.3.0 go.opentelemetry.io/otel/sdk v1.3.0
go.opentelemetry.io/otel/trace v1.3.0 go.opentelemetry.io/otel/trace v1.3.0

@ -2,11 +2,10 @@ package context
import ( import (
"context" "context"
"errors"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func TestContext(t *testing.T) { func TestContext(t *testing.T) {
@ -20,13 +19,21 @@ func TestContext(t *testing.T) {
got := ctx.Value(ctxKey1{}) got := ctx.Value(ctxKey1{})
value1, ok := got.(string) value1, ok := got.(string)
assert.Equal(t, ok, true) if !ok {
assert.Equal(t, value1, "https://github.com/go-kratos/") t.Errorf("expect %v, got %v", true, ok)
}
if !reflect.DeepEqual(value1, "https://github.com/go-kratos/") {
t.Errorf("expect %v, got %v", "https://github.com/go-kratos/", value1)
}
got2 := ctx.Value(ctxKey2{}) got2 := ctx.Value(ctxKey2{})
value2, ok := got2.(string) value2, ok := got2.(string)
assert.Equal(t, ok, true) if !ok {
assert.Equal(t, value2, "https://go-kratos.dev/") t.Errorf("expect %v, got %v", true, ok)
}
if !reflect.DeepEqual("https://go-kratos.dev/", value2) {
t.Errorf("expect %v, got %v", "https://go-kratos.dev/", value2)
}
t.Log(value1) t.Log(value1)
t.Log(value2) t.Log(value2)
@ -45,13 +52,21 @@ func TestMerge(t *testing.T) {
got := ctx.Value(ctxKey1{}) got := ctx.Value(ctxKey1{})
value1, ok := got.(string) value1, ok := got.(string)
assert.Equal(t, ok, true) if !ok {
assert.Equal(t, value1, "https://github.com/go-kratos/") t.Errorf("expect %v, got %v", true, ok)
}
if !reflect.DeepEqual(value1, "https://github.com/go-kratos/") {
t.Errorf("expect %v, got %v", "https://github.com/go-kratos/", value1)
}
got2 := ctx.Value(ctxKey2{}) got2 := ctx.Value(ctxKey2{})
value2, ok := got2.(string) value2, ok := got2.(string)
assert.Equal(t, ok, true) if !ok {
assert.Equal(t, value2, "https://go-kratos.dev/") t.Errorf("expect %v, got %v", true, ok)
}
if !reflect.DeepEqual(value2, "https://go-kratos.dev/") {
t.Errorf("expect %v, got %v", " https://go-kratos.dev/", value2)
}
t.Log(ctx) t.Log(ctx)
} }
@ -63,8 +78,9 @@ func TestErr(t *testing.T) {
ctx, cancel := Merge(ctx1, context.Background()) ctx, cancel := Merge(ctx1, context.Background())
defer cancel() defer cancel()
if !errors.Is(ctx.Err(), context.DeadlineExceeded) {
assert.Equal(t, ctx.Err(), context.DeadlineExceeded) t.Errorf("expect %v, got %v", context.DeadlineExceeded, ctx.Err())
}
} }
func TestDone(t *testing.T) { func TestDone(t *testing.T) {
@ -77,7 +93,9 @@ func TestDone(t *testing.T) {
cancel() cancel()
}() }()
assert.Equal(t, <-ctx.Done(), struct{}{}) if <-ctx.Done() != struct{}{} {
t.Errorf("expect %v, got %v", struct{}{}, <-ctx.Done())
}
} }
func TestFinish(t *testing.T) { func TestFinish(t *testing.T) {
@ -88,9 +106,15 @@ func TestFinish(t *testing.T) {
cancelCh: make(chan struct{}), cancelCh: make(chan struct{}),
} }
err := mc.finish(context.DeadlineExceeded) err := mc.finish(context.DeadlineExceeded)
assert.Equal(t, err, context.DeadlineExceeded) if !errors.Is(err, context.DeadlineExceeded) {
assert.Equal(t, mc.doneMark, uint32(1)) t.Errorf("expect %v, got %v", context.DeadlineExceeded, err)
assert.Equal(t, <-mc.done, struct{}{}) }
if !reflect.DeepEqual(mc.doneMark, uint32(1)) {
t.Errorf("expect %v, got %v", 1, mc.doneMark)
}
if <-mc.done != struct{}{} {
t.Errorf("expect %v, got %v", struct{}{}, <-mc.done)
}
} }
func TestWait(t *testing.T) { func TestWait(t *testing.T) {
@ -109,7 +133,9 @@ func TestWait(t *testing.T) {
mc.wait() mc.wait()
t.Log(mc.doneErr) t.Log(mc.doneErr)
assert.Equal(t, mc.doneErr, context.Canceled) if !errors.Is(mc.doneErr, context.Canceled) {
t.Errorf("expect %v, got %v", context.Canceled, mc.doneErr)
}
ctx2, cancel2 := context.WithCancel(context.Background()) ctx2, cancel2 := context.WithCancel(context.Background())
@ -126,7 +152,9 @@ func TestWait(t *testing.T) {
mc.wait() mc.wait()
t.Log(mc.doneErr) t.Log(mc.doneErr)
assert.Equal(t, mc.doneErr, context.Canceled) if !errors.Is(mc.doneErr, context.Canceled) {
t.Errorf("expect %v, got %v", context.Canceled, mc.doneErr)
}
} }
func TestCancel(t *testing.T) { func TestCancel(t *testing.T) {
@ -137,8 +165,9 @@ func TestCancel(t *testing.T) {
cancelCh: make(chan struct{}), cancelCh: make(chan struct{}),
} }
mc.cancel() mc.cancel()
if <-mc.cancelCh != struct{}{} {
assert.Equal(t, <-mc.cancelCh, struct{}{}) t.Errorf("expect %v, got %v", struct{}{}, <-mc.cancelCh)
}
} }
func Test_mergeCtx_Deadline(t *testing.T) { func Test_mergeCtx_Deadline(t *testing.T) {
@ -215,7 +244,9 @@ func Test_Err2(t *testing.T) {
ctx, cancel := Merge(ctx1, context.Background()) ctx, cancel := Merge(ctx1, context.Background())
defer cancel() defer cancel()
assert.Equal(t, ctx.Err(), nil) if ctx.Err() != nil {
t.Errorf("expect %v, got %v", nil, ctx.Err())
}
ctx1, cancel1 := context.WithCancel(context.Background()) ctx1, cancel1 := context.WithCancel(context.Background())
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
@ -225,7 +256,9 @@ func Test_Err2(t *testing.T) {
cancel1() cancel1()
assert.Equal(t, ctx.Err(), context.Canceled) if !errors.Is(ctx.Err(), context.Canceled) {
t.Errorf("expect %v, got %v", context.Canceled, ctx.Err())
}
ctx1, cancel1 = context.WithCancel(context.Background()) ctx1, cancel1 = context.WithCancel(context.Background())
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
@ -235,9 +268,13 @@ func Test_Err2(t *testing.T) {
cancel1() cancel1()
assert.Equal(t, ctx.Err(), context.Canceled) if !errors.Is(ctx.Err(), context.Canceled) {
t.Errorf("expect %v, got %v", context.Canceled, ctx.Err())
}
ctx, cancel = Merge(context.Background(), context.Background()) ctx, cancel = Merge(context.Background(), context.Background())
cancel() cancel()
assert.Equal(t, ctx.Err(), context.Canceled) if !errors.Is(ctx.Err(), context.Canceled) {
t.Errorf("expect %v, got %v", context.Canceled, ctx.Err())
}
} }

@ -2,9 +2,8 @@ package host
import ( import (
"net" "net"
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestValidIP(t *testing.T) { func TestValidIP(t *testing.T) {
@ -77,25 +76,35 @@ func TestExtract(t *testing.T) {
}) })
} }
lis, err := net.Listen("tcp", ":12345") lis, err := net.Listen("tcp", ":12345")
assert.NoError(t, err) if err != nil {
t.Errorf("expected: %v got %v", nil, err)
}
res, err := Extract("", lis) res, err := Extract("", lis)
assert.NoError(t, err) if err != nil {
t.Errorf("expected: %v got %v", nil, err)
}
expect, err := Extract(lis.Addr().String(), nil) expect, err := Extract(lis.Addr().String(), nil)
assert.NoError(t, err) if err != nil {
assert.Equal(t, expect, res) t.Errorf("expected: %v got %v", nil, err)
}
if !reflect.DeepEqual(res, expect) {
t.Errorf("expected %s got %s", expect, res)
}
} }
func TestExtract2(t *testing.T) { func TestExtract2(t *testing.T) {
addr := "localhost:9001" addr := "localhost:9001"
lis, err := net.Listen("tcp", addr) lis, err := net.Listen("tcp", addr)
if err == nil { if err != nil {
assert.Nil(t, err) t.Errorf("expected: %v got %v", nil, err)
} }
res, err := Extract(addr, lis) res, err := Extract(addr, lis)
if err == nil { if err != nil {
assert.Nil(t, err) t.Errorf("expected: %v got %v", nil, err)
}
if !reflect.DeepEqual(res, "localhost:9001") {
t.Errorf("expected %s got %s", "localhost:9001", res)
} }
assert.Equal(t, res, "localhost:9001")
} }
func TestPort(t *testing.T) { func TestPort(t *testing.T) {

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"reflect"
"testing" "testing"
"time" "time"
@ -12,7 +13,6 @@ import (
"github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport"
"github.com/stretchr/testify/assert"
) )
type headerCarrier http.Header type headerCarrier http.Header
@ -145,11 +145,17 @@ func TestServer(t *testing.T) {
})(next) })(next)
} }
_, err2 := server(test.ctx, test.name) _, err2 := server(test.ctx, test.name)
assert.Equal(t, test.exceptErr, err2) if !errors.Is(test.exceptErr, err2) {
t.Errorf("except error %v, but got %v", test.exceptErr, err2)
}
if test.exceptErr == nil { if test.exceptErr == nil {
assert.NotNil(t, testToken) if testToken == nil {
t.Errorf("except testToken not nil, but got nil")
}
_, ok := testToken.(jwt.MapClaims) _, ok := testToken.(jwt.MapClaims)
assert.True(t, ok) if !ok {
t.Errorf("except testToken is jwt.MapClaims, but got %T", testToken)
}
} }
}) })
} }
@ -189,9 +195,13 @@ func TestClient(t *testing.T) {
handler := Client(test.tokenProvider)(next) handler := Client(test.tokenProvider)(next)
header := &headerCarrier{} header := &headerCarrier{}
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") _, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
assert.Equal(t, test.expectError, err2) if !errors.Is(test.expectError, err2) {
t.Errorf("except error %v, but got %v", test.expectError, err2)
}
if err2 == nil { if err2 == nil {
assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) if !reflect.DeepEqual(header.Get(authorizationKey), fmt.Sprintf(bearerFormat, token)) {
t.Errorf("except header %s, but got %s", fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey))
}
} }
}) })
} }
@ -217,7 +227,9 @@ func TestTokenExpire(t *testing.T) {
return []byte(testKey), nil return []byte(testKey), nil
}, WithSigningMethod(jwt.SigningMethodHS256))(next) }, WithSigningMethod(jwt.SigningMethodHS256))(next)
_, err2 := server(ctx, "test expire token") _, err2 := server(ctx, "test expire token")
assert.Equal(t, ErrTokenExpired, err2) if !errors.Is(ErrTokenExpired, err2) {
t.Errorf("except error %v, but got %v", ErrTokenExpired, err2)
}
} }
func TestMissingKeyFunc(t *testing.T) { func TestMissingKeyFunc(t *testing.T) {
@ -252,9 +264,13 @@ func TestMissingKeyFunc(t *testing.T) {
} }
server := Server(nil)(next) server := Server(nil)(next)
_, err2 := server(test.ctx, test.name) _, err2 := server(test.ctx, test.name)
assert.Equal(t, test.exceptErr, err2) if !errors.Is(test.exceptErr, err2) {
t.Errorf("except error %v, but got %v", test.exceptErr, err2)
}
if test.exceptErr == nil { if test.exceptErr == nil {
assert.NotNil(t, testToken) if testToken == nil {
t.Errorf("except testToken not nil, but got nil")
}
} }
} }
@ -287,9 +303,13 @@ func TestClientWithClaims(t *testing.T) {
handler := Client(test.tokenProvider, WithClaims(mapClaims))(next) handler := Client(test.tokenProvider, WithClaims(mapClaims))(next)
header := &headerCarrier{} header := &headerCarrier{}
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") _, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
assert.Equal(t, test.expectError, err2) if !errors.Is(test.expectError, err2) {
t.Errorf("except error %v, but got %v", test.expectError, err2)
}
if err2 == nil { if err2 == nil {
assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) if !reflect.DeepEqual(header.Get(authorizationKey), fmt.Sprintf(bearerFormat, token)) {
t.Errorf("except header %s, but got %s", fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey))
}
} }
}) })
} }
@ -318,8 +338,12 @@ func TestClientWithHeader(t *testing.T) {
handler := Client(tProvider, WithClaims(mapClaims), WithTokenHeader(tokenHeader))(next) handler := Client(tProvider, WithClaims(mapClaims), WithTokenHeader(tokenHeader))(next)
header := &headerCarrier{} header := &headerCarrier{}
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") _, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
assert.Equal(t, nil, err2) if err2 != nil {
assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) t.Errorf("except error nil, but got %v", err2)
}
if !reflect.DeepEqual(header.Get(authorizationKey), fmt.Sprintf(bearerFormat, token)) {
t.Errorf("except header %s, but got %s", fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey))
}
} }
func TestClientMissKey(t *testing.T) { func TestClientMissKey(t *testing.T) {
@ -351,9 +375,13 @@ func TestClientMissKey(t *testing.T) {
handler := Client(test.tokenProvider, WithClaims(mapClaims))(next) handler := Client(test.tokenProvider, WithClaims(mapClaims))(next)
header := &headerCarrier{} header := &headerCarrier{}
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") _, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
assert.Equal(t, test.expectError, err2) if !errors.Is(test.expectError, err2) {
t.Errorf("except error %v, but got %v", test.expectError, err2)
}
if err2 == nil { if err2 == nil {
assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) if !reflect.DeepEqual(header.Get(authorizationKey), fmt.Sprintf(bearerFormat, token)) {
t.Errorf("except header %s, but got %s", fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey))
}
} }
}) })
} }

@ -3,8 +3,6 @@ package metrics
import ( import (
"context" "context"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestMetrics(t *testing.T) { func TestMetrics(t *testing.T) {
@ -12,8 +10,12 @@ func TestMetrics(t *testing.T) {
return req.(string) + "https://go-kratos.dev", nil return req.(string) + "https://go-kratos.dev", nil
} }
_, err := Server()(next)(context.Background(), "test:") _, err := Server()(next)(context.Background(), "test:")
assert.Equal(t, err, nil) if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}
_, err = Client()(next)(context.Background(), "test:") _, err = Client()(next)(context.Background(), "test:")
assert.Equal(t, err, nil) if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}
} }

@ -3,9 +3,8 @@ package middleware
import ( import (
"context" "context"
"fmt" "fmt"
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
var i int var i int
@ -18,9 +17,15 @@ func TestChain(t *testing.T) {
} }
got, err := Chain(test1Middleware, test2Middleware, test3Middleware)(next)(context.Background(), "hello kratos!") got, err := Chain(test1Middleware, test2Middleware, test3Middleware)(next)(context.Background(), "hello kratos!")
assert.Nil(t, err) if err != nil {
assert.Equal(t, got, "reply") t.Errorf("expect %v, got %v", nil, err)
assert.Equal(t, i, 16) }
if !reflect.DeepEqual(got, "reply") {
t.Errorf("expect %v, got %v", "reply", got)
}
if !reflect.DeepEqual(i, 16) {
t.Errorf("expect %v, got %v", 16, i)
}
} }
func test1Middleware(handler Handler) Handler { func test1Middleware(handler Handler) Handler {

@ -3,12 +3,12 @@ package selector
import ( import (
"context" "context"
"fmt" "fmt"
"reflect"
"strings" "strings"
"testing" "testing"
"github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport"
"github.com/stretchr/testify/assert"
) )
var _ transport.Transporter = &Transport{} var _ transport.Transporter = &Transport{}
@ -167,8 +167,12 @@ func TestFunc(t *testing.T) {
return false return false
}).Build()(next) }).Build()(next)
reply, err := next(test.ctx, test.name) reply, err := next(test.ctx, test.name)
assert.Equal(t, reply, "reply") if err != nil {
assert.Nil(t, err) t.Errorf("expect error is nil, but got %v", err)
}
if !reflect.DeepEqual(reply, "reply") {
t.Errorf("expect reply is reply,but got %v", reply)
}
}) })
} }
} }
@ -227,8 +231,12 @@ func TestHeaderFunc(t *testing.T) {
return false return false
}).Build()(next) }).Build()(next)
reply, err := next(test.ctx, test.name) reply, err := next(test.ctx, test.name)
assert.Equal(t, reply, "reply") if err != nil {
assert.Nil(t, err) t.Errorf("expect error is nil, but got %v", err)
}
if !reflect.DeepEqual(reply, "reply") {
t.Errorf("expect reply is reply,but got %v", reply)
}
}) })
} }
} }

@ -2,12 +2,12 @@ package tracing
import ( import (
"context" "context"
"reflect"
"testing" "testing"
"github.com/go-kratos/kratos/v2" "github.com/go-kratos/kratos/v2"
"github.com/go-kratos/kratos/v2/metadata" "github.com/go-kratos/kratos/v2/metadata"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/propagation"
) )
@ -77,13 +77,19 @@ func TestMetadata_Extract(t *testing.T) {
b := Metadata{} b := Metadata{}
ctx := b.Extract(tt.args.parent, tt.args.carrier) ctx := b.Extract(tt.args.parent, tt.args.carrier)
md, ok := metadata.FromServerContext(ctx) md, ok := metadata.FromServerContext(ctx)
assert.Equal(t, ok, true) if !ok {
assert.Equal(t, md.Get(serviceHeader), tt.want) t.Errorf("expect %v, got %v", true, ok)
}
if !reflect.DeepEqual(md.Get(serviceHeader), tt.want) {
t.Errorf("expect %v, got %v", tt.want, md.Get(serviceHeader))
}
}) })
} }
} }
func TestFields(t *testing.T) { func TestFields(t *testing.T) {
b := Metadata{} b := Metadata{}
assert.Equal(t, b.Fields(), []string{"x-md-service-name"}) if !reflect.DeepEqual(b.Fields(), []string{"x-md-service-name"}) {
t.Errorf("expect %v, got %v", []string{"x-md-service-name"}, b.Fields())
}
} }

@ -4,11 +4,11 @@ import (
"context" "context"
"net/http" "net/http"
"os" "os"
"reflect"
"testing" "testing"
"github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/propagation"
tracesdk "go.opentelemetry.io/otel/sdk/trace" tracesdk "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace"
@ -127,19 +127,33 @@ func TestServer(t *testing.T) {
)(next)(ctx, "test server: ") )(next)(ctx, "test server: ")
span.End() span.End()
assert.NoError(t, err) if err != nil {
assert.NotEmpty(t, childSpanID) t.Errorf("expected nil, got %v", err)
assert.NotEqual(t, span.SpanContext().SpanID().String(), childSpanID) }
assert.Equal(t, span.SpanContext().TraceID().String(), childTraceID) if childSpanID == "" {
t.Errorf("expected empty, got %v", childSpanID)
}
if reflect.DeepEqual(span.SpanContext().SpanID().String(), childSpanID) {
t.Errorf("span.SpanContext().SpanID().String()(%v) is not equal to childSpanID(%v)", span.SpanContext().SpanID().String(), childSpanID)
}
if !reflect.DeepEqual(span.SpanContext().TraceID().String(), childTraceID) {
t.Errorf("expected %v, got %v", childTraceID, span.SpanContext().TraceID().String())
}
_, err = Server( _, err = Server(
WithTracerProvider(tracesdk.NewTracerProvider()), WithTracerProvider(tracesdk.NewTracerProvider()),
WithPropagator(propagation.NewCompositeTextMapPropagator(propagation.Baggage{}, propagation.TraceContext{})), WithPropagator(propagation.NewCompositeTextMapPropagator(propagation.Baggage{}, propagation.TraceContext{})),
)(next)(context.Background(), "test server: ") )(next)(context.Background(), "test server: ")
assert.NoError(t, err) if err != nil {
assert.Empty(t, childSpanID) t.Errorf("expected error, got nil")
assert.Empty(t, childTraceID) }
if childSpanID != "" {
t.Errorf("expected empty, got %v", childSpanID)
}
if childTraceID != "" {
t.Errorf("expected empty, got %v", childTraceID)
}
} }
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
@ -185,8 +199,16 @@ func TestClient(t *testing.T) {
)(next)(ctx, "test client: ") )(next)(ctx, "test client: ")
span.End() span.End()
assert.NoError(t, err) if err != nil {
assert.NotEmpty(t, childSpanID) t.Errorf("expected nil, got %v", err)
assert.NotEqual(t, span.SpanContext().SpanID().String(), childSpanID) }
assert.Equal(t, span.SpanContext().TraceID().String(), childTraceID) if childSpanID == "" {
t.Errorf("expected empty, got %v", childSpanID)
}
if reflect.DeepEqual(span.SpanContext().SpanID().String(), childSpanID) {
t.Errorf("span.SpanContext().SpanID().String()(%v) is not equal to childSpanID(%v)", span.SpanContext().SpanID().String(), childSpanID)
}
if !reflect.DeepEqual(span.SpanContext().TraceID().String(), childTraceID) {
t.Errorf("expected %v, got %v", childTraceID, span.SpanContext().TraceID().String())
}
} }

@ -5,34 +5,40 @@ import (
"log" "log"
"net/url" "net/url"
"os" "os"
"reflect"
"testing" "testing"
"time" "time"
xlog "github.com/go-kratos/kratos/v2/log" xlog "github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport"
"github.com/stretchr/testify/assert"
) )
func TestID(t *testing.T) { func TestID(t *testing.T) {
o := &options{} o := &options{}
v := "123" v := "123"
ID(v)(o) ID(v)(o)
assert.Equal(t, v, o.id) if !reflect.DeepEqual(v, o.id) {
t.Fatalf("o.id:%s is not equal to v:%s", o.id, v)
}
} }
func TestName(t *testing.T) { func TestName(t *testing.T) {
o := &options{} o := &options{}
v := "abc" v := "abc"
Name(v)(o) Name(v)(o)
assert.Equal(t, v, o.name) if !reflect.DeepEqual(v, o.name) {
t.Fatalf("o.name:%s is not equal to v:%s", o.name, v)
}
} }
func TestVersion(t *testing.T) { func TestVersion(t *testing.T) {
o := &options{} o := &options{}
v := "123" v := "123"
Version(v)(o) Version(v)(o)
assert.Equal(t, v, o.version) if !reflect.DeepEqual(v, o.version) {
t.Fatalf("o.version:%s is not equal to v:%s", o.version, v)
}
} }
func TestMetadata(t *testing.T) { func TestMetadata(t *testing.T) {
@ -42,7 +48,9 @@ func TestMetadata(t *testing.T) {
"b": "2", "b": "2",
} }
Metadata(v)(o) Metadata(v)(o)
assert.Equal(t, v, o.metadata) if !reflect.DeepEqual(v, o.metadata) {
t.Fatalf("o.metadata:%s is not equal to v:%s", o.metadata, v)
}
} }
func TestEndpoint(t *testing.T) { func TestEndpoint(t *testing.T) {
@ -52,7 +60,9 @@ func TestEndpoint(t *testing.T) {
{Host: "foo.com"}, {Host: "foo.com"},
} }
Endpoint(v...)(o) Endpoint(v...)(o)
assert.Equal(t, v, o.endpoints) if !reflect.DeepEqual(v, o.endpoints) {
t.Fatalf("o.endpoints:%s is not equal to v:%s", o.endpoints, v)
}
} }
func TestContext(t *testing.T) { func TestContext(t *testing.T) {
@ -60,14 +70,18 @@ func TestContext(t *testing.T) {
o := &options{} o := &options{}
v := context.WithValue(context.TODO(), ctxKey{}, "b") v := context.WithValue(context.TODO(), ctxKey{}, "b")
Context(v)(o) Context(v)(o)
assert.Equal(t, v, o.ctx) if !reflect.DeepEqual(v, o.ctx) {
t.Fatalf("o.ctx:%s is not equal to v:%s", o.ctx, v)
}
} }
func TestLogger(t *testing.T) { func TestLogger(t *testing.T) {
o := &options{} o := &options{}
v := xlog.NewStdLogger(log.Writer()) v := xlog.NewStdLogger(log.Writer())
Logger(v)(o) Logger(v)(o)
assert.Equal(t, xlog.NewHelper(v), o.logger) if !reflect.DeepEqual(xlog.NewHelper(v), o.logger) {
t.Fatalf("o.logger:%s is not equal to xlog.NewHelper(v):%s", o.logger, xlog.NewHelper(v))
}
} }
type mockServer struct{} type mockServer struct{}
@ -81,7 +95,9 @@ func TestServer(t *testing.T) {
&mockServer{}, &mockServer{}, &mockServer{}, &mockServer{},
} }
Server(v...)(o) Server(v...)(o)
assert.Equal(t, v, o.servers) if !reflect.DeepEqual(v, o.servers) {
t.Fatalf("o.servers:%s is not equal to xlog.NewHelper(v):%s", o.servers, v)
}
} }
type mockSignal struct{} type mockSignal struct{}
@ -95,7 +111,9 @@ func TestSignal(t *testing.T) {
&mockSignal{}, &mockSignal{}, &mockSignal{}, &mockSignal{},
} }
Signal(v...)(o) Signal(v...)(o)
assert.Equal(t, v, o.sigs) if !reflect.DeepEqual(v, o.sigs) {
t.Fatal("o.sigs is not equal to v")
}
} }
type mockRegistrar struct{} type mockRegistrar struct{}
@ -112,12 +130,16 @@ func TestRegistrar(t *testing.T) {
o := &options{} o := &options{}
v := &mockRegistrar{} v := &mockRegistrar{}
Registrar(v)(o) Registrar(v)(o)
assert.Equal(t, v, o.registrar) if !reflect.DeepEqual(v, o.registrar) {
t.Fatal("o.registrar is not equal to v")
}
} }
func TestRegistrarTimeout(t *testing.T) { func TestRegistrarTimeout(t *testing.T) {
o := &options{} o := &options{}
v := time.Duration(123) v := time.Duration(123)
RegistrarTimeout(v)(o) RegistrarTimeout(v)(o)
assert.Equal(t, v, o.registrarTimeout) if !reflect.DeepEqual(v, o.registrarTimeout) {
t.Fatal("o.registrarTimeout is not equal to v")
}
} }

@ -2,11 +2,11 @@ package filter
import ( import (
"context" "context"
"reflect"
"testing" "testing"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/selector"
"github.com/stretchr/testify/assert"
) )
func TestVersion(t *testing.T) { func TestVersion(t *testing.T) {
@ -31,6 +31,10 @@ func TestVersion(t *testing.T) {
})) }))
nodes = f(context.Background(), nodes) nodes = f(context.Background(), nodes)
assert.Equal(t, 1, len(nodes)) if !reflect.DeepEqual(len(nodes), 1) {
assert.Equal(t, "127.0.0.2:9090", nodes[0].Address()) t.Errorf("expect %v, got %v", 1, len(nodes))
}
if !reflect.DeepEqual(nodes[0].Address(), "127.0.0.2:9090") {
t.Errorf("expect %v, got %v", nodes[0].Address(), "127.0.0.2:9090")
}
} }

@ -2,13 +2,12 @@ package direct
import ( import (
"context" "context"
"reflect"
"testing" "testing"
"time" "time"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/selector"
"github.com/stretchr/testify/assert"
) )
func TestDirect(t *testing.T) { func TestDirect(t *testing.T) {
@ -24,12 +23,20 @@ func TestDirect(t *testing.T) {
})) }))
done := wn.Pick() done := wn.Pick()
assert.NotNil(t, done) if done == nil {
t.Errorf("expect %v, got %v", nil, done)
}
time.Sleep(time.Millisecond * 10) time.Sleep(time.Millisecond * 10)
done(context.Background(), selector.DoneInfo{}) done(context.Background(), selector.DoneInfo{})
assert.Equal(t, float64(10), wn.Weight()) if !reflect.DeepEqual(float64(10), wn.Weight()) {
assert.Greater(t, time.Millisecond*15, wn.PickElapsed()) t.Errorf("expect %v, got %v", float64(10), wn.Weight())
assert.Less(t, time.Millisecond*5, wn.PickElapsed()) }
if time.Millisecond*15 <= wn.PickElapsed() {
t.Errorf("time.Millisecond*15 <= wn.PickElapsed()(%s)", wn.PickElapsed())
}
if time.Millisecond*5 >= wn.PickElapsed() {
t.Errorf("time.Millisecond*5 >= wn.PickElapsed()(%s)", wn.PickElapsed())
}
} }
func TestDirectDefaultWeight(t *testing.T) { func TestDirectDefaultWeight(t *testing.T) {
@ -44,10 +51,18 @@ func TestDirectDefaultWeight(t *testing.T) {
})) }))
done := wn.Pick() done := wn.Pick()
assert.NotNil(t, done) if done == nil {
t.Errorf("expect %v, got %v", nil, done)
}
time.Sleep(time.Millisecond * 10) time.Sleep(time.Millisecond * 10)
done(context.Background(), selector.DoneInfo{}) done(context.Background(), selector.DoneInfo{})
assert.Equal(t, float64(100), wn.Weight()) if !reflect.DeepEqual(float64(100), wn.Weight()) {
assert.Greater(t, time.Millisecond*20, wn.PickElapsed()) t.Errorf("expect %v, got %v", float64(100), wn.Weight())
assert.Less(t, time.Millisecond*5, wn.PickElapsed()) }
if time.Millisecond*20 <= wn.PickElapsed() {
t.Errorf("time.Millisecond*20 <= wn.PickElapsed()(%s)", wn.PickElapsed())
}
if time.Millisecond*5 >= wn.PickElapsed() {
t.Errorf("time.Millisecond*5 >= wn.PickElapsed()(%s)", wn.PickElapsed())
}
} }

@ -2,12 +2,12 @@ package ewma
import ( import (
"context" "context"
"reflect"
"testing" "testing"
"time" "time"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/selector"
"github.com/stretchr/testify/assert"
) )
func TestDirect(t *testing.T) { func TestDirect(t *testing.T) {
@ -22,19 +22,32 @@ func TestDirect(t *testing.T) {
Metadata: map[string]string{"weight": "10"}, Metadata: map[string]string{"weight": "10"},
})) }))
assert.Equal(t, float64(100), wn.Weight()) if !reflect.DeepEqual(float64(100), wn.Weight()) {
t.Errorf("expect %v, got %v", 100, wn.Weight())
}
done := wn.Pick() done := wn.Pick()
assert.NotNil(t, done) if done == nil {
t.Errorf("done is equal to nil")
}
done2 := wn.Pick() done2 := wn.Pick()
assert.NotNil(t, done2) if done2 == nil {
t.Errorf("done2 is equal to nil")
}
time.Sleep(time.Millisecond * 10) time.Sleep(time.Millisecond * 10)
done(context.Background(), selector.DoneInfo{}) done(context.Background(), selector.DoneInfo{})
assert.Less(t, float64(30000), wn.Weight()) if float64(30000) >= wn.Weight() {
assert.Greater(t, float64(60000), wn.Weight()) t.Errorf("float64(30000) >= wn.Weight()(%v)", wn.Weight())
}
assert.Greater(t, time.Millisecond*15, wn.PickElapsed()) if float64(60000) <= wn.Weight() {
assert.Less(t, time.Millisecond*5, wn.PickElapsed()) t.Errorf("float64(60000) <= wn.Weight()(%v)", wn.Weight())
}
if time.Millisecond*15 <= wn.PickElapsed() {
t.Errorf("time.Millisecond*15 <= wn.PickElapsed()(%v)", wn.PickElapsed())
}
if time.Millisecond*5 >= wn.PickElapsed() {
t.Errorf("time.Millisecond*5 >= wn.PickElapsed()(%v)", wn.PickElapsed())
}
} }
func TestDirectError(t *testing.T) { func TestDirectError(t *testing.T) {
@ -55,13 +68,18 @@ func TestDirectError(t *testing.T) {
err = context.DeadlineExceeded err = context.DeadlineExceeded
} }
done := wn.Pick() done := wn.Pick()
assert.NotNil(t, done) if done == nil {
t.Errorf("expect not nil, got nil")
}
time.Sleep(time.Millisecond * 20) time.Sleep(time.Millisecond * 20)
done(context.Background(), selector.DoneInfo{Err: err}) done(context.Background(), selector.DoneInfo{Err: err})
} }
if float64(30000) >= wn.Weight() {
assert.Less(t, float64(30000), wn.Weight()) t.Errorf("float64(30000) >= wn.Weight()(%v)", wn.Weight())
assert.Greater(t, float64(60000), wn.Weight()) }
if float64(60000) <= wn.Weight() {
t.Errorf("float64(60000) <= wn.Weight()(%v)", wn.Weight())
}
} }
func TestDirectErrorHandler(t *testing.T) { func TestDirectErrorHandler(t *testing.T) {
@ -86,11 +104,16 @@ func TestDirectErrorHandler(t *testing.T) {
err = context.DeadlineExceeded err = context.DeadlineExceeded
} }
done := wn.Pick() done := wn.Pick()
assert.NotNil(t, done) if done == nil {
t.Errorf("expect not nil, got nil")
}
time.Sleep(time.Millisecond * 20) time.Sleep(time.Millisecond * 20)
done(context.Background(), selector.DoneInfo{Err: err}) done(context.Background(), selector.DoneInfo{Err: err})
} }
if float64(30000) >= wn.Weight() {
assert.Less(t, float64(30000), wn.Weight()) t.Errorf("float64(30000) >= wn.Weight()(%v)", wn.Weight())
assert.Greater(t, float64(60000), wn.Weight()) }
if float64(60000) <= wn.Weight() {
t.Errorf("float64(60000) <= wn.Weight()(%v)", wn.Weight())
}
} }

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"math/rand" "math/rand"
"reflect"
"sync" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
@ -12,7 +13,6 @@ import (
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/filter" "github.com/go-kratos/kratos/v2/selector/filter"
"github.com/stretchr/testify/assert"
) )
func TestWrr3(t *testing.T) { func TestWrr3(t *testing.T) {
@ -41,9 +41,15 @@ func TestWrr3(t *testing.T) {
lk.Unlock() lk.Unlock()
time.Sleep(d) time.Sleep(d)
n, done, err := p2c.Select(context.Background()) n, done, err := p2c.Select(context.Background())
assert.Nil(t, err) if err != nil {
assert.NotNil(t, done) t.Errorf("expect %v, got %v", nil, err)
assert.NotNil(t, n) }
if n == nil {
t.Errorf("expect %v, got %v", nil, n)
}
if done == nil {
t.Errorf("expect %v, got %v", nil, done)
}
time.Sleep(time.Millisecond * 10) time.Sleep(time.Millisecond * 10)
done(context.Background(), selector.DoneInfo{}) done(context.Background(), selector.DoneInfo{})
if n.Address() == "127.0.0.0:8080" { if n.Address() == "127.0.0.0:8080" {
@ -56,18 +62,32 @@ func TestWrr3(t *testing.T) {
}() }()
} }
group.Wait() group.Wait()
assert.Greater(t, count1, int64(1500)) if count1 <= int64(1500) {
assert.Less(t, count1, int64(4500)) t.Errorf("count1(%v) <= int64(1500)", count1)
assert.Greater(t, count2, int64(1500)) }
assert.Less(t, count2, int64(4500)) if count1 >= int64(4500) {
assert.Greater(t, count3, int64(1500)) t.Errorf("count1(%v) >= int64(4500),", count1)
assert.Less(t, count3, int64(4500)) }
if count2 <= int64(1500) {
t.Errorf("count2(%v) <= int64(1500)", count1)
}
if count2 >= int64(4500) {
t.Errorf("count2(%v) >= int64(4500),", count2)
}
if count3 <= int64(1500) {
t.Errorf("count3(%v) <= int64(1500)", count3)
}
if count3 >= int64(4500) {
t.Errorf("count3(%v) >= int64(4500),", count3)
}
} }
func TestEmpty(t *testing.T) { func TestEmpty(t *testing.T) {
b := &Balancer{} b := &Balancer{}
_, _, err := b.Pick(context.Background(), []selector.WeightedNode{}) _, _, err := b.Pick(context.Background(), []selector.WeightedNode{})
assert.NotNil(t, err) if err == nil {
t.Errorf("expect %v, got %v", nil, err)
}
} }
func TestOne(t *testing.T) { func TestOne(t *testing.T) {
@ -85,8 +105,16 @@ func TestOne(t *testing.T) {
} }
p2c.Apply(nodes) p2c.Apply(nodes)
n, done, err := p2c.Select(context.Background()) n, done, err := p2c.Select(context.Background())
assert.Nil(t, err) if err != nil {
assert.NotNil(t, done) t.Errorf("expect %v, got %v", nil, err)
assert.NotNil(t, n) }
assert.Equal(t, "127.0.0.0:8080", n.Address()) if n == nil {
t.Errorf("expect %v, got %v", nil, n)
}
if done == nil {
t.Errorf("expect %v, got %v", nil, done)
}
if !reflect.DeepEqual("127.0.0.0:8080", n.Address()) {
t.Errorf("expect %v, got %v", "127.0.0.0:8080", n.Address())
}
} }

@ -7,7 +7,6 @@ import (
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/filter" "github.com/go-kratos/kratos/v2/selector/filter"
"github.com/stretchr/testify/assert"
) )
func TestWrr(t *testing.T) { func TestWrr(t *testing.T) {
@ -31,9 +30,15 @@ func TestWrr(t *testing.T) {
var count1, count2 int var count1, count2 int
for i := 0; i < 200; i++ { for i := 0; i < 200; i++ {
n, done, err := random.Select(context.Background()) n, done, err := random.Select(context.Background())
assert.Nil(t, err) if err != nil {
assert.NotNil(t, done) t.Errorf("expect no error, got %v", err)
assert.NotNil(t, n) }
if done == nil {
t.Errorf("expect not nil, got:%v", done)
}
if n == nil {
t.Errorf("expect not nil, got:%v", n)
}
done(context.Background(), selector.DoneInfo{}) done(context.Background(), selector.DoneInfo{})
if n.Address() == "127.0.0.1:8080" { if n.Address() == "127.0.0.1:8080" {
count1++ count1++
@ -41,14 +46,24 @@ func TestWrr(t *testing.T) {
count2++ count2++
} }
} }
assert.Greater(t, count1, 80) if count1 <= 80 {
assert.Less(t, count1, 120) t.Errorf("count1(%v) <= 80", count1)
assert.Greater(t, count2, 80) }
assert.Less(t, count2, 120) if count1 >= 120 {
t.Errorf("count1(%v) >= 120", count1)
}
if count2 <= 80 {
t.Errorf("count2(%v) <= 80", count2)
}
if count2 >= 120 {
t.Errorf("count2(%v) >= 120", count2)
}
} }
func TestEmpty(t *testing.T) { func TestEmpty(t *testing.T) {
b := &Balancer{} b := &Balancer{}
_, _, err := b.Pick(context.Background(), []selector.WeightedNode{}) _, _, err := b.Pick(context.Background(), []selector.WeightedNode{})
assert.NotNil(t, err) if err == nil {
t.Errorf("expect nil, got %v", err)
}
} }

@ -2,13 +2,14 @@ package selector
import ( import (
"context" "context"
"errors"
"math/rand" "math/rand"
"reflect"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/stretchr/testify/assert"
) )
type mockWeightedNode struct { type mockWeightedNode struct {
@ -107,33 +108,67 @@ func TestDefault(t *testing.T) {
})) }))
selector.Apply(nodes) selector.Apply(nodes)
n, done, err := selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) n, done, err := selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0")))
assert.Nil(t, err) if err != nil {
assert.NotNil(t, n) t.Errorf("expect %v, got %v", nil, err)
assert.NotNil(t, done) }
assert.Equal(t, "v2.0.0", n.Version()) if n == nil {
assert.NotNil(t, n.Address()) t.Errorf("expect %v, got %v", nil, n)
assert.Equal(t, int64(10), *n.InitialWeight()) }
assert.NotNil(t, n.Metadata()) if done == nil {
assert.Equal(t, "helloworld", n.ServiceName()) t.Errorf("expect %v, got %v", nil, done)
}
if !reflect.DeepEqual("v2.0.0", n.Version()) {
t.Errorf("expect %v, got %v", "v2.0.0", n.Version())
}
if n.Address() == "" {
t.Errorf("expect %v, got %v", "", n.Address())
}
if !reflect.DeepEqual(int64(10), *n.InitialWeight()) {
t.Errorf("expect %v, got %v", 10, *n.InitialWeight())
}
if n.Metadata() == nil {
t.Errorf("expect %v, got %v", nil, n.Metadata())
}
if !reflect.DeepEqual("helloworld", n.ServiceName()) {
t.Errorf("expect %v, got %v", "helloworld", n.ServiceName())
}
done(context.Background(), DoneInfo{}) done(context.Background(), DoneInfo{})
// no v3.0.0 instance // no v3.0.0 instance
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v3.0.0"))) n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v3.0.0")))
assert.Equal(t, ErrNoAvailable, err) if !errors.Is(ErrNoAvailable, err) {
assert.Nil(t, done) t.Errorf("expect %v, got %v", ErrNoAvailable, err)
assert.Nil(t, n) }
if done != nil {
t.Errorf("expect %v, got %v", nil, done)
}
if n != nil {
t.Errorf("expect %v, got %v", nil, n)
}
// apply zero instance // apply zero instance
selector.Apply([]Node{}) selector.Apply([]Node{})
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0")))
assert.Equal(t, ErrNoAvailable, err) if !errors.Is(ErrNoAvailable, err) {
assert.Nil(t, done) t.Errorf("expect %v, got %v", ErrNoAvailable, err)
assert.Nil(t, n) }
if done != nil {
t.Errorf("expect %v, got %v", nil, done)
}
if n != nil {
t.Errorf("expect %v, got %v", nil, n)
}
// apply zero instance // apply zero instance
selector.Apply(nil) selector.Apply(nil)
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0")))
assert.Equal(t, ErrNoAvailable, err) if !errors.Is(ErrNoAvailable, err) {
assert.Nil(t, done) t.Errorf("expect %v, got %v", ErrNoAvailable, err)
assert.Nil(t, n) }
if done != nil {
t.Errorf("expect %v, got %v", nil, done)
}
if n != nil {
t.Errorf("expect %v, got %v", nil, n)
}
} }

@ -2,12 +2,12 @@ package wrr
import ( import (
"context" "context"
"reflect"
"testing" "testing"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/filter" "github.com/go-kratos/kratos/v2/selector/filter"
"github.com/stretchr/testify/assert"
) )
func TestWrr(t *testing.T) { func TestWrr(t *testing.T) {
@ -31,9 +31,15 @@ func TestWrr(t *testing.T) {
var count1, count2 int var count1, count2 int
for i := 0; i < 90; i++ { for i := 0; i < 90; i++ {
n, done, err := wrr.Select(context.Background()) n, done, err := wrr.Select(context.Background())
assert.Nil(t, err) if err != nil {
assert.NotNil(t, done) t.Errorf("expect no error, got %v", err)
assert.NotNil(t, n) }
if done == nil {
t.Errorf("expect done callback, got nil")
}
if n == nil {
t.Errorf("expect node, got nil")
}
done(context.Background(), selector.DoneInfo{}) done(context.Background(), selector.DoneInfo{})
if n.Address() == "127.0.0.1:8080" { if n.Address() == "127.0.0.1:8080" {
count1++ count1++
@ -41,12 +47,18 @@ func TestWrr(t *testing.T) {
count2++ count2++
} }
} }
assert.Equal(t, 30, count1) if !reflect.DeepEqual(count1, 30) {
assert.Equal(t, 60, count2) t.Errorf("expect 30, got %d", count1)
}
if !reflect.DeepEqual(count2, 60) {
t.Errorf("expect 60, got %d", count2)
}
} }
func TestEmpty(t *testing.T) { func TestEmpty(t *testing.T) {
b := &Balancer{} b := &Balancer{}
_, _, err := b.Pick(context.Background(), []selector.WeightedNode{}) _, _, err := b.Pick(context.Background(), []selector.WeightedNode{})
assert.NotNil(t, err) if err == nil {
t.Errorf("expect no error, got %v", err)
}
} }

@ -2,24 +2,30 @@ package grpc
import ( import (
"context" "context"
"reflect"
"testing" "testing"
"github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/selector"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
) )
func TestTrailer(t *testing.T) { func TestTrailer(t *testing.T) {
trailer := Trailer(metadata.New(map[string]string{"a": "b"})) trailer := Trailer(metadata.New(map[string]string{"a": "b"}))
assert.Equal(t, "b", trailer.Get("a")) if !reflect.DeepEqual("b", trailer.Get("a")) {
assert.Equal(t, "", trailer.Get("3")) t.Errorf("expect %v, got %v", "b", trailer.Get("a"))
}
if !reflect.DeepEqual("", trailer.Get("notfound")) {
t.Errorf("expect %v, got %v", "", trailer.Get("notfound"))
}
} }
func TestBalancerName(t *testing.T) { func TestBalancerName(t *testing.T) {
o := &clientOptions{} o := &clientOptions{}
WithBalancerName("p2c")(o) WithBalancerName("p2c")(o)
assert.Equal(t, "p2c", o.balancerName) if !reflect.DeepEqual("p2c", o.balancerName) {
t.Errorf("expect %v, got %v", "p2c", o.balancerName)
}
} }
func TestFilters(t *testing.T) { func TestFilters(t *testing.T) {
@ -28,5 +34,7 @@ func TestFilters(t *testing.T) {
WithFilter(func(_ context.Context, nodes []selector.Node) []selector.Node { WithFilter(func(_ context.Context, nodes []selector.Node) []selector.Node {
return nodes return nodes
})(o) })(o)
assert.Equal(t, 1, len(o.filters)) if !reflect.DeepEqual(1, len(o.filters)) {
t.Errorf("expect %v, got %v", 1, len(o.filters))
}
} }

@ -3,13 +3,13 @@ package grpc
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"reflect"
"testing" "testing"
"time" "time"
"github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
@ -17,14 +17,18 @@ func TestWithEndpoint(t *testing.T) {
o := &clientOptions{} o := &clientOptions{}
v := "abc" v := "abc"
WithEndpoint(v)(o) WithEndpoint(v)(o)
assert.Equal(t, v, o.endpoint) if !reflect.DeepEqual(v, o.endpoint) {
t.Errorf("expect %v but got %v", v, o.endpoint)
}
} }
func TestWithTimeout(t *testing.T) { func TestWithTimeout(t *testing.T) {
o := &clientOptions{} o := &clientOptions{}
v := time.Duration(123) v := time.Duration(123)
WithTimeout(v)(o) WithTimeout(v)(o)
assert.Equal(t, v, o.timeout) if !reflect.DeepEqual(v, o.timeout) {
t.Errorf("expect %v but got %v", v, o.timeout)
}
} }
func TestWithMiddleware(t *testing.T) { func TestWithMiddleware(t *testing.T) {
@ -33,7 +37,9 @@ func TestWithMiddleware(t *testing.T) {
func(middleware.Handler) middleware.Handler { return nil }, func(middleware.Handler) middleware.Handler { return nil },
} }
WithMiddleware(v...)(o) WithMiddleware(v...)(o)
assert.Equal(t, v, o.middleware) if !reflect.DeepEqual(v, o.middleware) {
t.Errorf("expect %v but got %v", v, o.middleware)
}
} }
type mockRegistry struct{} type mockRegistry struct{}
@ -50,21 +56,27 @@ func TestWithDiscovery(t *testing.T) {
o := &clientOptions{} o := &clientOptions{}
v := &mockRegistry{} v := &mockRegistry{}
WithDiscovery(v)(o) WithDiscovery(v)(o)
assert.Equal(t, v, o.discovery) if !reflect.DeepEqual(v, o.discovery) {
t.Errorf("expect %v but got %v", v, o.discovery)
}
} }
func TestWithTLSConfig(t *testing.T) { func TestWithTLSConfig(t *testing.T) {
o := &clientOptions{} o := &clientOptions{}
v := &tls.Config{} v := &tls.Config{}
WithTLSConfig(v)(o) WithTLSConfig(v)(o)
assert.Equal(t, v, o.tlsConf) if !reflect.DeepEqual(v, o.tlsConf) {
t.Errorf("expect %v but got %v", v, o.tlsConf)
}
} }
func TestWithLogger(t *testing.T) { func TestWithLogger(t *testing.T) {
o := &clientOptions{} o := &clientOptions{}
v := log.DefaultLogger v := log.DefaultLogger
WithLogger(v)(o) WithLogger(v)(o)
assert.Equal(t, v, o.logger) if !reflect.DeepEqual(v, o.logger) {
t.Errorf("expect %v but got %v", v, o.logger)
}
} }
func EmptyMiddleware() middleware.Middleware { func EmptyMiddleware() middleware.Middleware {
@ -84,7 +96,9 @@ func TestUnaryClientInterceptor(t *testing.T) {
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return nil return nil
}) })
assert.NoError(t, err) if err != nil {
t.Errorf("unexpected error: %v", err)
}
} }
func TestWithUnaryInterceptor(t *testing.T) { func TestWithUnaryInterceptor(t *testing.T) {
@ -100,7 +114,9 @@ func TestWithUnaryInterceptor(t *testing.T) {
}, },
} }
WithUnaryInterceptor(v...)(o) WithUnaryInterceptor(v...)(o)
assert.Equal(t, v, o.ints) if !reflect.DeepEqual(v, o.ints) {
t.Errorf("expect %v but got %v", v, o.ints)
}
} }
func TestWithOptions(t *testing.T) { func TestWithOptions(t *testing.T) {
@ -109,7 +125,9 @@ func TestWithOptions(t *testing.T) {
grpc.EmptyDialOption{}, grpc.EmptyDialOption{},
} }
WithOptions(v...)(o) WithOptions(v...)(o)
assert.Equal(t, v, o.grpcOpts) if !reflect.DeepEqual(v, o.grpcOpts) {
t.Errorf("expect %v but got %v", v, o.grpcOpts)
}
} }
func TestDial(t *testing.T) { func TestDial(t *testing.T) {
@ -118,7 +136,9 @@ func TestDial(t *testing.T) {
grpc.EmptyDialOption{}, grpc.EmptyDialOption{},
} }
WithOptions(v...)(o) WithOptions(v...)(o)
assert.Equal(t, v, o.grpcOpts) if !reflect.DeepEqual(v, o.grpcOpts) {
t.Errorf("expect %v but got %v", v, o.grpcOpts)
}
} }
func TestDialConn(t *testing.T) { func TestDialConn(t *testing.T) {

@ -1,16 +1,18 @@
package direct package direct
import ( import (
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/serviceconfig"
) )
func TestDirectBuilder_Scheme(t *testing.T) { func TestDirectBuilder_Scheme(t *testing.T) {
b := NewBuilder() b := NewBuilder()
assert.Equal(t, "direct", b.Scheme()) if !reflect.DeepEqual(b.Scheme(), "direct") {
t.Errorf("expect %v, got %v", "direct", b.Scheme())
}
} }
type mockConn struct{} type mockConn struct{}
@ -32,6 +34,8 @@ func (m *mockConn) ParseServiceConfig(serviceConfigJSON string) *serviceconfig.P
func TestDirectBuilder_Build(t *testing.T) { func TestDirectBuilder_Build(t *testing.T) {
b := NewBuilder() b := NewBuilder()
r, err := b.Build(resolver.Target{}, &mockConn{}, resolver.BuildOptions{}) r, err := b.Build(resolver.Target{}, &mockConn{}, resolver.BuildOptions{})
assert.NoError(t, err) if err != nil {
t.Errorf("expect no error, got %v", err)
}
r.ResolveNow(resolver.ResolveNowOptions{}) r.ResolveNow(resolver.ResolveNowOptions{})
} }

@ -2,12 +2,12 @@ package discovery
import ( import (
"context" "context"
"reflect"
"testing" "testing"
"time" "time"
"github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/serviceconfig"
) )
@ -33,14 +33,18 @@ func TestWithLogger(t *testing.T) {
func TestWithInsecure(t *testing.T) { func TestWithInsecure(t *testing.T) {
b := &builder{} b := &builder{}
WithInsecure(true)(b) WithInsecure(true)(b)
assert.True(t, b.insecure) if !b.insecure {
t.Errorf("expected insecure to be true")
}
} }
func TestWithTimeout(t *testing.T) { func TestWithTimeout(t *testing.T) {
o := &builder{} o := &builder{}
v := time.Duration(123) v := time.Duration(123)
WithTimeout(v)(o) WithTimeout(v)(o)
assert.Equal(t, v, o.timeout) if !reflect.DeepEqual(v, o.timeout) {
t.Errorf("expected %v, got %v", v, o.timeout)
}
} }
type mockDiscovery struct{} type mockDiscovery struct{}
@ -55,7 +59,9 @@ func (m *mockDiscovery) Watch(ctx context.Context, serviceName string) (registry
func TestBuilder_Scheme(t *testing.T) { func TestBuilder_Scheme(t *testing.T) {
b := NewBuilder(&mockDiscovery{}) b := NewBuilder(&mockDiscovery{})
assert.Equal(t, "discovery", b.Scheme()) if !reflect.DeepEqual("discovery", b.Scheme()) {
t.Errorf("expected %v, got %v", "discovery", b.Scheme())
}
} }
type mockConn struct{} type mockConn struct{}
@ -77,5 +83,7 @@ func (m *mockConn) ParseServiceConfig(serviceConfigJSON string) *serviceconfig.P
func TestBuilder_Build(t *testing.T) { func TestBuilder_Build(t *testing.T) {
b := NewBuilder(&mockDiscovery{}) b := NewBuilder(&mockDiscovery{})
_, err := b.Build(resolver.Target{Scheme: resolver.GetDefaultScheme(), Endpoint: "gprc://authority/endpoint"}, &mockConn{}, resolver.BuildOptions{}) _, err := b.Build(resolver.Target{Scheme: resolver.GetDefaultScheme(), Endpoint: "gprc://authority/endpoint"}, &mockConn{}, resolver.BuildOptions{})
assert.NoError(t, err) if err != nil {
t.Errorf("expected no error, got %v", err)
}
} }

@ -3,12 +3,12 @@ package discovery
import ( import (
"context" "context"
"errors" "errors"
"reflect"
"testing" "testing"
"time" "time"
"github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
) )
@ -107,8 +107,14 @@ func TestWatchContextCancel(t *testing.T) {
func TestParseAttributes(t *testing.T) { func TestParseAttributes(t *testing.T) {
a := parseAttributes(map[string]string{"a": "b"}) a := parseAttributes(map[string]string{"a": "b"})
assert.Equal(t, "b", a.Value("a").(string)) if !reflect.DeepEqual("b", a.Value("a").(string)) {
t.Errorf("expect b, got %v", a.Value("a"))
}
x := a.WithValue("qq", "ww") x := a.WithValue("qq", "ww")
assert.Equal(t, "ww", x.Value("qq").(string)) if !reflect.DeepEqual("ww", x.Value("qq").(string)) {
assert.Nil(t, x.Value("notfound")) t.Errorf("expect ww, got %v", x.Value("qq"))
}
if x.Value("notfound") != nil {
t.Errorf("expect nil, got %v", x.Value("notfound"))
}
} }

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
"reflect"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -16,7 +17,6 @@ import (
"github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
@ -104,9 +104,13 @@ func testClient(t *testing.T, srv *Server) {
} }
client := pb.NewGreeterClient(conn) client := pb.NewGreeterClient(conn)
reply, err := client.SayHello(context.Background(), &pb.HelloRequest{Name: "kratos"}) reply, err := client.SayHello(context.Background(), &pb.HelloRequest{Name: "kratos"})
fmt.Println(err) t.Log(err)
assert.Nil(t, err) if err != nil {
assert.Equal(t, "Hello kratos", reply.Message) t.Errorf("failed to call: %v", err)
}
if !reflect.DeepEqual(reply.Message, "Hello kratos") {
t.Errorf("expect %s, got %s", "Hello kratos", reply.Message)
}
_ = conn.Close() _ = conn.Close()
} }
@ -114,23 +118,33 @@ func TestNetwork(t *testing.T) {
o := &Server{} o := &Server{}
v := "abc" v := "abc"
Network(v)(o) Network(v)(o)
assert.Equal(t, v, o.network) if !reflect.DeepEqual(v, o.network) {
t.Errorf("expect %s, got %s", v, o.network)
}
} }
func TestAddress(t *testing.T) { func TestAddress(t *testing.T) {
v := "abc" v := "abc"
o := NewServer(Address(v)) o := NewServer(Address(v))
assert.Equal(t, v, o.address) if !reflect.DeepEqual(v, o.address) {
t.Errorf("expect %s, got %s", v, o.address)
}
u, err := o.Endpoint() u, err := o.Endpoint()
assert.NotNil(t, err) if err == nil {
assert.Nil(t, u) t.Errorf("expect %s, got %s", v, err)
}
if u != nil {
t.Errorf("expect %s, got %s", v, u)
}
} }
func TestTimeout(t *testing.T) { func TestTimeout(t *testing.T) {
o := &Server{} o := &Server{}
v := time.Duration(123) v := time.Duration(123)
Timeout(v)(o) Timeout(v)(o)
assert.Equal(t, v, o.timeout) if !reflect.DeepEqual(v, o.timeout) {
t.Errorf("expect %s, got %s", v, o.timeout)
}
} }
func TestMiddleware(t *testing.T) { func TestMiddleware(t *testing.T) {
@ -139,7 +153,9 @@ func TestMiddleware(t *testing.T) {
func(middleware.Handler) middleware.Handler { return nil }, func(middleware.Handler) middleware.Handler { return nil },
} }
Middleware(v...)(o) Middleware(v...)(o)
assert.Equal(t, v, o.middleware) if !reflect.DeepEqual(v, o.middleware) {
t.Errorf("expect %v, got %v", v, o.middleware)
}
} }
type mockLogger struct { type mockLogger struct {
@ -160,16 +176,24 @@ func TestLogger(t *testing.T) {
v := &mockLogger{} v := &mockLogger{}
Logger(v)(o) Logger(v)(o)
o.log.Log(log.LevelWarn, "foo", "bar") o.log.Log(log.LevelWarn, "foo", "bar")
assert.Equal(t, "foo", v.key) if !reflect.DeepEqual("foo", v.key) {
assert.Equal(t, "bar", v.val) t.Errorf("expect %s, got %s", "foo", v.key)
assert.Equal(t, log.LevelWarn, v.level) }
if !reflect.DeepEqual("bar", v.val) {
t.Errorf("expect %s, got %s", "bar", v.val)
}
if !reflect.DeepEqual(log.LevelWarn, v.level) {
t.Errorf("expect %s, got %s", log.LevelWarn, v.level)
}
} }
func TestTLSConfig(t *testing.T) { func TestTLSConfig(t *testing.T) {
o := &Server{} o := &Server{}
v := &tls.Config{} v := &tls.Config{}
TLSConfig(v)(o) TLSConfig(v)(o)
assert.Equal(t, v, o.tlsConf) if !reflect.DeepEqual(v, o.tlsConf) {
t.Errorf("expect %v, got %v", v, o.tlsConf)
}
} }
func TestUnaryInterceptor(t *testing.T) { func TestUnaryInterceptor(t *testing.T) {
@ -183,7 +207,9 @@ func TestUnaryInterceptor(t *testing.T) {
}, },
} }
UnaryInterceptor(v...)(o) UnaryInterceptor(v...)(o)
assert.Equal(t, v, o.ints) if !reflect.DeepEqual(v, o.ints) {
t.Errorf("expect %v, got %v", v, o.ints)
}
} }
func TestOptions(t *testing.T) { func TestOptions(t *testing.T) {
@ -192,7 +218,9 @@ func TestOptions(t *testing.T) {
grpc.EmptyServerOption{}, grpc.EmptyServerOption{},
} }
Options(v...)(o) Options(v...)(o)
assert.Equal(t, v, o.grpcOpts) if !reflect.DeepEqual(v, o.grpcOpts) {
t.Errorf("expect %v, got %v", v, o.grpcOpts)
}
} }
type testResp struct { type testResp struct {
@ -201,7 +229,9 @@ type testResp struct {
func TestServer_unaryServerInterceptor(t *testing.T) { func TestServer_unaryServerInterceptor(t *testing.T) {
u, err := url.Parse("grpc://hello/world") u, err := url.Parse("grpc://hello/world")
assert.NoError(t, err) if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}
srv := &Server{ srv := &Server{
baseCtx: context.Background(), baseCtx: context.Background(),
endpoint: u, endpoint: u,
@ -212,13 +242,19 @@ func TestServer_unaryServerInterceptor(t *testing.T) {
rv, err := srv.unaryServerInterceptor()(context.TODO(), req, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (i interface{}, e error) { rv, err := srv.unaryServerInterceptor()(context.TODO(), req, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (i interface{}, e error) {
return &testResp{Data: "hi"}, nil return &testResp{Data: "hi"}, nil
}) })
assert.NoError(t, err) if err != nil {
assert.Equal(t, "hi", rv.(*testResp).Data) t.Errorf("expect %v, got %v", nil, err)
}
if !reflect.DeepEqual("hi", rv.(*testResp).Data) {
t.Errorf("expect %s, got %s", "hi", rv.(*testResp).Data)
}
} }
func TestListener(t *testing.T) { func TestListener(t *testing.T) {
lis := &net.TCPListener{} lis := &net.TCPListener{}
s := &Server{} s := &Server{}
Listener(lis)(s) Listener(lis)(s)
assert.Equal(t, s.lis, lis) if !reflect.DeepEqual(lis, s.lis) {
t.Errorf("expect %v, got %v", lis, s.lis)
}
} }

@ -1,48 +1,70 @@
package grpc package grpc
import ( import (
"reflect"
"sort"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport"
) )
func TestTransport_Kind(t *testing.T) { func TestTransport_Kind(t *testing.T) {
o := &Transport{} o := &Transport{}
assert.Equal(t, transport.KindGRPC, o.Kind()) if !reflect.DeepEqual(transport.KindGRPC, o.Kind()) {
t.Errorf("expect %v, got %v", transport.KindGRPC, o.Kind())
}
} }
func TestTransport_Endpoint(t *testing.T) { func TestTransport_Endpoint(t *testing.T) {
v := "hello" v := "hello"
o := &Transport{endpoint: v} o := &Transport{endpoint: v}
assert.Equal(t, v, o.Endpoint()) if !reflect.DeepEqual(v, o.Endpoint()) {
t.Errorf("expect %v, got %v", v, o.Endpoint())
}
} }
func TestTransport_Operation(t *testing.T) { func TestTransport_Operation(t *testing.T) {
v := "hello" v := "hello"
o := &Transport{operation: v} o := &Transport{operation: v}
assert.Equal(t, v, o.Operation()) if !reflect.DeepEqual(v, o.Operation()) {
t.Errorf("expect %v, got %v", v, o.Operation())
}
} }
func TestTransport_RequestHeader(t *testing.T) { func TestTransport_RequestHeader(t *testing.T) {
v := headerCarrier{} v := headerCarrier{}
v.Set("a", "1") v.Set("a", "1")
o := &Transport{reqHeader: v} o := &Transport{reqHeader: v}
assert.Equal(t, "1", o.RequestHeader().Get("a")) if !reflect.DeepEqual("1", o.RequestHeader().Get("a")) {
assert.Equal(t, "", o.RequestHeader().Get("notfound")) t.Errorf("expect %v, got %v", "1", o.RequestHeader().Get("a"))
}
if !reflect.DeepEqual("", o.RequestHeader().Get("notfound")) {
t.Errorf("expect %v, got %v", "", o.RequestHeader().Get("notfound"))
}
} }
func TestTransport_ReplyHeader(t *testing.T) { func TestTransport_ReplyHeader(t *testing.T) {
v := headerCarrier{} v := headerCarrier{}
v.Set("a", "1") v.Set("a", "1")
o := &Transport{replyHeader: v} o := &Transport{replyHeader: v}
assert.Equal(t, "1", o.ReplyHeader().Get("a")) if !reflect.DeepEqual("1", o.ReplyHeader().Get("a")) {
t.Errorf("expect %v, got %v", "1", o.ReplyHeader().Get("a"))
}
} }
func TestHeaderCarrier_Keys(t *testing.T) { func TestHeaderCarrier_Keys(t *testing.T) {
v := headerCarrier{} v := headerCarrier{}
v.Set("abb", "1") v.Set("abb", "1")
v.Set("bcc", "2") v.Set("bcc", "2")
assert.ElementsMatch(t, []string{"abb", "bcc"}, v.Keys()) want := []string{"abb", "bcc"}
keys := v.Keys()
sort.Slice(want, func(i, j int) bool {
return want[i] < want[j]
})
sort.Slice(keys, func(i, j int) bool {
return keys[i] < keys[j]
})
if !reflect.DeepEqual(want, keys) {
t.Errorf("expect %v, got %v", want, keys)
}
} }

@ -2,60 +2,88 @@ package http
import ( import (
"net/http" "net/http"
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestEmptyCallOptions(t *testing.T) { func TestEmptyCallOptions(t *testing.T) {
assert.NoError(t, EmptyCallOption{}.before(&callInfo{})) e := EmptyCallOption{}
EmptyCallOption{}.after(&callInfo{}, &csAttempt{}) if e.before(&callInfo{}) != nil {
t.Error("EmptyCallOption should be ignored")
}
e.after(&callInfo{}, &csAttempt{})
} }
func TestContentType(t *testing.T) { func TestContentType(t *testing.T) {
assert.Equal(t, "aaa", ContentType("aaa").(ContentTypeCallOption).ContentType) if !reflect.DeepEqual(ContentType("aaa").(ContentTypeCallOption).ContentType, "aaa") {
t.Errorf("want: %v,got: %v", "aaa", ContentType("aaa").(ContentTypeCallOption).ContentType)
}
} }
func TestContentTypeCallOption_before(t *testing.T) { func TestContentTypeCallOption_before(t *testing.T) {
c := &callInfo{} c := &callInfo{}
err := ContentType("aaa").before(c) err := ContentType("aaa").before(c)
assert.NoError(t, err) if err != nil {
assert.Equal(t, "aaa", c.contentType) t.Errorf("unexpected error: %v", err)
}
if !reflect.DeepEqual("aaa", c.contentType) {
t.Errorf("want: %v, got: %v", "aaa", c.contentType)
}
} }
func TestDefaultCallInfo(t *testing.T) { func TestDefaultCallInfo(t *testing.T) {
path := "hi" path := "hi"
rv := defaultCallInfo(path) rv := defaultCallInfo(path)
assert.Equal(t, path, rv.pathTemplate) if !reflect.DeepEqual(path, rv.pathTemplate) {
assert.Equal(t, path, rv.operation) t.Errorf("expect %v, got %v", path, rv.pathTemplate)
assert.Equal(t, "application/json", rv.contentType) }
if !reflect.DeepEqual(path, rv.operation) {
t.Errorf("expect %v, got %v", path, rv.operation)
}
if !reflect.DeepEqual("application/json", rv.contentType) {
t.Errorf("expect %v, got %v", "application/json", rv.contentType)
}
} }
func TestOperation(t *testing.T) { func TestOperation(t *testing.T) {
assert.Equal(t, "aaa", Operation("aaa").(OperationCallOption).Operation) if !reflect.DeepEqual("aaa", Operation("aaa").(OperationCallOption).Operation) {
t.Errorf("want: %v,got: %v", "aaa", Operation("aaa").(OperationCallOption).Operation)
}
} }
func TestOperationCallOption_before(t *testing.T) { func TestOperationCallOption_before(t *testing.T) {
c := &callInfo{} c := &callInfo{}
err := Operation("aaa").before(c) err := Operation("aaa").before(c)
assert.NoError(t, err) if err != nil {
assert.Equal(t, "aaa", c.operation) t.Errorf("unexpected error: %v", err)
}
if !reflect.DeepEqual("aaa", c.operation) {
t.Errorf("want: %v, got: %v", "aaa", c.operation)
}
} }
func TestPathTemplate(t *testing.T) { func TestPathTemplate(t *testing.T) {
assert.Equal(t, "aaa", PathTemplate("aaa").(PathTemplateCallOption).Pattern) if !reflect.DeepEqual("aaa", PathTemplate("aaa").(PathTemplateCallOption).Pattern) {
t.Errorf("want: %v,got: %v", "aaa", PathTemplate("aaa").(PathTemplateCallOption).Pattern)
}
} }
func TestPathTemplateCallOption_before(t *testing.T) { func TestPathTemplateCallOption_before(t *testing.T) {
c := &callInfo{} c := &callInfo{}
err := PathTemplate("aaa").before(c) err := PathTemplate("aaa").before(c)
assert.NoError(t, err) if err != nil {
assert.Equal(t, "aaa", c.pathTemplate) t.Errorf("unexpected error: %v", err)
}
if !reflect.DeepEqual("aaa", c.pathTemplate) {
t.Errorf("want: %v, got: %v", "aaa", c.pathTemplate)
}
} }
func TestHeader(t *testing.T) { func TestHeader(t *testing.T) {
h := http.Header{"A": []string{"123"}} h := http.Header{"A": []string{"123"}}
assert.Equal(t, "123", Header(&h).(HeaderCallOption).header.Get("A")) if !reflect.DeepEqual(Header(&h).(HeaderCallOption).header.Get("A"), "123") {
t.Errorf("want: %v,got: %v", "123", Header(&h).(HeaderCallOption).header.Get("A"))
}
} }
func TestHeaderCallOption_after(t *testing.T) { func TestHeaderCallOption_after(t *testing.T) {
@ -64,5 +92,7 @@ func TestHeaderCallOption_after(t *testing.T) {
cs := &csAttempt{res: &http.Response{Header: h}} cs := &csAttempt{res: &http.Response{Header: h}}
o := Header(&h) o := Header(&h)
o.after(c, cs) o.after(c, cs)
assert.Equal(t, &h, o.(HeaderCallOption).header) if !reflect.DeepEqual(&h, o.(HeaderCallOption).header) {
t.Errorf("want: %v,got: %v", &h, o.(HeaderCallOption).header)
}
} }

@ -5,9 +5,11 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
nethttp "net/http" nethttp "net/http"
"reflect"
"strconv" "strconv"
"testing" "testing"
"time" "time"
@ -15,7 +17,6 @@ import (
kratosErrors "github.com/go-kratos/kratos/v2/errors" kratosErrors "github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/stretchr/testify/assert"
) )
type mockRoundTripper struct{} type mockRoundTripper struct{}
@ -29,7 +30,9 @@ func TestWithTransport(t *testing.T) {
o := WithTransport(ov) o := WithTransport(ov)
co := &clientOptions{} co := &clientOptions{}
o(co) o(co)
assert.Equal(t, co.transport, ov) if !reflect.DeepEqual(co.transport, ov) {
t.Errorf("expected transport to be %v, got %v", ov, co.transport)
}
} }
func TestWithTimeout(t *testing.T) { func TestWithTimeout(t *testing.T) {
@ -37,14 +40,18 @@ func TestWithTimeout(t *testing.T) {
o := WithTimeout(ov) o := WithTimeout(ov)
co := &clientOptions{} co := &clientOptions{}
o(co) o(co)
assert.Equal(t, co.timeout, ov) if !reflect.DeepEqual(co.timeout, ov) {
t.Errorf("expected timeout to be %v, got %v", ov, co.timeout)
}
} }
func TestWithBlock(t *testing.T) { func TestWithBlock(t *testing.T) {
o := WithBlock() o := WithBlock()
co := &clientOptions{} co := &clientOptions{}
o(co) o(co)
assert.True(t, co.block) if !co.block {
t.Errorf("expected block to be true, got %v", co.block)
}
} }
func TestWithBalancer(t *testing.T) { func TestWithBalancer(t *testing.T) {
@ -55,7 +62,9 @@ func TestWithTLSConfig(t *testing.T) {
o := WithTLSConfig(ov) o := WithTLSConfig(ov)
co := &clientOptions{} co := &clientOptions{}
o(co) o(co)
assert.Same(t, ov, co.tlsConf) if !reflect.DeepEqual(co.tlsConf, ov) {
t.Errorf("expected tls config to be %v, got %v", ov, co.tlsConf)
}
} }
func TestWithUserAgent(t *testing.T) { func TestWithUserAgent(t *testing.T) {
@ -63,7 +72,9 @@ func TestWithUserAgent(t *testing.T) {
o := WithUserAgent(ov) o := WithUserAgent(ov)
co := &clientOptions{} co := &clientOptions{}
o(co) o(co)
assert.Equal(t, co.userAgent, ov) if !reflect.DeepEqual(co.userAgent, ov) {
t.Errorf("expected user agent to be %v, got %v", ov, co.userAgent)
}
} }
func TestWithMiddleware(t *testing.T) { func TestWithMiddleware(t *testing.T) {
@ -72,7 +83,9 @@ func TestWithMiddleware(t *testing.T) {
func(middleware.Handler) middleware.Handler { return nil }, func(middleware.Handler) middleware.Handler { return nil },
} }
WithMiddleware(v...)(o) WithMiddleware(v...)(o)
assert.Equal(t, v, o.middleware) if !reflect.DeepEqual(o.middleware, v) {
t.Errorf("expected middleware to be %v, got %v", v, o.middleware)
}
} }
func TestWithEndpoint(t *testing.T) { func TestWithEndpoint(t *testing.T) {
@ -80,7 +93,9 @@ func TestWithEndpoint(t *testing.T) {
o := WithEndpoint(ov) o := WithEndpoint(ov)
co := &clientOptions{} co := &clientOptions{}
o(co) o(co)
assert.Equal(t, co.endpoint, ov) if !reflect.DeepEqual(co.endpoint, ov) {
t.Errorf("expected endpoint to be %v, got %v", ov, co.endpoint)
}
} }
func TestWithRequestEncoder(t *testing.T) { func TestWithRequestEncoder(t *testing.T) {
@ -89,21 +104,27 @@ func TestWithRequestEncoder(t *testing.T) {
return nil, nil return nil, nil
} }
WithRequestEncoder(v)(o) WithRequestEncoder(v)(o)
assert.NotNil(t, o.encoder) if o.encoder == nil {
t.Errorf("expected encoder to be not nil")
}
} }
func TestWithResponseDecoder(t *testing.T) { func TestWithResponseDecoder(t *testing.T) {
o := &clientOptions{} o := &clientOptions{}
v := func(ctx context.Context, res *nethttp.Response, out interface{}) error { return nil } v := func(ctx context.Context, res *nethttp.Response, out interface{}) error { return nil }
WithResponseDecoder(v)(o) WithResponseDecoder(v)(o)
assert.NotNil(t, o.decoder) if o.decoder == nil {
t.Errorf("expected encoder to be not nil")
}
} }
func TestWithErrorDecoder(t *testing.T) { func TestWithErrorDecoder(t *testing.T) {
o := &clientOptions{} o := &clientOptions{}
v := func(ctx context.Context, res *nethttp.Response) error { return nil } v := func(ctx context.Context, res *nethttp.Response) error { return nil }
WithErrorDecoder(v)(o) WithErrorDecoder(v)(o)
assert.NotNil(t, o.errorDecoder) if o.errorDecoder == nil {
t.Errorf("expected encoder to be not nil")
}
} }
type mockDiscovery struct{} type mockDiscovery struct{}
@ -139,7 +160,9 @@ func TestWithDiscovery(t *testing.T) {
o := WithDiscovery(ov) o := WithDiscovery(ov)
co := &clientOptions{} co := &clientOptions{}
o(co) o(co)
assert.Equal(t, co.discovery, ov) if !reflect.DeepEqual(co.discovery, ov) {
t.Errorf("expected discovery to be %v, got %v", ov, co.discovery)
}
} }
func TestDefaultRequestEncoder(t *testing.T) { func TestDefaultRequestEncoder(t *testing.T) {
@ -154,14 +177,20 @@ func TestDefaultRequestEncoder(t *testing.T) {
B int64 `json:"b"` B int64 `json:"b"`
}{"a", 1} }{"a", 1}
b, err1 := DefaultRequestEncoder(context.TODO(), "application/json", v1) b, err1 := DefaultRequestEncoder(context.TODO(), "application/json", v1)
assert.Nil(t, err1) if err1 != nil {
t.Errorf("expected no error, got %v", err1)
}
v1b := &struct { v1b := &struct {
A string `json:"a"` A string `json:"a"`
B int64 `json:"b"` B int64 `json:"b"`
}{} }{}
err1 = json.Unmarshal(b, v1b) err1 = json.Unmarshal(b, v1b)
assert.Nil(t, err1) if err1 != nil {
assert.Equal(t, v1, v1b) t.Errorf("expected no error, got %v", err1)
}
if !reflect.DeepEqual(v1b, v1) {
t.Errorf("expected %v, got %v", v1, v1b)
}
} }
func TestDefaultResponseDecoder(t *testing.T) { func TestDefaultResponseDecoder(t *testing.T) {
@ -175,9 +204,15 @@ func TestDefaultResponseDecoder(t *testing.T) {
B int64 `json:"b"` B int64 `json:"b"`
}{} }{}
err1 := DefaultResponseDecoder(context.TODO(), resp1, &v1) err1 := DefaultResponseDecoder(context.TODO(), resp1, &v1)
assert.Nil(t, err1) if err1 != nil {
assert.Equal(t, "1", v1.A) t.Errorf("expected no error, got %v", err1)
assert.Equal(t, int64(2), v1.B) }
if !reflect.DeepEqual("1", v1.A) {
t.Errorf("expected %v, got %v", "1", v1.A)
}
if !reflect.DeepEqual(int64(2), v1.B) {
t.Errorf("expected %v, got %v", 2, v1.B)
}
resp2 := &nethttp.Response{ resp2 := &nethttp.Response{
Header: make(nethttp.Header), Header: make(nethttp.Header),
@ -190,20 +225,26 @@ func TestDefaultResponseDecoder(t *testing.T) {
}{} }{}
err2 := DefaultResponseDecoder(context.TODO(), resp2, &v2) err2 := DefaultResponseDecoder(context.TODO(), resp2, &v2)
terr1 := &json.SyntaxError{} terr1 := &json.SyntaxError{}
assert.ErrorAs(t, err2, &terr1) if !errors.As(err2, &terr1) {
t.Errorf("expected %v, got %v", terr1, err2)
}
} }
func TestDefaultErrorDecoder(t *testing.T) { func TestDefaultErrorDecoder(t *testing.T) {
for i := 200; i < 300; i++ { for i := 200; i < 300; i++ {
resp := &nethttp.Response{Header: make(nethttp.Header), StatusCode: i} resp := &nethttp.Response{Header: make(nethttp.Header), StatusCode: i}
assert.Nil(t, DefaultErrorDecoder(context.TODO(), resp)) if DefaultErrorDecoder(context.TODO(), resp) != nil {
t.Errorf("expected no error, got %v", DefaultErrorDecoder(context.TODO(), resp))
}
} }
resp1 := &nethttp.Response{ resp1 := &nethttp.Response{
Header: make(nethttp.Header), Header: make(nethttp.Header),
StatusCode: 300, StatusCode: 300,
Body: io.NopCloser(bytes.NewBufferString("{\"foo\":\"bar\"}")), Body: io.NopCloser(bytes.NewBufferString("{\"foo\":\"bar\"}")),
} }
assert.Error(t, DefaultErrorDecoder(context.TODO(), resp1)) if DefaultErrorDecoder(context.TODO(), resp1) == nil {
t.Errorf("expected error, got nil")
}
resp2 := &nethttp.Response{ resp2 := &nethttp.Response{
Header: make(nethttp.Header), Header: make(nethttp.Header),
@ -211,17 +252,27 @@ func TestDefaultErrorDecoder(t *testing.T) {
Body: io.NopCloser(bytes.NewBufferString("{\"code\":54321, \"message\": \"hi\", \"reason\": \"FOO\"}")), Body: io.NopCloser(bytes.NewBufferString("{\"code\":54321, \"message\": \"hi\", \"reason\": \"FOO\"}")),
} }
err2 := DefaultErrorDecoder(context.TODO(), resp2) err2 := DefaultErrorDecoder(context.TODO(), resp2)
assert.Error(t, err2) if err2 == nil {
assert.Equal(t, int32(500), err2.(*kratosErrors.Error).GetCode()) t.Errorf("expected error, got nil")
assert.Equal(t, "hi", err2.(*kratosErrors.Error).GetMessage()) }
assert.Equal(t, "FOO", err2.(*kratosErrors.Error).GetReason()) if !reflect.DeepEqual(int32(500), err2.(*kratosErrors.Error).GetCode()) {
t.Errorf("expected %v, got %v", 500, err2.(*kratosErrors.Error).GetCode())
}
if !reflect.DeepEqual("hi", err2.(*kratosErrors.Error).GetMessage()) {
t.Errorf("expected %v, got %v", "hi", err2.(*kratosErrors.Error).GetMessage())
}
if !reflect.DeepEqual("FOO", err2.(*kratosErrors.Error).GetReason()) {
t.Errorf("expected %v, got %v", "FOO", err2.(*kratosErrors.Error).GetReason())
}
} }
func TestCodecForResponse(t *testing.T) { func TestCodecForResponse(t *testing.T) {
resp := &nethttp.Response{Header: make(nethttp.Header)} resp := &nethttp.Response{Header: make(nethttp.Header)}
resp.Header.Set("Content-Type", "application/xml") resp.Header.Set("Content-Type", "application/xml")
c := CodecForResponse(resp) c := CodecForResponse(resp)
assert.Equal(t, "xml", c.Name()) if !reflect.DeepEqual("xml", c.Name()) {
t.Errorf("expected %v, got %v", "xml", c.Name())
}
} }
func TestNewClient(t *testing.T) { func TestNewClient(t *testing.T) {

@ -4,10 +4,10 @@ import (
"bytes" "bytes"
"io" "io"
nethttp "net/http" nethttp "net/http"
"reflect"
"testing" "testing"
"github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/errors"
"github.com/stretchr/testify/assert"
) )
func TestDefaultRequestDecoder(t *testing.T) { func TestDefaultRequestDecoder(t *testing.T) {
@ -22,9 +22,15 @@ func TestDefaultRequestDecoder(t *testing.T) {
B int64 `json:"b"` B int64 `json:"b"`
}{} }{}
err1 := DefaultRequestDecoder(req1, &v1) err1 := DefaultRequestDecoder(req1, &v1)
assert.Nil(t, err1) if err1 != nil {
assert.Equal(t, "1", v1.A) t.Errorf("expected no error, got %v", err1)
assert.Equal(t, int64(2), v1.B) }
if !reflect.DeepEqual("1", v1.A) {
t.Errorf("expected %v, got %v", "1", v1.A)
}
if !reflect.DeepEqual(int64(2), v1.B) {
t.Errorf("expected %v, got %v", 2, v1.B)
}
} }
type mockResponseWriter struct { type mockResponseWriter struct {
@ -60,10 +66,18 @@ func TestDefaultResponseEncoder(t *testing.T) {
v1 := &dataWithStatusCode{A: "1", B: 2} v1 := &dataWithStatusCode{A: "1", B: 2}
err := DefaultResponseEncoder(w, req1, v1) err := DefaultResponseEncoder(w, req1, v1)
assert.Nil(t, err) if err != nil {
assert.Equal(t, "application/json", w.Header().Get("Content-Type")) t.Errorf("expected no error, got %v", err)
assert.Equal(t, 200, w.StatusCode) }
assert.NotNil(t, w.Data) if !reflect.DeepEqual("application/json", w.Header().Get("Content-Type")) {
t.Errorf("expected %v, got %v", "application/json", w.Header().Get("Content-Type"))
}
if !reflect.DeepEqual(200, w.StatusCode) {
t.Errorf("expected %v, got %v", 200, w.StatusCode)
}
if w.Data == nil {
t.Errorf("expected not nil, got %v", w.Data)
}
} }
func TestDefaultResponseEncoderWithError(t *testing.T) { func TestDefaultResponseEncoderWithError(t *testing.T) {
@ -75,9 +89,15 @@ func TestDefaultResponseEncoderWithError(t *testing.T) {
se := &errors.Error{Code: 511} se := &errors.Error{Code: 511}
DefaultErrorEncoder(w, req, se) DefaultErrorEncoder(w, req, se)
assert.Equal(t, "application/json", w.Header().Get("Content-Type")) if !reflect.DeepEqual("application/json", w.Header().Get("Content-Type")) {
assert.Equal(t, 511, w.StatusCode) t.Errorf("expected %v, got %v", "application/json", w.Header().Get("Content-Type"))
assert.NotNil(t, w.Data) }
if !reflect.DeepEqual(511, w.StatusCode) {
t.Errorf("expected %v, got %v", 511, w.StatusCode)
}
if w.Data == nil {
t.Errorf("expected not nil, got %v", w.Data)
}
} }
func TestCodecForRequest(t *testing.T) { func TestCodecForRequest(t *testing.T) {
@ -88,8 +108,12 @@ func TestCodecForRequest(t *testing.T) {
req1.Header.Set("Content-Type", "application/xml") req1.Header.Set("Content-Type", "application/xml")
c, ok := CodecForRequest(req1, "Content-Type") c, ok := CodecForRequest(req1, "Content-Type")
assert.True(t, ok) if !ok {
assert.Equal(t, "xml", c.Name()) t.Errorf("expected true, got %v", ok)
}
if !reflect.DeepEqual("xml", c.Name()) {
t.Errorf("expected %v, got %v", "xml", c.Name())
}
req2 := &nethttp.Request{ req2 := &nethttp.Request{
Header: make(nethttp.Header), Header: make(nethttp.Header),
@ -98,6 +122,10 @@ func TestCodecForRequest(t *testing.T) {
req2.Header.Set("Content-Type", "blablablabla") req2.Header.Set("Content-Type", "blablablabla")
c, ok = CodecForRequest(req2, "Content-Type") c, ok = CodecForRequest(req2, "Content-Type")
assert.False(t, ok) if ok {
assert.Equal(t, "json", c.Name()) t.Errorf("expected false, got %v", ok)
}
if !reflect.DeepEqual("json", c.Name()) {
t.Errorf("expected %v, got %v", "json", c.Name())
}
} }

@ -6,10 +6,9 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"reflect"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func TestContextHeader(t *testing.T) { func TestContextHeader(t *testing.T) {
@ -20,7 +19,9 @@ func TestContextHeader(t *testing.T) {
w: responseWriter{}, w: responseWriter{},
} }
h := w.Header() h := w.Header()
assert.Equal(t, h, http.Header{"name": {"kratos"}}) if !reflect.DeepEqual(h, http.Header{"name": {"kratos"}}) {
t.Errorf("expected %v, got %v", http.Header{"name": {"kratos"}}, h)
}
} }
func TestContextForm(t *testing.T) { func TestContextForm(t *testing.T) {
@ -31,7 +32,9 @@ func TestContextForm(t *testing.T) {
w: responseWriter{}, w: responseWriter{},
} }
form := w.Form() form := w.Form()
assert.Equal(t, form, url.Values{}) if !reflect.DeepEqual(form, url.Values{}) {
t.Errorf("expected %v, got %v", url.Values{}, form)
}
w = wrapper{ w = wrapper{
router: nil, router: nil,
@ -40,7 +43,9 @@ func TestContextForm(t *testing.T) {
w: responseWriter{}, w: responseWriter{},
} }
form = w.Form() form = w.Form()
assert.Equal(t, form, url.Values{"name": []string{"kratos"}}) if !reflect.DeepEqual(form, url.Values{"name": {"kratos"}}) {
t.Errorf("expected %v, got %v", url.Values{"name": {"kratos"}}, form)
}
} }
func TestContextQuery(t *testing.T) { func TestContextQuery(t *testing.T) {
@ -51,7 +56,9 @@ func TestContextQuery(t *testing.T) {
w: responseWriter{}, w: responseWriter{},
} }
q := w.Query() q := w.Query()
assert.Equal(t, q, url.Values{"page": []string{"1"}}) if !reflect.DeepEqual(q, url.Values{"page": {"1"}}) {
t.Errorf("expected %v, got %v", url.Values{"page": {"1"}}, q)
}
} }
func TestContextRequest(t *testing.T) { func TestContextRequest(t *testing.T) {
@ -63,7 +70,9 @@ func TestContextRequest(t *testing.T) {
w: responseWriter{}, w: responseWriter{},
} }
res := w.Request() res := w.Request()
assert.Equal(t, res, req) if !reflect.DeepEqual(res, req) {
t.Errorf("expected %v, got %v", req, res)
}
} }
func TestContextResponse(t *testing.T) { func TestContextResponse(t *testing.T) {
@ -74,9 +83,13 @@ func TestContextResponse(t *testing.T) {
res: res, res: res,
w: responseWriter{200, res}, w: responseWriter{200, res},
} }
assert.Equal(t, w.Response(), res) if !reflect.DeepEqual(w.Response(), res) {
t.Errorf("expected %v, got %v", res, w.Response())
}
err := w.Returns(map[string]string{}, nil) err := w.Returns(map[string]string{}, nil)
assert.Nil(t, err) if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
} }
func TestContextBindQuery(t *testing.T) { func TestContextBindQuery(t *testing.T) {
@ -91,8 +104,12 @@ func TestContextBindQuery(t *testing.T) {
} }
b := BindQuery{} b := BindQuery{}
err := w.BindQuery(&b) err := w.BindQuery(&b)
assert.Nil(t, err) if err != nil {
assert.Equal(t, b, BindQuery{Page: 2}) t.Errorf("expected %v, got %v", nil, err)
}
if !reflect.DeepEqual(b, BindQuery{Page: 2}) {
t.Errorf("expected %v, got %v", BindQuery{Page: 2}, b)
}
} }
func TestContextBindForm(t *testing.T) { func TestContextBindForm(t *testing.T) {
@ -107,8 +124,12 @@ func TestContextBindForm(t *testing.T) {
} }
b := BindForm{} b := BindForm{}
err := w.BindForm(&b) err := w.BindForm(&b)
assert.Nil(t, err) if err != nil {
assert.Equal(t, b, BindForm{Page: 2}) t.Errorf("expected %v, got %v", nil, err)
}
if !reflect.DeepEqual(b, BindForm{Page: 2}) {
t.Errorf("expected %v, got %v", BindForm{Page: 2}, b)
}
} }
func TestContextResponseReturn(t *testing.T) { func TestContextResponseReturn(t *testing.T) {
@ -120,15 +141,25 @@ func TestContextResponseReturn(t *testing.T) {
w: responseWriter{}, w: responseWriter{},
} }
err := w.JSON(200, "success") err := w.JSON(200, "success")
assert.Nil(t, err) if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
err = w.XML(200, "success") err = w.XML(200, "success")
assert.Nil(t, err) if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
err = w.String(200, "success") err = w.String(200, "success")
assert.Nil(t, err) if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
err = w.Blob(200, "blob", []byte("success")) err = w.Blob(200, "blob", []byte("success"))
assert.Nil(t, err) if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
err = w.Stream(200, "stream", bytes.NewBuffer([]byte("success"))) err = w.Stream(200, "stream", bytes.NewBuffer([]byte("success")))
assert.Nil(t, err) if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
} }
func TestContextCtx(t *testing.T) { func TestContextCtx(t *testing.T) {
@ -143,13 +174,21 @@ func TestContextCtx(t *testing.T) {
w: responseWriter{}, w: responseWriter{},
} }
_, ok := w.Deadline() _, ok := w.Deadline()
assert.Equal(t, ok, true) if !ok {
t.Errorf("expected %v, got %v", true, ok)
}
done := w.Done() done := w.Done()
assert.NotNil(t, done) if done == nil {
t.Errorf("expected %v, got %v", true, ok)
}
err := w.Err() err := w.Err()
assert.Nil(t, err) if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
v := w.Value("test") v := w.Value("test")
assert.Nil(t, v) if v != nil {
t.Errorf("expected %v, got %v", nil, v)
}
w = wrapper{ w = wrapper{
router: &Router{srv: &Server{enc: DefaultResponseEncoder}}, router: &Router{srv: &Server{enc: DefaultResponseEncoder}},
@ -158,11 +197,19 @@ func TestContextCtx(t *testing.T) {
w: responseWriter{}, w: responseWriter{},
} }
_, ok = w.Deadline() _, ok = w.Deadline()
assert.Equal(t, ok, false) if ok {
t.Errorf("expected %v, got %v", false, ok)
}
done = w.Done() done = w.Done()
assert.Nil(t, done) if done != nil {
t.Errorf("expected not nil, got %v", done)
}
err = w.Err() err = w.Err()
assert.NotNil(t, err) if err == nil {
t.Errorf("expected not %v, got %v", nil, err)
}
v = w.Value("test") v = w.Value("test")
assert.Nil(t, v) if v != nil {
t.Errorf("expected %v, got %v", nil, v)
}
} }

@ -4,35 +4,55 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"reflect"
"strconv" "strconv"
"testing" "testing"
"time" "time"
"github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/selector"
"github.com/stretchr/testify/assert"
) )
func TestParseTarget(t *testing.T) { func TestParseTarget(t *testing.T) {
target, err := parseTarget("localhost:8000", true) target, err := parseTarget("localhost:8000", true)
assert.Nil(t, err) if err != nil {
assert.Equal(t, &Target{Scheme: "http", Authority: "localhost:8000"}, target) t.Errorf("expect %v, got %v", nil, err)
}
if !reflect.DeepEqual(&Target{Scheme: "http", Authority: "localhost:8000"}, target) {
t.Errorf("expect %v, got %v", &Target{Scheme: "http", Authority: "localhost:8000"}, target)
}
target, err = parseTarget("discovery:///demo", true) target, err = parseTarget("discovery:///demo", true)
assert.Nil(t, err) if err != nil {
assert.Equal(t, &Target{Scheme: "discovery", Authority: "", Endpoint: "demo"}, target) t.Errorf("expect %v, got %v", nil, err)
}
if !reflect.DeepEqual(&Target{Scheme: "discovery", Authority: "", Endpoint: "demo"}, target) {
t.Errorf("expect %v, got %v", &Target{Scheme: "discovery", Authority: "", Endpoint: "demo"}, target)
}
target, err = parseTarget("127.0.0.1:8000", true) target, err = parseTarget("127.0.0.1:8000", true)
assert.Nil(t, err) if err != nil {
assert.Equal(t, &Target{Scheme: "http", Authority: "127.0.0.1:8000"}, target) t.Errorf("expect %v, got %v", nil, err)
}
if !reflect.DeepEqual(&Target{Scheme: "http", Authority: "127.0.0.1:8000"}, target) {
t.Errorf("expect %v, got %v", &Target{Scheme: "http", Authority: "127.0.0.1:8000"}, target)
}
target, err = parseTarget("https://127.0.0.1:8000", false) target, err = parseTarget("https://127.0.0.1:8000", false)
assert.Nil(t, err) if err != nil {
assert.Equal(t, &Target{Scheme: "https", Authority: "127.0.0.1:8000"}, target) t.Errorf("expect %v, got %v", nil, err)
}
if !reflect.DeepEqual(&Target{Scheme: "https", Authority: "127.0.0.1:8000"}, target) {
t.Errorf("expect %v, got %v", &Target{Scheme: "https", Authority: "127.0.0.1:8000"}, target)
}
target, err = parseTarget("127.0.0.1:8000", false) target, err = parseTarget("127.0.0.1:8000", false)
assert.Nil(t, err) if err != nil {
assert.Equal(t, &Target{Scheme: "https", Authority: "127.0.0.1:8000"}, target) t.Errorf("expect %v, got %v", nil, err)
}
if !reflect.DeepEqual(&Target{Scheme: "https", Authority: "127.0.0.1:8000"}, target) {
t.Errorf("expect %v, got %v", &Target{Scheme: "https", Authority: "127.0.0.1:8000"}, target)
}
} }
type mockRebalancer struct{} type mockRebalancer struct{}
@ -85,7 +105,11 @@ func TestResolver(t *testing.T) {
Endpoint: "discovery://helloworld", Endpoint: "discovery://helloworld",
} }
_, err := newResolver(context.Background(), &mockDiscoverys{true}, ta, &mockRebalancer{}, false, false) _, err := newResolver(context.Background(), &mockDiscoverys{true}, ta, &mockRebalancer{}, false, false)
assert.Nil(t, err) if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}
_, err = newResolver(context.Background(), &mockDiscoverys{false}, ta, &mockRebalancer{}, true, true) _, err = newResolver(context.Background(), &mockDiscoverys{false}, ta, &mockRebalancer{}, true, true)
assert.Nil(t, err) if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}
} }

@ -6,12 +6,11 @@ import (
"fmt" "fmt"
"log" "log"
"net/http" "net/http"
"reflect"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/go-kratos/kratos/v2/internal/host" "github.com/go-kratos/kratos/v2/internal/host"
) )
@ -175,7 +174,9 @@ func testRoute(t *testing.T, srv *Server) {
func TestRouter_Group(t *testing.T) { func TestRouter_Group(t *testing.T) {
r := &Router{} r := &Router{}
rr := r.Group("a", func(http.Handler) http.Handler { return nil }) rr := r.Group("a", func(http.Handler) http.Handler { return nil })
assert.Equal(t, "a", rr.prefix) if !reflect.DeepEqual("a", rr.prefix) {
t.Errorf("expected %q, got %q", "a", rr.prefix)
}
} }
func TestHandle(t *testing.T) { func TestHandle(t *testing.T) {

@ -8,6 +8,7 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"reflect"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -16,7 +17,6 @@ import (
"github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/internal/host" "github.com/go-kratos/kratos/v2/internal/host"
"github.com/stretchr/testify/assert"
) )
type testKey struct{} type testKey struct{}
@ -49,20 +49,30 @@ func TestServer(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
testHeader(t, srv) testHeader(t, srv)
testClient(t, srv) testClient(t, srv)
assert.NoError(t, srv.Stop(ctx)) if srv.Stop(ctx) != nil {
t.Errorf("expected nil got %v", srv.Stop(ctx))
}
} }
func testHeader(t *testing.T, srv *Server) { func testHeader(t *testing.T, srv *Server) {
e, err := srv.Endpoint() e, err := srv.Endpoint()
assert.NoError(t, err) if err != nil {
t.Errorf("expected nil got %v", err)
}
client, err := NewClient(context.Background(), WithEndpoint(e.Host)) client, err := NewClient(context.Background(), WithEndpoint(e.Host))
assert.NoError(t, err) if err != nil {
t.Errorf("expected nil got %v", err)
}
reqURL := fmt.Sprintf(e.String() + "/index") reqURL := fmt.Sprintf(e.String() + "/index")
req, err := http.NewRequest("GET", reqURL, nil) req, err := http.NewRequest("GET", reqURL, nil)
assert.NoError(t, err) if err != nil {
t.Errorf("expected nil got %v", err)
}
req.Header.Set("content-type", "application/grpc-web+json") req.Header.Set("content-type", "application/grpc-web+json")
resp, err := client.Do(req) resp, err := client.Do(req)
assert.NoError(t, err) if err != nil {
t.Errorf("expected nil got %v", err)
}
resp.Body.Close() resp.Body.Close()
} }
@ -163,15 +173,21 @@ func BenchmarkServer(b *testing.B) {
}() }()
time.Sleep(time.Second) time.Sleep(time.Second)
port, ok := host.Port(srv.lis) port, ok := host.Port(srv.lis)
assert.True(b, ok) if !ok {
b.Errorf("expected port got %v", srv.lis)
}
client, err := NewClient(context.Background(), WithEndpoint(fmt.Sprintf("127.0.0.1:%d", port))) client, err := NewClient(context.Background(), WithEndpoint(fmt.Sprintf("127.0.0.1:%d", port)))
assert.NoError(b, err) if err != nil {
b.Errorf("expected nil got %v", err)
}
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var res testData var res testData
err := client.Invoke(context.Background(), "POST", "/index", nil, &res) err := client.Invoke(context.Background(), "POST", "/index", nil, &res)
assert.NoError(b, err) if err != nil {
b.Errorf("expected nil got %v", err)
}
} }
_ = srv.Stop(ctx) _ = srv.Stop(ctx)
} }
@ -180,21 +196,27 @@ func TestNetwork(t *testing.T) {
o := &Server{} o := &Server{}
v := "abc" v := "abc"
Network(v)(o) Network(v)(o)
assert.Equal(t, v, o.network) if !reflect.DeepEqual(v, o.network) {
t.Errorf("expected %v got %v", v, o.network)
}
} }
func TestAddress(t *testing.T) { func TestAddress(t *testing.T) {
o := &Server{} o := &Server{}
v := "abc" v := "abc"
Address(v)(o) Address(v)(o)
assert.Equal(t, v, o.address) if !reflect.DeepEqual(v, o.address) {
t.Errorf("expected %v got %v", v, o.address)
}
} }
func TestTimeout(t *testing.T) { func TestTimeout(t *testing.T) {
o := &Server{} o := &Server{}
v := time.Duration(123) v := time.Duration(123)
Timeout(v)(o) Timeout(v)(o)
assert.Equal(t, v, o.timeout) if !reflect.DeepEqual(v, o.timeout) {
t.Errorf("expected %v got %v", v, o.timeout)
}
} }
func TestLogger(t *testing.T) { func TestLogger(t *testing.T) {
@ -207,40 +229,52 @@ func TestMiddleware(t *testing.T) {
func(middleware.Handler) middleware.Handler { return nil }, func(middleware.Handler) middleware.Handler { return nil },
} }
Middleware(v...)(o) Middleware(v...)(o)
assert.Equal(t, v, o.ms) if !reflect.DeepEqual(v, o.ms) {
t.Errorf("expected %v got %v", v, o.ms)
}
} }
func TestRequestDecoder(t *testing.T) { func TestRequestDecoder(t *testing.T) {
o := &Server{} o := &Server{}
v := func(*http.Request, interface{}) error { return nil } v := func(*http.Request, interface{}) error { return nil }
RequestDecoder(v)(o) RequestDecoder(v)(o)
assert.NotNil(t, o.dec) if o.dec == nil {
t.Errorf("expected nil got %v", o.dec)
}
} }
func TestResponseEncoder(t *testing.T) { func TestResponseEncoder(t *testing.T) {
o := &Server{} o := &Server{}
v := func(http.ResponseWriter, *http.Request, interface{}) error { return nil } v := func(http.ResponseWriter, *http.Request, interface{}) error { return nil }
ResponseEncoder(v)(o) ResponseEncoder(v)(o)
assert.NotNil(t, o.enc) if o.enc == nil {
t.Errorf("expected nil got %v", o.enc)
}
} }
func TestErrorEncoder(t *testing.T) { func TestErrorEncoder(t *testing.T) {
o := &Server{} o := &Server{}
v := func(http.ResponseWriter, *http.Request, error) {} v := func(http.ResponseWriter, *http.Request, error) {}
ErrorEncoder(v)(o) ErrorEncoder(v)(o)
assert.NotNil(t, o.ene) if o.ene == nil {
t.Errorf("expected nil got %v", o.ene)
}
} }
func TestTLSConfig(t *testing.T) { func TestTLSConfig(t *testing.T) {
o := &Server{} o := &Server{}
v := &tls.Config{} v := &tls.Config{}
TLSConfig(v)(o) TLSConfig(v)(o)
assert.Equal(t, v, o.tlsConf) if !reflect.DeepEqual(v, o.tlsConf) {
t.Errorf("expected %v got %v", v, o.tlsConf)
}
} }
func TestListener(t *testing.T) { func TestListener(t *testing.T) {
lis := &net.TCPListener{} lis := &net.TCPListener{}
s := &Server{} s := &Server{}
Listener(lis)(s) Listener(lis)(s)
assert.Equal(t, s.lis, lis) if !reflect.DeepEqual(s.lis, lis) {
t.Errorf("expected %v got %v", lis, s.lis)
}
} }

@ -3,65 +3,92 @@ package http
import ( import (
"context" "context"
"net/http" "net/http"
"reflect"
"sort"
"testing" "testing"
"github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport"
"github.com/stretchr/testify/assert"
) )
func TestTransport_Kind(t *testing.T) { func TestTransport_Kind(t *testing.T) {
o := &Transport{} o := &Transport{}
assert.Equal(t, transport.KindHTTP, o.Kind()) if !reflect.DeepEqual(transport.KindHTTP, o.Kind()) {
t.Errorf("expect %v, got %v", transport.KindHTTP, o.Kind())
}
} }
func TestTransport_Endpoint(t *testing.T) { func TestTransport_Endpoint(t *testing.T) {
v := "hello" v := "hello"
o := &Transport{endpoint: v} o := &Transport{endpoint: v}
assert.Equal(t, v, o.Endpoint()) if !reflect.DeepEqual(v, o.Endpoint()) {
t.Errorf("expect %v, got %v", v, o.Endpoint())
}
} }
func TestTransport_Operation(t *testing.T) { func TestTransport_Operation(t *testing.T) {
v := "hello" v := "hello"
o := &Transport{operation: v} o := &Transport{operation: v}
assert.Equal(t, v, o.Operation()) if !reflect.DeepEqual(v, o.Operation()) {
t.Errorf("expect %v, got %v", v, o.Operation())
}
} }
func TestTransport_Request(t *testing.T) { func TestTransport_Request(t *testing.T) {
v := &http.Request{} v := &http.Request{}
o := &Transport{request: v} o := &Transport{request: v}
assert.Same(t, v, o.Request()) if !reflect.DeepEqual(v, o.Request()) {
t.Errorf("expect %v, got %v", v, o.Request())
}
} }
func TestTransport_RequestHeader(t *testing.T) { func TestTransport_RequestHeader(t *testing.T) {
v := headerCarrier{} v := headerCarrier{}
v.Set("a", "1") v.Set("a", "1")
o := &Transport{reqHeader: v} o := &Transport{reqHeader: v}
assert.Equal(t, "1", o.RequestHeader().Get("a")) if !reflect.DeepEqual("1", o.RequestHeader().Get("a")) {
t.Errorf("expect %v, got %v", "1", o.RequestHeader().Get("a"))
}
} }
func TestTransport_ReplyHeader(t *testing.T) { func TestTransport_ReplyHeader(t *testing.T) {
v := headerCarrier{} v := headerCarrier{}
v.Set("a", "1") v.Set("a", "1")
o := &Transport{replyHeader: v} o := &Transport{replyHeader: v}
assert.Equal(t, "1", o.ReplyHeader().Get("a")) if !reflect.DeepEqual("1", o.ReplyHeader().Get("a")) {
t.Errorf("expect %v, got %v", "1", o.ReplyHeader().Get("a"))
}
} }
func TestTransport_PathTemplate(t *testing.T) { func TestTransport_PathTemplate(t *testing.T) {
v := "template" v := "template"
o := &Transport{pathTemplate: v} o := &Transport{pathTemplate: v}
assert.Equal(t, v, o.PathTemplate()) if !reflect.DeepEqual(v, o.PathTemplate()) {
t.Errorf("expect %v, got %v", v, o.PathTemplate())
}
} }
func TestHeaderCarrier_Keys(t *testing.T) { func TestHeaderCarrier_Keys(t *testing.T) {
v := headerCarrier{} v := headerCarrier{}
v.Set("abb", "1") v.Set("abb", "1")
v.Set("bcc", "2") v.Set("bcc", "2")
assert.ElementsMatch(t, []string{"Abb", "Bcc"}, v.Keys()) want := []string{"Abb", "Bcc"}
keys := v.Keys()
sort.Slice(want, func(i, j int) bool {
return want[i] < want[j]
})
sort.Slice(keys, func(i, j int) bool {
return keys[i] < keys[j]
})
if !reflect.DeepEqual(want, keys) {
t.Errorf("expect %v, got %v", want, keys)
}
} }
func TestSetOperation(t *testing.T) { func TestSetOperation(t *testing.T) {
tr := &Transport{} tr := &Transport{}
ctx := transport.NewServerContext(context.Background(), tr) ctx := transport.NewServerContext(context.Background(), tr)
SetOperation(ctx, "kratos") SetOperation(ctx, "kratos")
assert.Equal(t, tr.operation, "kratos") if !reflect.DeepEqual(tr.operation, "kratos") {
t.Errorf("expect %v, got %v", "kratos", tr.operation)
}
} }

@ -2,9 +2,8 @@ package transport
import ( import (
"context" "context"
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
// mockTransport is a gRPC transport. // mockTransport is a gRPC transport.
@ -43,13 +42,22 @@ func TestServerTransport(t *testing.T) {
ctx = NewServerContext(ctx, &mockTransport{endpoint: "test_endpoint"}) ctx = NewServerContext(ctx, &mockTransport{endpoint: "test_endpoint"})
tr, ok := FromServerContext(ctx) tr, ok := FromServerContext(ctx)
if !ok {
assert.Equal(t, true, ok) t.Errorf("expected:%v got:%v", true, ok)
assert.NotNil(t, tr) }
if tr == nil {
t.Errorf("expected:%v got:%v", nil, tr)
}
mtr, ok := tr.(*mockTransport) mtr, ok := tr.(*mockTransport)
assert.Equal(t, true, ok) if !ok {
assert.NotNil(t, mtr) t.Errorf("expected:%v got:%v", true, ok)
assert.Equal(t, mtr.endpoint, "test_endpoint") }
if mtr == nil {
t.Errorf("expected:%v got:%v", nil, mtr)
}
if !reflect.DeepEqual(mtr.endpoint, "test_endpoint") {
t.Errorf("expected:%v got:%v", "test_endpoint", mtr.endpoint)
}
} }
func TestClientTransport(t *testing.T) { func TestClientTransport(t *testing.T) {
@ -57,11 +65,20 @@ func TestClientTransport(t *testing.T) {
ctx = NewClientContext(ctx, &mockTransport{endpoint: "test_endpoint"}) ctx = NewClientContext(ctx, &mockTransport{endpoint: "test_endpoint"})
tr, ok := FromClientContext(ctx) tr, ok := FromClientContext(ctx)
if !ok {
assert.Equal(t, true, ok) t.Errorf("expected:%v got:%v", true, ok)
assert.NotNil(t, tr) }
if tr == nil {
t.Errorf("expected:%v got:%v", nil, tr)
}
mtr, ok := tr.(*mockTransport) mtr, ok := tr.(*mockTransport)
assert.Equal(t, true, ok) if !ok {
assert.NotNil(t, mtr) t.Errorf("expected:%v got:%v", true, ok)
assert.Equal(t, mtr.endpoint, "test_endpoint") }
if mtr == nil {
t.Errorf("expected:%v got:%v", nil, mtr)
}
if !reflect.DeepEqual(mtr.endpoint, "test_endpoint") {
t.Errorf("expected:%v got:%v", "test_endpoint", mtr.endpoint)
}
} }

Loading…
Cancel
Save