feat[middleware]: grpc server 单飞

基于singleflight库,这个库的主要作用就是将一组相同的请求合并成一个请求,实际上只会去请求一次,然后对所有的请求返回相同的结果。
pull/2639/head
renjun 2 years ago
parent b242403bc1
commit 8988d1444f
  1. 46
      middleware/singleflight/singleflight.go
  2. 132
      middleware/singleflight/singleflight_test.go

@ -0,0 +1,46 @@
package singleflight
import (
"context"
"fmt"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/middleware"
"golang.org/x/sync/singleflight"
)
var singleflightGroup singleflight.Group
// 单飞 middleware.
/*
//只在grpc服务端下使用,通过传入op名称,使用单飞:
singleflight.SingleFlight(
"/service.test1/GetCityName",
"/service.test2/GetAllCityName",
)
*/
func SingleFlight(ops... string) middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
if tr, ok := transport.FromServerContext(ctx); ok {
if Contains(ops,tr.Operation()){
cacheKey:=fmt.Sprintf("%s %s",tr.Operation(),req)
reply, err, _ = singleflightGroup.Do(cacheKey, func() (interface{}, error) {
return handler(ctx, req)
})
return reply,err
}
}
return handler(ctx, req)
}
}
}
func Contains(elems []string, elem string) bool {
for _, e := range elems {
if elem == e {
return true
}
}
return false
}

@ -0,0 +1,132 @@
package singleflight
import (
"context"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/transport/grpc"
"sync"
"testing"
"time"
"xiaozhu/pkg/convert"
"github.com/go-kratos/kratos/v2/middleware"
)
type testVali struct {
in string
out int
}
type Transport2 struct {
grpc.Transport
op string
}
func (t Transport2) Operation()string{
return t.op
}
//测试使用单飞时
func TestUse(t *testing.T) {
var mu sync.Mutex
var callNum int
var mock middleware.Handler = func(ctx context.Context, req interface{}) (interface{}, error) {
in := req.(testVali)
mu.Lock()
callNum++
mu.Unlock()
time.Sleep(1*time.Second)
return in.out, nil
}
tests := []testVali{
{"1",1},
{"2",2},
{"2",2},
{"3",3},
{"3",3},
{"3",3},
}
var wg sync.WaitGroup
for _, test := range tests {
wg.Add(1)
go func(te testVali) {
t.Run(te.in, func(t *testing.T) {
v := SingleFlight("test")(mock)//注册
tr := &Transport2{op:"test"}
ctx:=transport.NewServerContext(context.Background(),tr)
re, err := v(ctx, te)
if err!=nil{
t.Error(err)
}
if re!=te.out{
t.Errorf("err: %v",te)
}
wg.Done()
})
}(test)
}
wg.Wait()
//最后计算总调用次数
t.Run("callNum", func(t *testing.T) {
if callNum!=3{
t.Errorf("callNum err: %v",callNum)
}
})
}
//测试不使用单飞时
func TestNoUse(t *testing.T) {
var mu sync.Mutex
var callNum int
var mock middleware.Handler = func(ctx context.Context, req interface{}) (interface{}, error) {
in := req.(testVali)
mu.Lock()
callNum++
mu.Unlock()
time.Sleep(1*time.Second)
return convert.ToInt(in.in), nil
}
tests := []testVali{
{"1",1},
{"2",2},
{"2",2},
{"3",3},
{"3",3},
{"3",3},
}
var wg sync.WaitGroup
for _, test := range tests {
wg.Add(1)
go func(te testVali) {
t.Run(te.in, func(t *testing.T) {
v := SingleFlight()(mock)//移除注册
tr := &Transport2{op:"test"}
ctx:=transport.NewServerContext(context.Background(),tr)
re, err := v(ctx, te)
if err!=nil{
t.Error(err)
}
if re!=te.out{
t.Errorf("err: %v",te)
}
wg.Done()
})
}(test)
}
wg.Wait()
//最后计算总调用次数
t.Run("callNum", func(t *testing.T) {
if callNum!=6{
t.Errorf("callNum err: %v",callNum)
}
})
}
Loading…
Cancel
Save