tests(coverage): Increase middleware tests coverage (#2165)

* tests(coverage): Increase tests coverage

* Lint fix
pull/2175/head
darkweak 2 years ago committed by GitHub
parent dec323113f
commit c9fbb27b5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 79
      middleware/circuitbreaker/circuitbreaker_test.go
  2. 23
      middleware/logging/logging_test.go
  3. 11
      middleware/metadata/metadata_test.go
  4. 149
      middleware/metrics/metrics_test.go
  5. 64
      middleware/ratelimit/ratelimit_test.go
  6. 6
      middleware/recovery/recovery_test.go
  7. 16
      middleware/selector/selector_test.go
  8. 18
      middleware/tracing/metadata_test.go
  9. 91
      middleware/tracing/span_test.go
  10. 2
      middleware/tracing/statsHandler.go
  11. 81
      middleware/tracing/statsHandler_test.go
  12. 54
      middleware/tracing/tracer_test.go
  13. 11
      middleware/tracing/tracing_test.go

@ -0,0 +1,79 @@
package circuitbreaker
import (
"context"
"errors"
"testing"
kratos_errors "github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/internal/group"
"github.com/go-kratos/kratos/v2/transport"
)
type transportMock struct {
kind transport.Kind
endpoint string
operation string
}
type circuitBreakerMock struct {
err error
}
func (tr *transportMock) Kind() transport.Kind {
return tr.kind
}
func (tr *transportMock) Endpoint() string {
return tr.endpoint
}
func (tr *transportMock) Operation() string {
return tr.operation
}
func (tr *transportMock) RequestHeader() transport.Header {
return nil
}
func (tr *transportMock) ReplyHeader() transport.Header {
return nil
}
func (c *circuitBreakerMock) Allow() error { return c.err }
func (c *circuitBreakerMock) MarkSuccess() {}
func (c *circuitBreakerMock) MarkFailed() {}
func Test_WithGroup(t *testing.T) {
o := options{
group: group.NewGroup(func() interface{} {
return ""
}),
}
WithGroup(nil)(&o)
if o.group != nil {
t.Error("The group property must be updated to nil.")
}
}
func Test_Server(t *testing.T) {
nextValid := func(ctx context.Context, req interface{}) (interface{}, error) {
return "Hello valid", nil
}
nextInvalid := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, kratos_errors.InternalServer("", "")
}
ctx := transport.NewClientContext(context.Background(), &transportMock{})
_, _ = Client(func(o *options) {
o.group = group.NewGroup(func() interface{} {
return &circuitBreakerMock{err: errors.New("circuitbreaker error")}
})
})(nextValid)(ctx, nil)
_, _ = Client(func(_ *options) {})(nextValid)(ctx, nil)
_, _ = Client(func(_ *options) {})(nextInvalid)(ctx, nil)
}

@ -97,3 +97,26 @@ func TestHTTP(t *testing.T) {
})
}
}
type (
dummy struct {
field string
}
dummyStringer struct {
field string
}
)
func (d *dummyStringer) String() string {
return "my value"
}
func Test_extractArgs(t *testing.T) {
if extractArgs(&dummyStringer{field: ""}) != "my value" {
t.Errorf(`The stringified dummyStringer structure must be equal to "my value", %v given`, extractArgs(&dummyStringer{field: ""}))
}
if extractArgs(&dummy{field: "value"}) != "&{field:value}" {
t.Errorf(`The stringified dummy structure must be equal to "&{field:value}", %v given`, extractArgs(&dummy{field: "value"}))
}
}

@ -126,3 +126,14 @@ func TestClient(t *testing.T) {
t.Fatalf("want foo got %v", reply)
}
}
func Test_WithPropagatedPrefix(t *testing.T) {
o := &options{
prefix: []string{"override"},
}
WithPropagatedPrefix("something", "another")(o)
if len(o.prefix) != 2 {
t.Error("The prefix must be overrided.")
}
}

