diff --git a/middleware/circuitbreaker/circuitbreaker_test.go b/middleware/circuitbreaker/circuitbreaker_test.go new file mode 100644 index 000000000..b7c2553ce --- /dev/null +++ b/middleware/circuitbreaker/circuitbreaker_test.go @@ -0,0 +1,79 @@ +package circuitbreaker + +import ( + "context" + "errors" + "testing" + + kratos_errors "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/internal/group" + "github.com/go-kratos/kratos/v2/transport" +) + +type transportMock struct { + kind transport.Kind + endpoint string + operation string +} + +type circuitBreakerMock struct { + err error +} + +func (tr *transportMock) Kind() transport.Kind { + return tr.kind +} + +func (tr *transportMock) Endpoint() string { + return tr.endpoint +} + +func (tr *transportMock) Operation() string { + return tr.operation +} + +func (tr *transportMock) RequestHeader() transport.Header { + return nil +} + +func (tr *transportMock) ReplyHeader() transport.Header { + return nil +} + +func (c *circuitBreakerMock) Allow() error { return c.err } +func (c *circuitBreakerMock) MarkSuccess() {} +func (c *circuitBreakerMock) MarkFailed() {} + +func Test_WithGroup(t *testing.T) { + o := options{ + group: group.NewGroup(func() interface{} { + return "" + }), + } + + WithGroup(nil)(&o) + if o.group != nil { + t.Error("The group property must be updated to nil.") + } +} + +func Test_Server(t *testing.T) { + nextValid := func(ctx context.Context, req interface{}) (interface{}, error) { + return "Hello valid", nil + } + nextInvalid := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, kratos_errors.InternalServer("", "") + } + + ctx := transport.NewClientContext(context.Background(), &transportMock{}) + + _, _ = Client(func(o *options) { + o.group = group.NewGroup(func() interface{} { + return &circuitBreakerMock{err: errors.New("circuitbreaker error")} + }) + })(nextValid)(ctx, nil) + + _, _ = Client(func(_ *options) {})(nextValid)(ctx, nil) + + _, _ = Client(func(_ *options) {})(nextInvalid)(ctx, nil) +} diff --git a/middleware/logging/logging_test.go b/middleware/logging/logging_test.go index 0427d11e8..7242f8e04 100644 --- a/middleware/logging/logging_test.go +++ b/middleware/logging/logging_test.go @@ -97,3 +97,26 @@ func TestHTTP(t *testing.T) { }) } } + +type ( + dummy struct { + field string + } + dummyStringer struct { + field string + } +) + +func (d *dummyStringer) String() string { + return "my value" +} + +func Test_extractArgs(t *testing.T) { + if extractArgs(&dummyStringer{field: ""}) != "my value" { + t.Errorf(`The stringified dummyStringer structure must be equal to "my value", %v given`, extractArgs(&dummyStringer{field: ""})) + } + + if extractArgs(&dummy{field: "value"}) != "&{field:value}" { + t.Errorf(`The stringified dummy structure must be equal to "&{field:value}", %v given`, extractArgs(&dummy{field: "value"})) + } +} diff --git a/middleware/metadata/metadata_test.go b/middleware/metadata/metadata_test.go index c286eedbd..d0190b672 100644 --- a/middleware/metadata/metadata_test.go +++ b/middleware/metadata/metadata_test.go @@ -126,3 +126,14 @@ func TestClient(t *testing.T) { t.Fatalf("want foo got %v", reply) } } + +func Test_WithPropagatedPrefix(t *testing.T) { + o := &options{ + prefix: []string{"override"}, + } + WithPropagatedPrefix("something", "another")(o) + + if len(o.prefix) != 2 { + t.Error("The prefix must be overrided.") + } +} diff --git a/middleware/metrics/metrics_test.go b/middleware/metrics/metrics_test.go index c695f2757..9cf261cee 100644 --- a/middleware/metrics/metrics_test.go +++ b/middleware/metrics/metrics_test.go @@ -2,20 +2,155 @@ package metrics import ( "context" + "errors" "testing" + + "github.com/go-kratos/kratos/v2/metrics" + "github.com/go-kratos/kratos/v2/transport" + "github.com/go-kratos/kratos/v2/transport/http" ) -func TestMetrics(t *testing.T) { - next := func(ctx context.Context, req interface{}) (interface{}, error) { - return req.(string) + "https://go-kratos.dev", nil +type ( + mockCounter struct { + lvs []string + value float64 + } + mockObserver struct { + lvs []string + value float64 } - _, err := Server()(next)(context.Background(), "test:") +) + +func (m *mockCounter) With(lvs ...string) metrics.Counter { + return m +} + +func (m *mockCounter) Inc() { + m.value += 1.0 +} + +func (m *mockCounter) Add(delta float64) { + m.value += delta +} + +func (m *mockObserver) With(lvs ...string) metrics.Observer { + return m +} + +func (m *mockObserver) Observe(delta float64) { + m.value += delta +} + +func TestWithRequests(t *testing.T) { + mock := mockCounter{ + lvs: []string{"Initial"}, + value: 1.23, + } + o := options{ + requests: &mock, + } + + WithRequests(&mock)(&o) + + if _, ok := o.requests.(*mockCounter); !ok { + t.Errorf(`The type of the option requests property must be of "mockCounter", %T given.`, o.requests) + } + + counter := o.requests.(*mockCounter) + + if len(counter.lvs) != 1 || counter.lvs[0] != "Initial" { + t.Errorf(`The given counter lvs must have only one element equal to "Initial", %v given`, counter.lvs) + } + if counter.value != 1.23 { + t.Errorf(`The given counter value must be equal to 1.23, %v given`, counter.value) + } +} + +func TestWithSeconds(t *testing.T) { + mock := mockObserver{ + lvs: []string{"Initial"}, + value: 1.23, + } + o := options{ + seconds: &mock, + } + + WithSeconds(&mock)(&o) + + if _, ok := o.seconds.(*mockObserver); !ok { + t.Errorf(`The type of the option requests property must be of "mockObserver", %T given.`, o.requests) + } + + observer := o.seconds.(*mockObserver) + + if len(observer.lvs) != 1 || observer.lvs[0] != "Initial" { + t.Errorf(`The given observer lvs must have only one element equal to "Initial", %v given`, observer.lvs) + } + if observer.value != 1.23 { + t.Errorf(`The given observer value must be equal to 1.23, %v given`, observer.value) + } +} + +func TestServer(t *testing.T) { + e := errors.New("got an error") + nextError := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, e + } + nextValid := func(ctx context.Context, req interface{}) (interface{}, error) { + return "Hello valid", nil + } + + _, err := Server()(nextError)(context.Background(), "test:") + if err != e { + t.Error("The given error mismatch the expected.") + } + + res, err := Server(func(o *options) { + o.requests = &mockCounter{ + lvs: []string{"Initial"}, + value: 1.23, + } + o.seconds = &mockObserver{ + lvs: []string{"Initial"}, + value: 1.23, + } + })(nextValid)(transport.NewServerContext(context.Background(), &http.Transport{}), "test:") if err != nil { - t.Errorf("expect %v, got %v", nil, err) + t.Error("The server must not throw an error.") + } + if res != "Hello valid" { + t.Error(`The server must return a "Hello valid" response.`) + } +} + +func TestClient(t *testing.T) { + e := errors.New("got an error") + nextError := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, e + } + nextValid := func(ctx context.Context, req interface{}) (interface{}, error) { + return "Hello valid", nil } - _, err = Client()(next)(context.Background(), "test:") + _, err := Client()(nextError)(context.Background(), "test:") + if err != e { + t.Error("The given error mismatch the expected.") + } + + res, err := Client(func(o *options) { + o.requests = &mockCounter{ + lvs: []string{"Initial"}, + value: 1.23, + } + o.seconds = &mockObserver{ + lvs: []string{"Initial"}, + value: 1.23, + } + })(nextValid)(transport.NewClientContext(context.Background(), &http.Transport{}), "test:") if err != nil { - t.Errorf("expect %v, got %v", nil, err) + t.Error("The server must not throw an error.") + } + if res != "Hello valid" { + t.Error(`The server must return a "Hello valid" response.`) } } diff --git a/middleware/ratelimit/ratelimit_test.go b/middleware/ratelimit/ratelimit_test.go new file mode 100644 index 000000000..2e4e5bde3 --- /dev/null +++ b/middleware/ratelimit/ratelimit_test.go @@ -0,0 +1,64 @@ +package ratelimit + +import ( + "context" + "errors" + "testing" + + "github.com/go-kratos/aegis/ratelimit" +) + +type ( + ratelimitMock struct { + reached bool + } + ratelimitReachedMock struct { + reached bool + } +) + +func (r *ratelimitMock) Allow() (ratelimit.DoneFunc, error) { + return func(_ ratelimit.DoneInfo) { + r.reached = true + }, nil +} + +func (r *ratelimitReachedMock) Allow() (ratelimit.DoneFunc, error) { + return func(_ ratelimit.DoneInfo) { + r.reached = true + }, errors.New("errored") +} + +func Test_WithLimiter(t *testing.T) { + o := options{ + limiter: &ratelimitMock{}, + } + + WithLimiter(nil)(&o) + if o.limiter != nil { + t.Error("The limiter property must be updated.") + } +} + +func Test_Server(t *testing.T) { + nextValid := func(ctx context.Context, req interface{}) (interface{}, error) { + return "Hello valid", nil + } + + rlm := &ratelimitMock{} + rlrm := &ratelimitReachedMock{} + + _, _ = Server(func(o *options) { + o.limiter = rlm + })(nextValid)(context.Background(), nil) + if !rlm.reached { + t.Error("The ratelimit must run the done function.") + } + + _, _ = Server(func(o *options) { + o.limiter = rlrm + })(nextValid)(context.Background(), nil) + if rlrm.reached { + t.Error("The ratelimit must not run the done function and should be denied.") + } +} diff --git a/middleware/recovery/recovery_test.go b/middleware/recovery/recovery_test.go index cf4a7fbb3..985cb19b0 100644 --- a/middleware/recovery/recovery_test.go +++ b/middleware/recovery/recovery_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/log" ) func TestOnce(t *testing.T) { @@ -34,3 +35,8 @@ func TestNotPanic(t *testing.T) { t.Errorf("e isn't nil") } } + +// Deprecated: Remove this test with WithLogger method. +func TestWithLogger(t *testing.T) { + _ = WithLogger(log.DefaultLogger) +} diff --git a/middleware/selector/selector_test.go b/middleware/selector/selector_test.go index afde95302..667d2755f 100644 --- a/middleware/selector/selector_test.go +++ b/middleware/selector/selector_test.go @@ -2,7 +2,6 @@ package selector import ( "context" - "fmt" "reflect" "strings" "testing" @@ -243,9 +242,20 @@ func TestHeaderFunc(t *testing.T) { func testMiddleware(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (reply interface{}, err error) { - fmt.Println("before") reply, err = handler(ctx, req) - fmt.Println("after") return } } + +func Test_RegexMatch(t *testing.T) { + if regexMatch("^\b(?", "something") { + t.Error("The invalid regex must not match.") + } +} + +func Test_matches(t *testing.T) { + b := Builder{} + if b.matches(context.Background(), func(_ context.Context) (transport.Transporter, bool) { return nil, false }) { + t.Error("The matches method must return false.") + } +} diff --git a/middleware/tracing/metadata_test.go b/middleware/tracing/metadata_test.go index a4f45ed69..09446c2a8 100644 --- a/middleware/tracing/metadata_test.go +++ b/middleware/tracing/metadata_test.go @@ -51,9 +51,10 @@ func TestMetadata_Extract(t *testing.T) { carrier propagation.TextMapCarrier } tests := []struct { - name string - args args - want string + name string + args args + want string + crash bool }{ { name: "https://go-kratos.dev", @@ -71,6 +72,14 @@ func TestMetadata_Extract(t *testing.T) { }, want: "https://github.com/go-kratos/kratos", }, + { + name: "https://github.com/go-kratos/kratos", + args: args{ + parent: metadata.NewServerContext(context.Background(), metadata.Metadata{}), + carrier: propagation.HeaderCarrier{"X-Md-Service-Name": nil}, + }, + crash: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -78,6 +87,9 @@ func TestMetadata_Extract(t *testing.T) { ctx := b.Extract(tt.args.parent, tt.args.carrier) md, ok := metadata.FromServerContext(ctx) if !ok { + if tt.crash { + return + } t.Errorf("expect %v, got %v", true, ok) } if !reflect.DeepEqual(md.Get(serviceHeader), tt.want) { diff --git a/middleware/tracing/span_test.go b/middleware/tracing/span_test.go index c930154f9..aad4614fc 100644 --- a/middleware/tracing/span_test.go +++ b/middleware/tracing/span_test.go @@ -1,11 +1,19 @@ package tracing import ( + "context" + "net" + "net/http" "reflect" "testing" + "github.com/go-kratos/kratos/v2/internal/testdata/binding" + "github.com/go-kratos/kratos/v2/metadata" + "github.com/go-kratos/kratos/v2/transport" "go.opentelemetry.io/otel/attribute" semconv "go.opentelemetry.io/otel/semconv/v1.4.0" + "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc/peer" ) func Test_parseFullMethod(t *testing.T) { @@ -142,6 +150,18 @@ func Test_parseTarget(t *testing.T) { wantAddress: "hello", wantErr: false, }, + { + name: "empty", + endpoint: "%%", + wantAddress: "", + wantErr: true, + }, + { + name: "invalid path", + endpoint: "//%2F/#%2Fanother", + wantAddress: "", + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -156,3 +176,74 @@ func Test_parseTarget(t *testing.T) { }) } } + +func Test_setServerSpan(t *testing.T) { + ctx := context.Background() + _, span := trace.NewNoopTracerProvider().Tracer("Tracer").Start(ctx, "Spanname") + + // Handle without Transport context + setServerSpan(ctx, span, nil) + + // Handle with proto message + m := &binding.HelloRequest{} + setServerSpan(ctx, span, m) + + // Handle with metadata context + ctx = metadata.NewServerContext(ctx, metadata.New()) + setServerSpan(ctx, span, m) + + // Handle with KindHTTP transport context + mt := &mockTransport{ + kind: transport.KindHTTP, + } + mt.request, _ = http.NewRequest(http.MethodGet, "/endpoint", nil) + ctx = transport.NewServerContext(ctx, mt) + setServerSpan(ctx, span, m) + + // Handle with KindGRPC transport context + mt.kind = transport.KindGRPC + ctx = transport.NewServerContext(ctx, mt) + ip, _ := net.ResolveIPAddr("ip", "1.1.1.1") + ctx = peer.NewContext(ctx, &peer.Peer{ + Addr: ip, + }) + setServerSpan(ctx, span, m) +} + +func Test_setClientSpan(t *testing.T) { + ctx := context.Background() + _, span := trace.NewNoopTracerProvider().Tracer("Tracer").Start(ctx, "Spanname") + + // Handle without Transport context + setClientSpan(ctx, span, nil) + + // Handle with proto message + m := &binding.HelloRequest{} + setClientSpan(ctx, span, m) + + // Handle with metadata context + ctx = metadata.NewClientContext(ctx, metadata.New()) + setClientSpan(ctx, span, m) + + // Handle with KindHTTP transport context + mt := &mockTransport{ + kind: transport.KindHTTP, + } + mt.request, _ = http.NewRequest(http.MethodGet, "/endpoint", nil) + mt.request.Host = "MyServer" + ctx = transport.NewClientContext(ctx, mt) + setClientSpan(ctx, span, m) + + // Handle with KindGRPC transport context + mt.kind = transport.KindGRPC + ctx = transport.NewClientContext(ctx, mt) + ip, _ := net.ResolveIPAddr("ip", "1.1.1.1") + ctx = peer.NewContext(ctx, &peer.Peer{ + Addr: ip, + }) + setClientSpan(ctx, span, m) + + // Handle without Host request + ctx = transport.NewClientContext(ctx, mt) + setClientSpan(ctx, span, m) +} diff --git a/middleware/tracing/statsHandler.go b/middleware/tracing/statsHandler.go index 11b12679d..b9ee3fd1e 100644 --- a/middleware/tracing/statsHandler.go +++ b/middleware/tracing/statsHandler.go @@ -2,6 +2,7 @@ package tracing import ( "context" + "fmt" "go.opentelemetry.io/otel/trace" "google.golang.org/grpc/peer" @@ -13,6 +14,7 @@ type ClientHandler struct{} // HandleConn exists to satisfy gRPC stats.Handler. func (c *ClientHandler) HandleConn(ctx context.Context, cs stats.ConnStats) { + fmt.Println("Handle connection.") } // TagConn exists to satisfy gRPC stats.Handler. diff --git a/middleware/tracing/statsHandler_test.go b/middleware/tracing/statsHandler_test.go new file mode 100644 index 000000000..d8a38dcd1 --- /dev/null +++ b/middleware/tracing/statsHandler_test.go @@ -0,0 +1,81 @@ +package tracing + +import ( + "context" + "net" + "testing" + + "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/stats" +) + +type ctxKey string + +const testKey ctxKey = "MY_TEST_KEY" + +func Test_Client_HandleConn(t *testing.T) { + (&ClientHandler{}).HandleConn(context.Background(), nil) +} + +func Test_Client_TagConn(t *testing.T) { + client := &ClientHandler{} + ctx := context.WithValue(context.Background(), testKey, 123) + + if client.TagConn(ctx, nil).Value(testKey) != 123 { + t.Errorf(`The context value must be 123 for the "MY_KEY_TEST" key, %v given.`, client.TagConn(ctx, nil).Value(testKey)) + } +} + +func Test_Client_TagRPC(t *testing.T) { + client := &ClientHandler{} + ctx := context.WithValue(context.Background(), testKey, 123) + + if client.TagRPC(ctx, nil).Value(testKey) != 123 { + t.Errorf(`The context value must be 123 for the "MY_KEY_TEST" key, %v given.`, client.TagConn(ctx, nil).Value(testKey)) + } +} + +type ( + mockSpan struct { + trace.Span + mockSpanCtx *trace.SpanContext + } +) + +func (m *mockSpan) SpanContext() trace.SpanContext { + return *m.mockSpanCtx +} + +func Test_Client_HandleRPC(t *testing.T) { + client := &ClientHandler{} + ctx := context.Background() + rs := stats.OutHeader{} + + // Handle stats.RPCStats is not type of stats.OutHeader case + client.HandleRPC(context.TODO(), nil) + + // Handle context doesn't have the peerkey filled with a Peer instance + client.HandleRPC(ctx, &rs) + + // Handle context with the peerkey filled with a Peer instance + ip, _ := net.ResolveIPAddr("ip", "1.1.1.1") + ctx = peer.NewContext(ctx, &peer.Peer{ + Addr: ip, + }) + client.HandleRPC(ctx, &rs) + + // Handle context with Span + _, span := trace.NewNoopTracerProvider().Tracer("Tracer").Start(ctx, "Spanname") + spanCtx := trace.SpanContext{} + spanID := [8]byte{12, 12, 12, 12, 12, 12, 12, 12} + traceID := [16]byte{12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12} + spanCtx = spanCtx.WithTraceID(traceID) + spanCtx = spanCtx.WithSpanID(spanID) + mSpan := mockSpan{ + Span: span, + mockSpanCtx: &spanCtx, + } + ctx = trace.ContextWithSpan(ctx, &mSpan) + client.HandleRPC(ctx, &rs) +} diff --git a/middleware/tracing/tracer_test.go b/middleware/tracing/tracer_test.go new file mode 100644 index 000000000..fbd51c070 --- /dev/null +++ b/middleware/tracing/tracer_test.go @@ -0,0 +1,54 @@ +package tracing + +import ( + "context" + "errors" + "testing" + + "github.com/go-kratos/kratos/v2/internal/testdata/binding" + "go.opentelemetry.io/otel/trace" +) + +func Test_NewTracer(t *testing.T) { + tracer := NewTracer(trace.SpanKindClient, func(o *options) { + o.tracerProvider = trace.NewNoopTracerProvider() + }) + + if tracer.kind != trace.SpanKindClient { + t.Errorf("The tracer kind must be equal to trace.SpanKindClient, %v given.", tracer.kind) + } + + defer func() { + if recover() == nil { + t.Error("The NewTracer with an invalid SpanKindMustCrash must panic") + } + }() + _ = NewTracer(666, func(o *options) { + o.tracerProvider = trace.NewNoopTracerProvider() + }) +} + +func Test_Tracer_End(t *testing.T) { + tracer := NewTracer(trace.SpanKindClient, func(o *options) { + o.tracerProvider = trace.NewNoopTracerProvider() + }) + ctx, span := trace.NewNoopTracerProvider().Tracer("noop").Start(context.Background(), "noopSpan") + + // Handle with error case + tracer.End(ctx, span, nil, errors.New("dummy error")) + + // Handle without error case + tracer.End(ctx, span, nil, nil) + + m := &binding.HelloRequest{} + + // Handle the trace KindServer + tracer = NewTracer(trace.SpanKindServer, func(o *options) { + o.tracerProvider = trace.NewNoopTracerProvider() + }) + tracer.End(ctx, span, m, nil) + tracer = NewTracer(trace.SpanKindClient, func(o *options) { + o.tracerProvider = trace.NewNoopTracerProvider() + }) + tracer.End(ctx, span, m, nil) +} diff --git a/middleware/tracing/tracing_test.go b/middleware/tracing/tracing_test.go index d5fc0ed66..8de2ceea4 100644 --- a/middleware/tracing/tracing_test.go +++ b/middleware/tracing/tracing_test.go @@ -42,6 +42,7 @@ type mockTransport struct { endpoint string operation string header headerCarrier + request *http.Request } func (tr *mockTransport) Kind() transport.Kind { return tr.kind } @@ -49,6 +50,16 @@ func (tr *mockTransport) Endpoint() string { return tr.endpoint } func (tr *mockTransport) Operation() string { return tr.operation } func (tr *mockTransport) RequestHeader() transport.Header { return tr.header } func (tr *mockTransport) ReplyHeader() transport.Header { return tr.header } +func (tr *mockTransport) Request() *http.Request { + if tr.request == nil { + rq, _ := http.NewRequest(http.MethodGet, "/endpoint", nil) + + return rq + } + + return tr.request +} +func (tr *mockTransport) PathTemplate() string { return "" } func TestTracer(t *testing.T) { carrier := headerCarrier{}