fix: format code

pull/2639/head
renjun 2 years ago
parent 212ea7f3f7
commit 3433ff76a8
  1. 2
      middleware/singleflight/singleflight.go
  2. 80
      middleware/singleflight/singleflight_test.go

@ -3,9 +3,9 @@ package singleflight
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
"golang.org/x/sync/singleflight" "golang.org/x/sync/singleflight"
) )

@ -2,12 +2,13 @@ package singleflight
import ( import (
"context" "context"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/transport/grpc"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/transport/grpc"
) )
type testVali struct { type testVali struct {
@ -20,13 +21,13 @@ type Transport2 struct {
op string op string
} }
func (t Transport2) Operation()string{ func (t Transport2) Operation() string {
return t.op return t.op
} }
//测试使用单飞时 // 测试使用单飞时
func TestUse(t *testing.T) { func TestUse(t *testing.T) {
var mu sync.Mutex var mu sync.Mutex
var callNum int var callNum int
var mock middleware.Handler = func(ctx context.Context, req interface{}) (interface{}, error) { var mock middleware.Handler = func(ctx context.Context, req interface{}) (interface{}, error) {
@ -34,32 +35,32 @@ func TestUse(t *testing.T) {
mu.Lock() mu.Lock()
callNum++ callNum++
mu.Unlock() mu.Unlock()
time.Sleep(1*time.Second) time.Sleep(1 * time.Second)
return in.out, nil return in.out, nil
} }
tests := []testVali{ tests := []testVali{
{"1",1}, {"1", 1},
{"2",2}, {"2", 2},
{"2",2}, {"2", 2},
{"3",3}, {"3", 3},
{"3",3}, {"3", 3},
{"3",3}, {"3", 3},
} }
var wg sync.WaitGroup var wg sync.WaitGroup
for _, test := range tests { for _, test := range tests {
wg.Add(1) wg.Add(1)
go func(te testVali) { go func(te testVali) {
t.Run(te.in, func(t *testing.T) { t.Run(te.in, func(t *testing.T) {
v := SingleFlight("test")(mock)//注册 v := SingleFlight("test")(mock) //注册
tr := &Transport2{op:"test"} tr := &Transport2{op: "test"}
ctx:=transport.NewServerContext(context.Background(),tr) ctx := transport.NewServerContext(context.Background(), tr)
re, err := v(ctx, te) re, err := v(ctx, te)
if err!=nil{ if err != nil {
t.Error(err) t.Error(err)
} }
if re!=te.out{ if re != te.out {
t.Errorf("err: %v",te) t.Errorf("err: %v", te)
} }
wg.Done() wg.Done()
}) })
@ -70,16 +71,15 @@ func TestUse(t *testing.T) {
//最后计算总调用次数 //最后计算总调用次数
t.Run("callNum", func(t *testing.T) { t.Run("callNum", func(t *testing.T) {
if callNum!=3{ if callNum != 3 {
t.Errorf("callNum err: %v",callNum) t.Errorf("callNum err: %v", callNum)
} }
}) })
} }
// 测试不使用单飞时
//测试不使用单飞时
func TestNoUse(t *testing.T) { func TestNoUse(t *testing.T) {
var mu sync.Mutex var mu sync.Mutex
var callNum int var callNum int
var mock middleware.Handler = func(ctx context.Context, req interface{}) (interface{}, error) { var mock middleware.Handler = func(ctx context.Context, req interface{}) (interface{}, error) {
@ -87,32 +87,32 @@ func TestNoUse(t *testing.T) {
mu.Lock() mu.Lock()
callNum++ callNum++
mu.Unlock() mu.Unlock()
time.Sleep(1*time.Second) time.Sleep(1 * time.Second)
return in.out, nil return in.out, nil
} }
tests := []testVali{ tests := []testVali{
{"1",1}, {"1", 1},
{"2",2}, {"2", 2},
{"2",2}, {"2", 2},
{"3",3}, {"3", 3},
{"3",3}, {"3", 3},
{"3",3}, {"3", 3},
} }
var wg sync.WaitGroup var wg sync.WaitGroup
for _, test := range tests { for _, test := range tests {
wg.Add(1) wg.Add(1)
go func(te testVali) { go func(te testVali) {
t.Run(te.in, func(t *testing.T) { t.Run(te.in, func(t *testing.T) {
v := SingleFlight()(mock)//移除注册 v := SingleFlight()(mock) //移除注册
tr := &Transport2{op:"test"} tr := &Transport2{op: "test"}
ctx:=transport.NewServerContext(context.Background(),tr) ctx := transport.NewServerContext(context.Background(), tr)
re, err := v(ctx, te) re, err := v(ctx, te)
if err!=nil{ if err != nil {
t.Error(err) t.Error(err)
} }
if re!=te.out{ if re != te.out {
t.Errorf("err: %v",te) t.Errorf("err: %v", te)
} }
wg.Done() wg.Done()
}) })
@ -123,8 +123,8 @@ func TestNoUse(t *testing.T) {
//最后计算总调用次数 //最后计算总调用次数
t.Run("callNum", func(t *testing.T) { t.Run("callNum", func(t *testing.T) {
if callNum!=6{ if callNum != 6 {
t.Errorf("callNum err: %v",callNum) t.Errorf("callNum err: %v", callNum)
} }
}) })
} }

Loading…
Cancel
Save