You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
kratos/middleware/singleflight/singleflight_test.go

130 lines
2.3 KiB

package singleflight
import (
"context"
"sync"
"testing"
"time"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/transport/grpc"
)
type testVali struct {
in string
out int
}
type Transport2 struct {
grpc.Transport
op string
}
func (t Transport2) Operation() string {
return t.op
}
// 测试使用单飞时
func TestUse(t *testing.T) {
var mu sync.Mutex
var callNum int
var mock middleware.Handler = func(ctx context.Context, req interface{}) (interface{}, error) {
in := req.(testVali)
mu.Lock()
callNum++
mu.Unlock()
time.Sleep(1 * time.Second)
return in.out, nil
}
tests := []testVali{
{"1", 1},
{"2", 2},
{"2", 2},
{"3", 3},
{"3", 3},
{"3", 3},
}
var wg sync.WaitGroup
for _, test := range tests {
wg.Add(1)
go func(te testVali) {
t.Run(te.in, func(t *testing.T) {
v := SingleFlight("test")(mock) //注册
tr := &Transport2{op: "test"}
ctx := transport.NewServerContext(context.Background(), tr)
re, err := v(ctx, te)
if err != nil {
t.Error(err)
}
if re != te.out {
t.Errorf("err: %v", te)
}
wg.Done()
})
}(test)
}
wg.Wait()
//最后计算总调用次数
t.Run("callNum", func(t *testing.T) {
if callNum != 3 {
t.Errorf("callNum err: %v", callNum)
}
})
}
// 测试不使用单飞时
func TestNoUse(t *testing.T) {
var mu sync.Mutex
var callNum int
var mock middleware.Handler = func(ctx context.Context, req interface{}) (interface{}, error) {
in := req.(testVali)
mu.Lock()
callNum++
mu.Unlock()
time.Sleep(1 * time.Second)
return in.out, nil
}
tests := []testVali{
{"1", 1},
{"2", 2},
{"2", 2},
{"3", 3},
{"3", 3},
{"3", 3},
}
var wg sync.WaitGroup
for _, test := range tests {
wg.Add(1)
go func(te testVali) {
t.Run(te.in, func(t *testing.T) {
v := SingleFlight()(mock) //移除注册
tr := &Transport2{op: "test"}
ctx := transport.NewServerContext(context.Background(), tr)
re, err := v(ctx, te)
if err != nil {
t.Error(err)
}
if re != te.out {
t.Errorf("err: %v", te)
}
wg.Done()
})
}(test)
}
wg.Wait()
//最后计算总调用次数
t.Run("callNum", func(t *testing.T) {
if callNum != 6 {
t.Errorf("callNum err: %v", callNum)
}
})
}