fix(jwt): parse server custom claims (#1817)

* fix(jwt): parse server custom claims

* fix(jwt): parse server custom claims & use factory pattern


Co-authored-by: 王真 <zhen.wang@yo-star.com>
pull/1833/head
Zhen Wang 3 years ago committed by GitHub
parent 85800cedb9
commit 4dadafff90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 24
      middleware/auth/jwt/jwt.go
  2. 113
      middleware/auth/jwt/jwt_test.go

@ -48,7 +48,7 @@ type Option func(*options)
// Parser is a jwt parser
type options struct {
signingMethod jwt.SigningMethod
claims jwt.Claims
claims func() jwt.Claims
tokenHeader map[string]interface{}
}
@ -60,9 +60,11 @@ func WithSigningMethod(method jwt.SigningMethod) Option {
}
// WithClaims with customer claim
func WithClaims(claims jwt.Claims) Option {
// If you use it in Server, f needs to return a new jwt.Claims object each time to avoid concurrent write problems
// If you use it in Client, f only needs to return a single object to provide performance
func WithClaims(f func() jwt.Claims) Option {
return func(o *options) {
o.claims = claims
o.claims = f
}
}
@ -77,7 +79,6 @@ func WithTokenHeader(header map[string]interface{}) Option {
func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware {
o := &options{
signingMethod: jwt.SigningMethodHS256,
claims: jwt.RegisteredClaims{},
}
for _, opt := range opts {
opt(o)
@ -93,7 +94,15 @@ func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware {
return nil, ErrMissingJwtToken
}
jwtToken := auths[1]
tokenInfo, err := jwt.Parse(jwtToken, keyFunc)
var (
tokenInfo *jwt.Token
err error
)
if o.claims != nil {
tokenInfo, err = jwt.ParseWithClaims(jwtToken, o.claims(), keyFunc)
} else {
tokenInfo, err = jwt.Parse(jwtToken, keyFunc)
}
if err != nil {
if ve, ok := err.(*jwt.ValidationError); ok {
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
@ -120,9 +129,10 @@ func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware {
// Client is a client jwt middleware.
func Client(keyProvider jwt.Keyfunc, opts ...Option) middleware.Middleware {
claims := jwt.RegisteredClaims{}
o := &options{
signingMethod: jwt.SigningMethodHS256,
claims: jwt.RegisteredClaims{},
claims: func() jwt.Claims { return claims },
}
for _, opt := range opts {
opt(o)
@ -132,7 +142,7 @@ func Client(keyProvider jwt.Keyfunc, opts ...Option) middleware.Middleware {
if keyProvider == nil {
return nil, ErrNeedTokenProvider
}
token := jwt.NewWithClaims(o.signingMethod, o.claims)
token := jwt.NewWithClaims(o.signingMethod, o.claims())
if o.tokenHeader != nil {
for k, v := range o.tokenHeader {
token.Header[k] = v

@ -4,8 +4,11 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"net/http"
"reflect"
"strconv"
"sync"
"testing"
"time"
@ -63,6 +66,107 @@ func (tr *Transport) ReplyHeader() transport.Header {
return nil
}
type CustomerClaims struct {
Name string `json:"name"`
jwt.RegisteredClaims
}
func TestJWTServerParse(t *testing.T) {
var (
errConcurrentWrite = errors.New("concurrent write claims")
errParseClaims = errors.New("bad result, token claims is not CustomerClaims")
)
testKey := "testKey"
tests := []struct {
name string
token func() string
claims func() jwt.Claims
exceptErr error
key string
goroutineNum int
}{
{
name: "normal",
token: func() string {
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &CustomerClaims{}).SignedString([]byte(testKey))
if err != nil {
panic(err)
}
return fmt.Sprintf(bearerFormat, token)
},
claims: func() jwt.Claims {
return &CustomerClaims{}
},
exceptErr: nil,
key: testKey,
goroutineNum: 1,
},
{
name: "concurrent request",
token: func() string {
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &CustomerClaims{
Name: strconv.Itoa(rand.Int()),
}).SignedString([]byte(testKey))
if err != nil {
panic(err)
}
return fmt.Sprintf(bearerFormat, token)
},
claims: func() jwt.Claims {
return &CustomerClaims{}
},
exceptErr: nil,
key: testKey,
goroutineNum: 10000,
},
}
next := func(ctx context.Context, req interface{}) (interface{}, error) {
testToken, _ := FromContext(ctx)
var name string
if customerClaims, ok := testToken.(*CustomerClaims); ok {
name = customerClaims.Name
} else {
return nil, errParseClaims
}
// mock biz
time.Sleep(100 * time.Millisecond)
if customerClaims, ok := testToken.(*CustomerClaims); ok {
if name != customerClaims.Name {
return nil, errConcurrentWrite
}
} else {
return nil, errParseClaims
}
return "reply", nil
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
server := Server(
func(token *jwt.Token) (interface{}, error) { return []byte(testKey), nil },
WithClaims(test.claims),
)(next)
wg := sync.WaitGroup{}
for i := 0; i < test.goroutineNum; i++ {
wg.Add(1)
go func() {
defer wg.Done()
ctx := transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, test.token())})
_, err2 := server(ctx, test.name)
if !errors.Is(test.exceptErr, err2) {
t.Errorf("except error %v, but got %v", test.exceptErr, err2)
}
}()
}
wg.Wait()
})
}
}
func TestServer(t *testing.T) {
testKey := "testKey"
mapClaims := jwt.MapClaims{}
@ -279,6 +383,7 @@ func TestClientWithClaims(t *testing.T) {
testKey := "testKey"
mapClaims := jwt.MapClaims{}
mapClaims["name"] = "xiaoli"
mapClaimsFunc := func() jwt.Claims { return mapClaims }
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims)
token, err := claims.SignedString([]byte(testKey))
if err != nil {
@ -301,7 +406,7 @@ func TestClientWithClaims(t *testing.T) {
next := func(ctx context.Context, req interface{}) (interface{}, error) {
return "reply", nil
}
handler := Client(test.tokenProvider, WithClaims(mapClaims))(next)
handler := Client(test.tokenProvider, WithClaims(mapClaimsFunc))(next)
header := &headerCarrier{}
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
if !errors.Is(test.expectError, err2) {
@ -319,6 +424,7 @@ func TestClientWithHeader(t *testing.T) {
testKey := "testKey"
mapClaims := jwt.MapClaims{}
mapClaims["name"] = "xiaoli"
mapClaimsFunc := func() jwt.Claims { return mapClaims }
tokenHeader := map[string]interface{}{
"test": "test",
}
@ -336,7 +442,7 @@ func TestClientWithHeader(t *testing.T) {
next := func(ctx context.Context, req interface{}) (interface{}, error) {
return "reply", nil
}
handler := Client(tProvider, WithClaims(mapClaims), WithTokenHeader(tokenHeader))(next)
handler := Client(tProvider, WithClaims(mapClaimsFunc), WithTokenHeader(tokenHeader))(next)
header := &headerCarrier{}
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
if err2 != nil {
@ -351,6 +457,7 @@ func TestClientMissKey(t *testing.T) {
testKey := "testKey"
mapClaims := jwt.MapClaims{}
mapClaims["name"] = "xiaoli"
mapClaimsFunc := func() jwt.Claims { return mapClaims }
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims)
token, err := claims.SignedString([]byte(testKey))
if err != nil {
@ -373,7 +480,7 @@ func TestClientMissKey(t *testing.T) {
next := func(ctx context.Context, req interface{}) (interface{}, error) {
return "reply", nil
}
handler := Client(test.tokenProvider, WithClaims(mapClaims))(next)
handler := Client(test.tokenProvider, WithClaims(mapClaimsFunc))(next)
header := &headerCarrier{}
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
if !errors.Is(test.expectError, err2) {

Loading…
Cancel
Save