From 4dadafff90258bf7ecae616cb1453f0c130e53d0 Mon Sep 17 00:00:00 2001 From: Zhen Wang <37008932+CryBecase@users.noreply.github.com> Date: Tue, 22 Feb 2022 14:06:36 +0800 Subject: [PATCH] fix(jwt): parse server custom claims (#1817) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(jwt): parse server custom claims * fix(jwt): parse server custom claims & use factory pattern Co-authored-by: 王真 --- middleware/auth/jwt/jwt.go | 24 +++++-- middleware/auth/jwt/jwt_test.go | 113 +++++++++++++++++++++++++++++++- 2 files changed, 127 insertions(+), 10 deletions(-) diff --git a/middleware/auth/jwt/jwt.go b/middleware/auth/jwt/jwt.go index c7a7f78f4..60dad2db2 100644 --- a/middleware/auth/jwt/jwt.go +++ b/middleware/auth/jwt/jwt.go @@ -48,7 +48,7 @@ type Option func(*options) // Parser is a jwt parser type options struct { signingMethod jwt.SigningMethod - claims jwt.Claims + claims func() jwt.Claims tokenHeader map[string]interface{} } @@ -60,9 +60,11 @@ func WithSigningMethod(method jwt.SigningMethod) Option { } // WithClaims with customer claim -func WithClaims(claims jwt.Claims) Option { +// If you use it in Server, f needs to return a new jwt.Claims object each time to avoid concurrent write problems +// If you use it in Client, f only needs to return a single object to provide performance +func WithClaims(f func() jwt.Claims) Option { return func(o *options) { - o.claims = claims + o.claims = f } } @@ -77,7 +79,6 @@ func WithTokenHeader(header map[string]interface{}) Option { func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware { o := &options{ signingMethod: jwt.SigningMethodHS256, - claims: jwt.RegisteredClaims{}, } for _, opt := range opts { opt(o) @@ -93,7 +94,15 @@ func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware { return nil, ErrMissingJwtToken } jwtToken := auths[1] - tokenInfo, err := jwt.Parse(jwtToken, keyFunc) + var ( + tokenInfo *jwt.Token + err error + ) + if o.claims != nil { + tokenInfo, err = jwt.ParseWithClaims(jwtToken, o.claims(), keyFunc) + } else { + tokenInfo, err = jwt.Parse(jwtToken, keyFunc) + } if err != nil { if ve, ok := err.(*jwt.ValidationError); ok { if ve.Errors&jwt.ValidationErrorMalformed != 0 { @@ -120,9 +129,10 @@ func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware { // Client is a client jwt middleware. func Client(keyProvider jwt.Keyfunc, opts ...Option) middleware.Middleware { + claims := jwt.RegisteredClaims{} o := &options{ signingMethod: jwt.SigningMethodHS256, - claims: jwt.RegisteredClaims{}, + claims: func() jwt.Claims { return claims }, } for _, opt := range opts { opt(o) @@ -132,7 +142,7 @@ func Client(keyProvider jwt.Keyfunc, opts ...Option) middleware.Middleware { if keyProvider == nil { return nil, ErrNeedTokenProvider } - token := jwt.NewWithClaims(o.signingMethod, o.claims) + token := jwt.NewWithClaims(o.signingMethod, o.claims()) if o.tokenHeader != nil { for k, v := range o.tokenHeader { token.Header[k] = v diff --git a/middleware/auth/jwt/jwt_test.go b/middleware/auth/jwt/jwt_test.go index 7944d503a..188744b01 100644 --- a/middleware/auth/jwt/jwt_test.go +++ b/middleware/auth/jwt/jwt_test.go @@ -4,8 +4,11 @@ import ( "context" "errors" "fmt" + "math/rand" "net/http" "reflect" + "strconv" + "sync" "testing" "time" @@ -63,6 +66,107 @@ func (tr *Transport) ReplyHeader() transport.Header { return nil } +type CustomerClaims struct { + Name string `json:"name"` + jwt.RegisteredClaims +} + +func TestJWTServerParse(t *testing.T) { + var ( + errConcurrentWrite = errors.New("concurrent write claims") + errParseClaims = errors.New("bad result, token claims is not CustomerClaims") + ) + + testKey := "testKey" + tests := []struct { + name string + token func() string + claims func() jwt.Claims + exceptErr error + key string + goroutineNum int + }{ + { + name: "normal", + token: func() string { + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &CustomerClaims{}).SignedString([]byte(testKey)) + if err != nil { + panic(err) + } + return fmt.Sprintf(bearerFormat, token) + }, + claims: func() jwt.Claims { + return &CustomerClaims{} + }, + exceptErr: nil, + key: testKey, + goroutineNum: 1, + }, + { + name: "concurrent request", + token: func() string { + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &CustomerClaims{ + Name: strconv.Itoa(rand.Int()), + }).SignedString([]byte(testKey)) + if err != nil { + panic(err) + } + return fmt.Sprintf(bearerFormat, token) + }, + claims: func() jwt.Claims { + return &CustomerClaims{} + }, + exceptErr: nil, + key: testKey, + goroutineNum: 10000, + }, + } + + next := func(ctx context.Context, req interface{}) (interface{}, error) { + testToken, _ := FromContext(ctx) + var name string + if customerClaims, ok := testToken.(*CustomerClaims); ok { + name = customerClaims.Name + } else { + return nil, errParseClaims + } + + // mock biz + time.Sleep(100 * time.Millisecond) + + if customerClaims, ok := testToken.(*CustomerClaims); ok { + if name != customerClaims.Name { + return nil, errConcurrentWrite + } + } else { + return nil, errParseClaims + } + return "reply", nil + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + server := Server( + func(token *jwt.Token) (interface{}, error) { return []byte(testKey), nil }, + WithClaims(test.claims), + )(next) + wg := sync.WaitGroup{} + for i := 0; i < test.goroutineNum; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ctx := transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, test.token())}) + _, err2 := server(ctx, test.name) + if !errors.Is(test.exceptErr, err2) { + t.Errorf("except error %v, but got %v", test.exceptErr, err2) + } + }() + } + wg.Wait() + }) + } +} + func TestServer(t *testing.T) { testKey := "testKey" mapClaims := jwt.MapClaims{} @@ -279,6 +383,7 @@ func TestClientWithClaims(t *testing.T) { testKey := "testKey" mapClaims := jwt.MapClaims{} mapClaims["name"] = "xiaoli" + mapClaimsFunc := func() jwt.Claims { return mapClaims } claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) token, err := claims.SignedString([]byte(testKey)) if err != nil { @@ -301,7 +406,7 @@ func TestClientWithClaims(t *testing.T) { next := func(ctx context.Context, req interface{}) (interface{}, error) { return "reply", nil } - handler := Client(test.tokenProvider, WithClaims(mapClaims))(next) + handler := Client(test.tokenProvider, WithClaims(mapClaimsFunc))(next) header := &headerCarrier{} _, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") if !errors.Is(test.expectError, err2) { @@ -319,6 +424,7 @@ func TestClientWithHeader(t *testing.T) { testKey := "testKey" mapClaims := jwt.MapClaims{} mapClaims["name"] = "xiaoli" + mapClaimsFunc := func() jwt.Claims { return mapClaims } tokenHeader := map[string]interface{}{ "test": "test", } @@ -336,7 +442,7 @@ func TestClientWithHeader(t *testing.T) { next := func(ctx context.Context, req interface{}) (interface{}, error) { return "reply", nil } - handler := Client(tProvider, WithClaims(mapClaims), WithTokenHeader(tokenHeader))(next) + handler := Client(tProvider, WithClaims(mapClaimsFunc), WithTokenHeader(tokenHeader))(next) header := &headerCarrier{} _, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") if err2 != nil { @@ -351,6 +457,7 @@ func TestClientMissKey(t *testing.T) { testKey := "testKey" mapClaims := jwt.MapClaims{} mapClaims["name"] = "xiaoli" + mapClaimsFunc := func() jwt.Claims { return mapClaims } claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) token, err := claims.SignedString([]byte(testKey)) if err != nil { @@ -373,7 +480,7 @@ func TestClientMissKey(t *testing.T) { next := func(ctx context.Context, req interface{}) (interface{}, error) { return "reply", nil } - handler := Client(test.tokenProvider, WithClaims(mapClaims))(next) + handler := Client(test.tokenProvider, WithClaims(mapClaimsFunc))(next) header := &headerCarrier{} _, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") if !errors.Is(test.expectError, err2) {