diff --git a/middleware/singleflight/singleflight.go b/middleware/singleflight/singleflight.go new file mode 100644 index 000000000..8b7b9d370 --- /dev/null +++ b/middleware/singleflight/singleflight.go @@ -0,0 +1,46 @@ +package singleflight + +import ( + "context" + "fmt" + "github.com/go-kratos/kratos/v2/transport" + + "github.com/go-kratos/kratos/v2/middleware" + "golang.org/x/sync/singleflight" +) + +var singleflightGroup singleflight.Group + +// 单飞 middleware. +/* + //只在grpc服务端下使用,通过传入op名称,使用单飞: + singleflight.SingleFlight( + "/service.test1/GetCityName", + "/service.test2/GetAllCityName", + ) +*/ +func SingleFlight(ops... string) middleware.Middleware { + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req interface{}) (reply interface{}, err error) { + if tr, ok := transport.FromServerContext(ctx); ok { + if Contains(ops,tr.Operation()){ + cacheKey:=fmt.Sprintf("%s %s",tr.Operation(),req) + reply, err, _ = singleflightGroup.Do(cacheKey, func() (interface{}, error) { + return handler(ctx, req) + }) + return reply,err + } + } + return handler(ctx, req) + } + } +} + +func Contains(elems []string, elem string) bool { + for _, e := range elems { + if elem == e { + return true + } + } + return false +} \ No newline at end of file diff --git a/middleware/singleflight/singleflight_test.go b/middleware/singleflight/singleflight_test.go new file mode 100644 index 000000000..146f34f54 --- /dev/null +++ b/middleware/singleflight/singleflight_test.go @@ -0,0 +1,132 @@ +package singleflight + +import ( + "context" + "github.com/go-kratos/kratos/v2/transport" + "github.com/go-kratos/kratos/v2/transport/grpc" + "sync" + "testing" + "time" + "xiaozhu/pkg/convert" + + "github.com/go-kratos/kratos/v2/middleware" +) + +type testVali struct { + in string + out int +} + +type Transport2 struct { + grpc.Transport + op string +} + +func (t Transport2) Operation()string{ + return t.op +} + +//测试使用单飞时 +func TestUse(t *testing.T) { + var mu sync.Mutex + var callNum int + + var mock middleware.Handler = func(ctx context.Context, req interface{}) (interface{}, error) { + in := req.(testVali) + mu.Lock() + callNum++ + mu.Unlock() + time.Sleep(1*time.Second) + return in.out, nil + } + + tests := []testVali{ + {"1",1}, + {"2",2}, + {"2",2}, + {"3",3}, + {"3",3}, + {"3",3}, + } + var wg sync.WaitGroup + for _, test := range tests { + wg.Add(1) + go func(te testVali) { + t.Run(te.in, func(t *testing.T) { + v := SingleFlight("test")(mock)//注册 + tr := &Transport2{op:"test"} + ctx:=transport.NewServerContext(context.Background(),tr) + re, err := v(ctx, te) + if err!=nil{ + t.Error(err) + } + if re!=te.out{ + t.Errorf("err: %v",te) + } + wg.Done() + }) + }(test) + } + + wg.Wait() + + //最后计算总调用次数 + t.Run("callNum", func(t *testing.T) { + if callNum!=3{ + t.Errorf("callNum err: %v",callNum) + } + }) +} + + +//测试不使用单飞时 +func TestNoUse(t *testing.T) { + var mu sync.Mutex + var callNum int + + var mock middleware.Handler = func(ctx context.Context, req interface{}) (interface{}, error) { + in := req.(testVali) + mu.Lock() + callNum++ + mu.Unlock() + time.Sleep(1*time.Second) + return convert.ToInt(in.in), nil + } + + tests := []testVali{ + {"1",1}, + {"2",2}, + {"2",2}, + {"3",3}, + {"3",3}, + {"3",3}, + } + var wg sync.WaitGroup + for _, test := range tests { + wg.Add(1) + go func(te testVali) { + t.Run(te.in, func(t *testing.T) { + v := SingleFlight()(mock)//移除注册 + tr := &Transport2{op:"test"} + ctx:=transport.NewServerContext(context.Background(),tr) + re, err := v(ctx, te) + if err!=nil{ + t.Error(err) + } + if re!=te.out{ + t.Errorf("err: %v",te) + } + wg.Done() + }) + }(test) + } + + wg.Wait() + + //最后计算总调用次数 + t.Run("callNum", func(t *testing.T) { + if callNum!=6{ + t.Errorf("callNum err: %v",callNum) + } + }) +} \ No newline at end of file