@ -2,20 +2,155 @@ package metrics
import (
"context"
"errors"
"testing"
"github.com/go-kratos/kratos/v2/metrics"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/transport/http"
)
func TestMetrics(t *testing.T) {
next := func(ctx context.Context, req interface{}) (interface{}, error) {
return req.(string) + "https://go-kratos.dev", nil
type (
mockCounter struct {
lvs []string
value float64
}
mockObserver struct {
lvs []string
value float64
}
_, err := Server()(next)(context.Background(), "test:")
)
func (m *mockCounter) With(lvs ...string) metrics.Counter {
return m
}
func (m *mockCounter) Inc() {
m.value += 1.0
}
func (m *mockCounter) Add(delta float64) {
m.value += delta
}
func (m *mockObserver) With(lvs ...string) metrics.Observer {
return m
}
func (m *mockObserver) Observe(delta float64) {
m.value += delta
}
func TestWithRequests(t *testing.T) {
mock := mockCounter{
lvs: []string{"Initial"},
value: 1.23,
}
o := options{
requests: &mock,
}
WithRequests(&mock)(&o)
if _, ok := o.requests.(*mockCounter); !ok {
t.Errorf(`The type of the option requests property must be of "mockCounter", %T given.`, o.requests)
}
counter := o.requests.(*mockCounter)
if len(counter.lvs) != 1 || counter.lvs[0] != "Initial" {
t.Errorf(`The given counter lvs must have only one element equal to "Initial", %v given`, counter.lvs)
}
if counter.value != 1.23 {
t.Errorf(`The given counter value must be equal to 1.23, %v given`, counter.value)
}
}
func TestWithSeconds(t *testing.T) {
mock := mockObserver{
lvs: []string{"Initial"},
value: 1.23,
}
o := options{
seconds: &mock,
}
WithSeconds(&mock)(&o)
if _, ok := o.seconds.(*mockObserver); !ok {
t.Errorf(`The type of the option requests property must be of "mockObserver", %T given.`, o.requests)
}
observer := o.seconds.(*mockObserver)
if len(observer.lvs) != 1 || observer.lvs[0] != "Initial" {
t.Errorf(`The given observer lvs must have only one element equal to "Initial", %v given`, observer.lvs)
}
if observer.value != 1.23 {
t.Errorf(`The given observer value must be equal to 1.23, %v given`, observer.value)
}
}
func TestServer(t *testing.T) {
e := errors.New("got an error")
nextError := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, e
}
nextValid := func(ctx context.Context, req interface{}) (interface{}, error) {
return "Hello valid", nil
}
_, err := Server()(nextError)(context.Background(), "test:")
if err != e {
t.Error("The given error mismatch the expected.")
}
res, err := Server(func(o *options) {
o.requests = &mockCounter{
lvs: []string{"Initial"},
value: 1.23,
}
o.seconds = &mockObserver{
lvs: []string{"Initial"},
value: 1.23,
}
})(nextValid)(transport.NewServerContext(context.Background(), &http.Transport{}), "test:")
if err != nil {
t.Errorf("expect %v, got %v", nil, err)
t.Error("The server must not throw an error.")
}
if res != "Hello valid" {
t.Error(`The server must return a "Hello valid" response.`)
}
}
func TestClient(t *testing.T) {
e := errors.New("got an error")
nextError := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, e
}
nextValid := func(ctx context.Context, req interface{}) (interface{}, error) {
return "Hello valid", nil
}
_, err = Client()(next)(context.Background(), "test:")
_, err := Client()(nextError)(context.Background(), "test:")
if err != e {
t.Error("The given error mismatch the expected.")
}
res, err := Client(func(o *options) {
o.requests = &mockCounter{
lvs: []string{"Initial"},
value: 1.23,
}
o.seconds = &mockObserver{
lvs: []string{"Initial"},
value: 1.23,
}
})(nextValid)(transport.NewClientContext(context.Background(), &http.Transport{}), "test:")
if err != nil {
t.Errorf("expect %v, got %v", nil, err)
t.Error("The server must not throw an error.")
}
if res != "Hello valid" {
t.Error(`The server must return a "Hello valid" response.`)
}
}

