diff --git a/internal/context/context_test.go b/internal/context/context_test.go new file mode 100644 index 000000000..563c8c694 --- /dev/null +++ b/internal/context/context_test.go @@ -0,0 +1,98 @@ +package context + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestContext(t *testing.T) { + ctx1 := context.WithValue(context.Background(), "go-kratos", "https://github.com/go-kratos/") + ctx2 := context.WithValue(context.Background(), "kratos", "https://go-kratos.dev/") + + ctx, cancel := Merge(ctx1, ctx2) + defer cancel() + + got := ctx.Value("go-kratos") + value1, ok := got.(string) + assert.Equal(t, ok, true) + assert.Equal(t, value1, "https://github.com/go-kratos/") + // + got2 := ctx.Value("kratos") + value2, ok := got2.(string) + assert.Equal(t, ok, true) + assert.Equal(t, value2, "https://go-kratos.dev/") + + t.Log(value1) + t.Log(value2) +} + +func TestErr(t *testing.T) { + ctx1, cancel := context.WithTimeout(context.Background(), time.Microsecond) + defer cancel() + time.Sleep(time.Millisecond) + + ctx, cancel := Merge(ctx1, context.Background()) + defer cancel() + + assert.Equal(t, ctx.Err(), context.DeadlineExceeded) +} + +func TestDone(t *testing.T) { + ctx1, cancel := context.WithCancel(context.Background()) + defer cancel() + + ctx, cancel := Merge(ctx1, context.Background()) + go func() { + time.Sleep(time.Millisecond * 50) + cancel() + }() + + assert.Equal(t, <-ctx.Done(), struct{}{}) +} + +func TestFinish(t *testing.T) { + mc := &mergeCtx{ + parent1: context.Background(), + parent2: context.Background(), + done: make(chan struct{}), + cancelCh: make(chan struct{}), + } + err := mc.finish(context.DeadlineExceeded) + assert.Equal(t, err, context.DeadlineExceeded) + assert.Equal(t, mc.doneMark, uint32(1)) + assert.Equal(t, <-mc.done, struct{}{}) +} + +func TestWait(t *testing.T) { + ctx1, cancel := context.WithCancel(context.Background()) + + mc := &mergeCtx{ + parent1: ctx1, + parent2: context.Background(), + done: make(chan struct{}), + cancelCh: make(chan struct{}), + } + go func() { + time.Sleep(time.Millisecond * 50) + cancel() + }() + + mc.wait() + t.Log(mc.doneErr) + assert.Equal(t, mc.doneErr, context.Canceled) +} + +func TestCancel(t *testing.T) { + mc := &mergeCtx{ + parent1: context.Background(), + parent2: context.Background(), + done: make(chan struct{}), + cancelCh: make(chan struct{}), + } + mc.cancel() + + assert.Equal(t, <-mc.cancelCh, struct{}{}) +} diff --git a/metadata/metadata_test.go b/metadata/metadata_test.go index 61a2ef36f..55439b16f 100644 --- a/metadata/metadata_test.go +++ b/metadata/metadata_test.go @@ -238,3 +238,48 @@ func TestMergeToClientContext(t *testing.T) { }) } } + +func TestMetadata_Range(t *testing.T) { + md := Metadata{"kratos": "kratos", "https://go-kratos.dev/": "https://go-kratos.dev/", "go-kratos": "go-kratos"} + var tmp = Metadata{} + md.Range(func(k, v string) bool { + if k == "https://go-kratos.dev/" || k == "kratos" { + tmp[k] = v + } + return true + }) + if !reflect.DeepEqual(tmp, Metadata{"https://go-kratos.dev/": "https://go-kratos.dev/", "kratos": "kratos"}) { + t.Errorf("metadata = %v, want %v", tmp, Metadata{"kratos": "kratos"}) + } +} + +func TestMetadata_Clone(t *testing.T) { + tests := []struct { + name string + m Metadata + want Metadata + }{ + { + name: "kratos", + m: Metadata{"kratos": "kratos", "https://go-kratos.dev/": "https://go-kratos.dev/", "go-kratos": "go-kratos"}, + want: Metadata{"kratos": "kratos", "https://go-kratos.dev/": "https://go-kratos.dev/", "go-kratos": "go-kratos"}, + }, + { + name: "go", + m: Metadata{"language": "golang"}, + want: Metadata{"language": "golang"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.m.Clone() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Clone() = %v, want %v", got, tt.want) + } + got["kratos"] = "go" + if reflect.DeepEqual(got, tt.want) { + t.Errorf("want got != want got %v want %v", got, tt.want) + } + }) + } +} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go new file mode 100644 index 000000000..ca5ff4c5a --- /dev/null +++ b/middleware/middleware_test.go @@ -0,0 +1,54 @@ +package middleware + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +var i int + +func TestChain(t *testing.T) { + next := func(ctx context.Context, req interface{}) (interface{}, error) { + t.Log(req) + i += 10 + return "reply", nil + } + + got, err := Chain(test1Middleware, test2Middleware, test3Middleware)(next)(context.Background(), "hello kratos!") + assert.Nil(t, err) + assert.Equal(t, got, "reply") + assert.Equal(t, i, 16) +} + +func test1Middleware(handler Handler) Handler { + return func(ctx context.Context, req interface{}) (reply interface{}, err error) { + fmt.Println("test1 before") + i++ + reply, err = handler(ctx, req) + fmt.Println("test1 after") + return + } +} + +func test2Middleware(handler Handler) Handler { + return func(ctx context.Context, req interface{}) (reply interface{}, err error) { + fmt.Println("test2 before") + i += 2 + reply, err = handler(ctx, req) + fmt.Println("test2 after") + return + } +} + +func test3Middleware(handler Handler) Handler { + return func(ctx context.Context, req interface{}) (reply interface{}, err error) { + fmt.Println("test3 before") + i += 3 + reply, err = handler(ctx, req) + fmt.Println("test3 after") + return + } +}