|
|
|
@ -4,8 +4,11 @@ import ( |
|
|
|
|
"context" |
|
|
|
|
"errors" |
|
|
|
|
"fmt" |
|
|
|
|
"math/rand" |
|
|
|
|
"net/http" |
|
|
|
|
"reflect" |
|
|
|
|
"strconv" |
|
|
|
|
"sync" |
|
|
|
|
"testing" |
|
|
|
|
"time" |
|
|
|
|
|
|
|
|
@ -63,6 +66,107 @@ func (tr *Transport) ReplyHeader() transport.Header { |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type CustomerClaims struct { |
|
|
|
|
Name string `json:"name"` |
|
|
|
|
jwt.RegisteredClaims |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestJWTServerParse(t *testing.T) { |
|
|
|
|
var ( |
|
|
|
|
errConcurrentWrite = errors.New("concurrent write claims") |
|
|
|
|
errParseClaims = errors.New("bad result, token claims is not CustomerClaims") |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
testKey := "testKey" |
|
|
|
|
tests := []struct { |
|
|
|
|
name string |
|
|
|
|
token func() string |
|
|
|
|
claims func() jwt.Claims |
|
|
|
|
exceptErr error |
|
|
|
|
key string |
|
|
|
|
goroutineNum int |
|
|
|
|
}{ |
|
|
|
|
{ |
|
|
|
|
name: "normal", |
|
|
|
|
token: func() string { |
|
|
|
|
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &CustomerClaims{}).SignedString([]byte(testKey)) |
|
|
|
|
if err != nil { |
|
|
|
|
panic(err) |
|
|
|
|
} |
|
|
|
|
return fmt.Sprintf(bearerFormat, token) |
|
|
|
|
}, |
|
|
|
|
claims: func() jwt.Claims { |
|
|
|
|
return &CustomerClaims{} |
|
|
|
|
}, |
|
|
|
|
exceptErr: nil, |
|
|
|
|
key: testKey, |
|
|
|
|
goroutineNum: 1, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "concurrent request", |
|
|
|
|
token: func() string { |
|
|
|
|
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &CustomerClaims{ |
|
|
|
|
Name: strconv.Itoa(rand.Int()), |
|
|
|
|
}).SignedString([]byte(testKey)) |
|
|
|
|
if err != nil { |
|
|
|
|
panic(err) |
|
|
|
|
} |
|
|
|
|
return fmt.Sprintf(bearerFormat, token) |
|
|
|
|
}, |
|
|
|
|
claims: func() jwt.Claims { |
|
|
|
|
return &CustomerClaims{} |
|
|
|
|
}, |
|
|
|
|
exceptErr: nil, |
|
|
|
|
key: testKey, |
|
|
|
|
goroutineNum: 10000, |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
|
|
|
|
testToken, _ := FromContext(ctx) |
|
|
|
|
var name string |
|
|
|
|
if customerClaims, ok := testToken.(*CustomerClaims); ok { |
|
|
|
|
name = customerClaims.Name |
|
|
|
|
} else { |
|
|
|
|
return nil, errParseClaims |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// mock biz
|
|
|
|
|
time.Sleep(100 * time.Millisecond) |
|
|
|
|
|
|
|
|
|
if customerClaims, ok := testToken.(*CustomerClaims); ok { |
|
|
|
|
if name != customerClaims.Name { |
|
|
|
|
return nil, errConcurrentWrite |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
return nil, errParseClaims |
|
|
|
|
} |
|
|
|
|
return "reply", nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for _, test := range tests { |
|
|
|
|
t.Run(test.name, func(t *testing.T) { |
|
|
|
|
server := Server( |
|
|
|
|
func(token *jwt.Token) (interface{}, error) { return []byte(testKey), nil }, |
|
|
|
|
WithClaims(test.claims), |
|
|
|
|
)(next) |
|
|
|
|
wg := sync.WaitGroup{} |
|
|
|
|
for i := 0; i < test.goroutineNum; i++ { |
|
|
|
|
wg.Add(1) |
|
|
|
|
go func() { |
|
|
|
|
defer wg.Done() |
|
|
|
|
ctx := transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, test.token())}) |
|
|
|
|
_, err2 := server(ctx, test.name) |
|
|
|
|
if !errors.Is(test.exceptErr, err2) { |
|
|
|
|
t.Errorf("except error %v, but got %v", test.exceptErr, err2) |
|
|
|
|
} |
|
|
|
|
}() |
|
|
|
|
} |
|
|
|
|
wg.Wait() |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestServer(t *testing.T) { |
|
|
|
|
testKey := "testKey" |
|
|
|
|
mapClaims := jwt.MapClaims{} |
|
|
|
@ -279,6 +383,7 @@ func TestClientWithClaims(t *testing.T) { |
|
|
|
|
testKey := "testKey" |
|
|
|
|
mapClaims := jwt.MapClaims{} |
|
|
|
|
mapClaims["name"] = "xiaoli" |
|
|
|
|
mapClaimsFunc := func() jwt.Claims { return mapClaims } |
|
|
|
|
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) |
|
|
|
|
token, err := claims.SignedString([]byte(testKey)) |
|
|
|
|
if err != nil { |
|
|
|
@ -301,7 +406,7 @@ func TestClientWithClaims(t *testing.T) { |
|
|
|
|
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
|
|
|
|
return "reply", nil |
|
|
|
|
} |
|
|
|
|
handler := Client(test.tokenProvider, WithClaims(mapClaims))(next) |
|
|
|
|
handler := Client(test.tokenProvider, WithClaims(mapClaimsFunc))(next) |
|
|
|
|
header := &headerCarrier{} |
|
|
|
|
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") |
|
|
|
|
if !errors.Is(test.expectError, err2) { |
|
|
|
@ -319,6 +424,7 @@ func TestClientWithHeader(t *testing.T) { |
|
|
|
|
testKey := "testKey" |
|
|
|
|
mapClaims := jwt.MapClaims{} |
|
|
|
|
mapClaims["name"] = "xiaoli" |
|
|
|
|
mapClaimsFunc := func() jwt.Claims { return mapClaims } |
|
|
|
|
tokenHeader := map[string]interface{}{ |
|
|
|
|
"test": "test", |
|
|
|
|
} |
|
|
|
@ -336,7 +442,7 @@ func TestClientWithHeader(t *testing.T) { |
|
|
|
|
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
|
|
|
|
return "reply", nil |
|
|
|
|
} |
|
|
|
|
handler := Client(tProvider, WithClaims(mapClaims), WithTokenHeader(tokenHeader))(next) |
|
|
|
|
handler := Client(tProvider, WithClaims(mapClaimsFunc), WithTokenHeader(tokenHeader))(next) |
|
|
|
|
header := &headerCarrier{} |
|
|
|
|
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") |
|
|
|
|
if err2 != nil { |
|
|
|
@ -351,6 +457,7 @@ func TestClientMissKey(t *testing.T) { |
|
|
|
|
testKey := "testKey" |
|
|
|
|
mapClaims := jwt.MapClaims{} |
|
|
|
|
mapClaims["name"] = "xiaoli" |
|
|
|
|
mapClaimsFunc := func() jwt.Claims { return mapClaims } |
|
|
|
|
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) |
|
|
|
|
token, err := claims.SignedString([]byte(testKey)) |
|
|
|
|
if err != nil { |
|
|
|
@ -373,7 +480,7 @@ func TestClientMissKey(t *testing.T) { |
|
|
|
|
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
|
|
|
|
return "reply", nil |
|
|
|
|
} |
|
|
|
|
handler := Client(test.tokenProvider, WithClaims(mapClaims))(next) |
|
|
|
|
handler := Client(test.tokenProvider, WithClaims(mapClaimsFunc))(next) |
|
|
|
|
header := &headerCarrier{} |
|
|
|
|
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") |
|
|
|
|
if !errors.Is(test.expectError, err2) { |
|
|
|
|