@ -0,0 +1,64 @@
package ratelimit
import (
"context"
"errors"
"testing"
"github.com/go-kratos/aegis/ratelimit"
)
type (
ratelimitMock struct {
reached bool
}
ratelimitReachedMock struct {
reached bool
}
)
func (r *ratelimitMock) Allow() (ratelimit.DoneFunc, error) {
return func(_ ratelimit.DoneInfo) {
r.reached = true
}, nil
}
func (r *ratelimitReachedMock) Allow() (ratelimit.DoneFunc, error) {
return func(_ ratelimit.DoneInfo) {
r.reached = true
}, errors.New("errored")
}
func Test_WithLimiter(t *testing.T) {
o := options{
limiter: &ratelimitMock{},
}
WithLimiter(nil)(&o)
if o.limiter != nil {
t.Error("The limiter property must be updated.")
}
}
func Test_Server(t *testing.T) {
nextValid := func(ctx context.Context, req interface{}) (interface{}, error) {
return "Hello valid", nil
}
rlm := &ratelimitMock{}
rlrm := &ratelimitReachedMock{}
_, _ = Server(func(o *options) {
o.limiter = rlm
})(nextValid)(context.Background(), nil)
if !rlm.reached {
t.Error("The ratelimit must run the done function.")
}
_, _ = Server(func(o *options) {
o.limiter = rlrm
})(nextValid)(context.Background(), nil)
if rlrm.reached {
t.Error("The ratelimit must not run the done function and should be denied.")
}
}

@ -6,6 +6,7 @@ import (
"testing"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/log"
)
func TestOnce(t *testing.T) {
@ -34,3 +35,8 @@ func TestNotPanic(t *testing.T) {
t.Errorf("e isn't nil")
}
}
// Deprecated: Remove this test with WithLogger method.
func TestWithLogger(t *testing.T) {
_ = WithLogger(log.DefaultLogger)
}

@ -2,7 +2,6 @@ package selector
import (
"context"
"fmt"
"reflect"
"strings"
"testing"
@ -243,9 +242,20 @@ func TestHeaderFunc(t *testing.T) {
func testMiddleware(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
fmt.Println("before")
reply, err = handler(ctx, req)
fmt.Println("after")
return
}
}
func Test_RegexMatch(t *testing.T) {
if regexMatch("^\b(?", "something") {
t.Error("The invalid regex must not match.")
}
}
func Test_matches(t *testing.T) {
b := Builder{}
if b.matches(context.Background(), func(_ context.Context) (transport.Transporter, bool) { return nil, false }) {
t.Error("The matches method must return false.")
}
}

@ -51,9 +51,10 @@ func TestMetadata_Extract(t *testing.T) {
carrier propagation.TextMapCarrier
}
tests := []struct {
name string
args args
want string
name string
args args
want string
crash bool
}{
{
name: "https://go-kratos.dev",
@ -71,6 +72,14 @@ func TestMetadata_Extract(t *testing.T) {
},
want: "https://github.com/go-kratos/kratos",
},
{
name: "https://github.com/go-kratos/kratos",
args: args{
parent: metadata.NewServerContext(context.Background(), metadata.Metadata{}),
carrier: propagation.HeaderCarrier{"X-Md-Service-Name": nil},
},
crash: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -78,6 +87,9 @@ func TestMetadata_Extract(t *testing.T) {
ctx := b.Extract(tt.args.parent, tt.args.carrier)
md, ok := metadata.FromServerContext(ctx)
if !ok {
if tt.crash {
return
}
t.Errorf("expect %v, got %v", true, ok)
}
if !reflect.DeepEqual(md.Get(serviceHeader), tt.want) {

@ -1,11 +1,19 @@
package tracing
import (
"context"
"net"
"net/http"
"reflect"
"testing"
"github.com/go-kratos/kratos/v2/internal/testdata/binding"
"github.com/go-kratos/kratos/v2/metadata"
"github.com/go-kratos/kratos/v2/transport"
"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/peer"
)
func Test_parseFullMethod(t *testing.T) {
@ -142,6 +150,18 @@ func Test_parseTarget(t *testing.T) {
wantAddress: "hello",
wantErr: false,
},
{
name: "empty",
endpoint: "%%",
wantAddress: "",
wantErr: true,
},
{
name: "invalid path",
endpoint: "//%2F/#%2Fanother",
wantAddress: "",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -156,3 +176,74 @@ func Test_parseTarget(t *testing.T) {
})
}
}
func Test_setServerSpan(t *testing.T) {
ctx := context.Background()
_, span := trace.NewNoopTracerProvider().Tracer("Tracer").Start(ctx, "Spanname")
// Handle without Transport context
setServerSpan(ctx, span, nil)
// Handle with proto message
m := &binding.HelloRequest{}
setServerSpan(ctx, span, m)
// Handle with metadata context
ctx = metadata.NewServerContext(ctx, metadata.New())
setServerSpan(ctx, span, m)
// Handle with KindHTTP transport context
mt := &mockTransport{
kind: transport.KindHTTP,
}
mt.request, _ = http.NewRequest(http.MethodGet, "/endpoint", nil)
ctx = transport.NewServerContext(ctx, mt)
setServerSpan(ctx, span, m)
// Handle with KindGRPC transport context
mt.kind = transport.KindGRPC
ctx = transport.NewServerContext(ctx, mt)
ip, _ := net.ResolveIPAddr("ip", "1.1.1.1")
ctx = peer.NewContext(ctx, &peer.Peer{
Addr: ip,
})
setServerSpan(ctx, span, m)
}
func Test_setClientSpan(t *testing.T) {
ctx := context.Background()
_, span := trace.NewNoopTracerProvider().Tracer("Tracer").Start(ctx, "Spanname")
// Handle without Transport context
setClientSpan(ctx, span, nil)
// Handle with proto message
m := &binding.HelloRequest{}
setClientSpan(ctx, span, m)
// Handle with metadata context
ctx = metadata.NewClientContext(ctx, metadata.New())
setClientSpan(ctx, span, m)
// Handle with KindHTTP transport context
mt := &mockTransport{
kind: transport.KindHTTP,
}
mt.request, _ = http.NewRequest(http.MethodGet, "/endpoint", nil)
mt.request.Host = "MyServer"
ctx = transport.NewClientContext(ctx, mt)
setClientSpan(ctx, span, m)
// Handle with KindGRPC transport context
mt.kind = transport.KindGRPC
ctx = transport.NewClientContext(ctx, mt)
ip, _ := net.ResolveIPAddr("ip", "1.1.1.1")
ctx = peer.NewContext(ctx, &peer.Peer{
Addr: ip,
})
setClientSpan(ctx, span, m)
// Handle without Host request
ctx = transport.NewClientContext(ctx, mt)
setClientSpan(ctx, span, m)
}

@ -2,6 +2,7 @@ package tracing
import (
"context"
"fmt"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/peer"
@ -13,6 +14,7 @@ type ClientHandler struct{}
// HandleConn exists to satisfy gRPC stats.Handler.
func (c *ClientHandler) HandleConn(ctx context.Context, cs stats.ConnStats) {
fmt.Println("Handle connection.")
}
// TagConn exists to satisfy gRPC stats.Handler.

@ -0,0 +1,81 @@
package tracing
import (
"context"
"net"
"testing"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
)
type ctxKey string
const testKey ctxKey = "MY_TEST_KEY"
func Test_Client_HandleConn(t *testing.T) {
(&ClientHandler{}).HandleConn(context.Background(), nil)
}
func Test_Client_TagConn(t *testing.T) {
client := &ClientHandler{}
ctx := context.WithValue(context.Background(), testKey, 123)
if client.TagConn(ctx, nil).Value(testKey) != 123 {
t.Errorf(`The context value must be 123 for the "MY_KEY_TEST" key, %v given.`, client.TagConn(ctx, nil).Value(testKey))
}
}
func Test_Client_TagRPC(t *testing.T) {
client := &ClientHandler{}
ctx := context.WithValue(context.Background(), testKey, 123)
if client.TagRPC(ctx, nil).Value(testKey) != 123 {
t.Errorf(`The context value must be 123 for the "MY_KEY_TEST" key, %v given.`, client.TagConn(ctx, nil).Value(testKey))
}
}
type (
mockSpan struct {
trace.Span
mockSpanCtx *trace.SpanContext
}
)
func (m *mockSpan) SpanContext() trace.SpanContext {
return *m.mockSpanCtx
}
func Test_Client_HandleRPC(t *testing.T) {
client := &ClientHandler{}
ctx := context.Background()
rs := stats.OutHeader{}
// Handle stats.RPCStats is not type of stats.OutHeader case
client.HandleRPC(context.TODO(), nil)
// Handle context doesn't have the peerkey filled with a Peer instance
client.HandleRPC(ctx, &rs)
// Handle context with the peerkey filled with a Peer instance
ip, _ := net.ResolveIPAddr("ip", "1.1.1.1")
ctx = peer.NewContext(ctx, &peer.Peer{
Addr: ip,
})
client.HandleRPC(ctx, &rs)
// Handle context with Span
_, span := trace.NewNoopTracerProvider().Tracer("Tracer").Start(ctx, "Spanname")
spanCtx := trace.SpanContext{}
spanID := [8]byte{12, 12, 12, 12, 12, 12, 12, 12}
traceID := [16]byte{12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12}
spanCtx = spanCtx.WithTraceID(traceID)
spanCtx = spanCtx.WithSpanID(spanID)
mSpan := mockSpan{
Span: span,
mockSpanCtx: &spanCtx,
}
ctx = trace.ContextWithSpan(ctx, &mSpan)
client.HandleRPC(ctx, &rs)
}

@ -0,0 +1,54 @@
package tracing
import (
"context"
"errors"
"testing"
"github.com/go-kratos/kratos/v2/internal/testdata/binding"
"go.opentelemetry.io/otel/trace"
)
func Test_NewTracer(t *testing.T) {
tracer := NewTracer(trace.SpanKindClient, func(o *options) {
o.tracerProvider = trace.NewNoopTracerProvider()
})
if tracer.kind != trace.SpanKindClient {
t.Errorf("The tracer kind must be equal to trace.SpanKindClient, %v given.", tracer.kind)
}
defer func() {
if recover() == nil {
t.Error("The NewTracer with an invalid SpanKindMustCrash must panic")
}
}()
_ = NewTracer(666, func(o *options) {
o.tracerProvider = trace.NewNoopTracerProvider()
})
}
func Test_Tracer_End(t *testing.T) {
tracer := NewTracer(trace.SpanKindClient, func(o *options) {
o.tracerProvider = trace.NewNoopTracerProvider()
})
ctx, span := trace.NewNoopTracerProvider().Tracer("noop").Start(context.Background(), "noopSpan")
// Handle with error case
tracer.End(ctx, span, nil, errors.New("dummy error"))
// Handle without error case
tracer.End(ctx, span, nil, nil)
m := &binding.HelloRequest{}
// Handle the trace KindServer
tracer = NewTracer(trace.SpanKindServer, func(o *options) {
o.tracerProvider = trace.NewNoopTracerProvider()
})
tracer.End(ctx, span, m, nil)
tracer = NewTracer(trace.SpanKindClient, func(o *options) {
o.tracerProvider = trace.NewNoopTracerProvider()
})
tracer.End(ctx, span, m, nil)
}

@ -42,6 +42,7 @@ type mockTransport struct {
endpoint string
operation string
header headerCarrier
request *http.Request
}
func (tr *mockTransport) Kind() transport.Kind { return tr.kind }
@ -49,6 +50,16 @@ func (tr *mockTransport) Endpoint() string { return tr.endpoint }
func (tr *mockTransport) Operation() string { return tr.operation }
func (tr *mockTransport) RequestHeader() transport.Header { return tr.header }
func (tr *mockTransport) ReplyHeader() transport.Header { return tr.header }
func (tr *mockTransport) Request() *http.Request {
if tr.request == nil {
rq, _ := http.NewRequest(http.MethodGet, "/endpoint", nil)
return rq
}
return tr.request
}
func (tr *mockTransport) PathTemplate() string { return "" }
func TestTracer(t *testing.T) {
carrier := headerCarrier{}

Loading…
Cancel
Save