From 76ab0baa561fc0e812e4a6b38af63271eb307791 Mon Sep 17 00:00:00 2001 From: Casper-Mars <50834595+Casper-Mars@users.noreply.github.com> Date: Fri, 7 Jan 2022 22:15:47 +0800 Subject: [PATCH] feat(middleware/auth/jwt): add customer header (#1752) --- middleware/auth/jwt/jwt.go | 40 +++++++++++++++++++++++---------- middleware/auth/jwt/jwt_test.go | 28 +++++++++++++++++++++++ 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/middleware/auth/jwt/jwt.go b/middleware/auth/jwt/jwt.go index 5c77dc227..59d670ceb 100644 --- a/middleware/auth/jwt/jwt.go +++ b/middleware/auth/jwt/jwt.go @@ -22,21 +22,24 @@ const ( // bearerFormat authorization token format bearerFormat string = "Bearer %s" - // authorizationKey holds the key used to store the JWT Token in the request header. + // authorizationKey holds the key used to store the JWT Token in the request tokenHeader. authorizationKey string = "Authorization" + + // reason holds the error reason. + reason string = "UNAUTHORIZED" ) var ( - ErrMissingJwtToken = errors.Unauthorized("UNAUTHORIZED", "JWT token is missing") - ErrMissingKeyFunc = errors.Unauthorized("UNAUTHORIZED", "keyFunc is missing") - ErrTokenInvalid = errors.Unauthorized("UNAUTHORIZED", "Token is invalid") - ErrTokenExpired = errors.Unauthorized("UNAUTHORIZED", "JWT token has expired") - ErrTokenParseFail = errors.Unauthorized("UNAUTHORIZED", "Fail to parse JWT token ") - ErrUnSupportSigningMethod = errors.Unauthorized("UNAUTHORIZED", "Wrong signing method") - ErrWrongContext = errors.Unauthorized("UNAUTHORIZED", "Wrong context for middleware") - ErrNeedTokenProvider = errors.Unauthorized("UNAUTHORIZED", "Token provider is missing") - ErrSignToken = errors.Unauthorized("UNAUTHORIZED", "Can not sign token.Is the key correct?") - ErrGetKey = errors.Unauthorized("UNAUTHORIZED", "Can not get key while signing token") + ErrMissingJwtToken = errors.Unauthorized(reason, "JWT token is missing") + ErrMissingKeyFunc = errors.Unauthorized(reason, "keyFunc is missing") + ErrTokenInvalid = errors.Unauthorized(reason, "Token is invalid") + ErrTokenExpired = errors.Unauthorized(reason, "JWT token has expired") + ErrTokenParseFail = errors.Unauthorized(reason, "Fail to parse JWT token ") + ErrUnSupportSigningMethod = errors.Unauthorized(reason, "Wrong signing method") + ErrWrongContext = errors.Unauthorized(reason, "Wrong context for middleware") + ErrNeedTokenProvider = errors.Unauthorized(reason, "Token provider is missing") + ErrSignToken = errors.Unauthorized(reason, "Can not sign token.Is the key correct?") + ErrGetKey = errors.Unauthorized(reason, "Can not get key while signing token") ) // Option is jwt option. @@ -46,6 +49,7 @@ type Option func(*options) type options struct { signingMethod jwt.SigningMethod claims jwt.Claims + tokenHeader map[string]interface{} } // WithSigningMethod with signing method option. @@ -62,6 +66,13 @@ func WithClaims(claims jwt.Claims) Option { } } +// WithTokenHeader withe customer tokenHeader for client side +func WithTokenHeader(header map[string]interface{}) Option { + return func(o *options) { + o.tokenHeader = header + } +} + // Server is a server auth middleware. Check the token and extract the info from token. func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware { o := &options{ @@ -93,7 +104,7 @@ func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware { return nil, ErrTokenParseFail } } - return nil, errors.Unauthorized("UNAUTHORIZED", err.Error()) + return nil, errors.Unauthorized(reason, err.Error()) } else if !tokenInfo.Valid { return nil, ErrTokenInvalid } else if tokenInfo.Method != o.signingMethod { @@ -122,6 +133,11 @@ func Client(keyProvider jwt.Keyfunc, opts ...Option) middleware.Middleware { return nil, ErrNeedTokenProvider } token := jwt.NewWithClaims(o.signingMethod, o.claims) + if o.tokenHeader != nil { + for k, v := range o.tokenHeader { + token.Header[k] = v + } + } key, err := keyProvider(token) if err != nil { return nil, ErrGetKey diff --git a/middleware/auth/jwt/jwt_test.go b/middleware/auth/jwt/jwt_test.go index 627b585e2..3f02cacc6 100644 --- a/middleware/auth/jwt/jwt_test.go +++ b/middleware/auth/jwt/jwt_test.go @@ -294,6 +294,34 @@ func TestClientWithClaims(t *testing.T) { }) } +func TestClientWithHeader(t *testing.T) { + testKey := "testKey" + mapClaims := jwt.MapClaims{} + mapClaims["name"] = "xiaoli" + tokenHeader := map[string]interface{}{ + "test": "test", + } + claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) + for k, v := range tokenHeader { + claims.Header[k] = v + } + token, err := claims.SignedString([]byte(testKey)) + if err != nil { + panic(err) + } + tProvider := func(*jwt.Token) (interface{}, error) { + return []byte(testKey), nil + } + next := func(ctx context.Context, req interface{}) (interface{}, error) { + return "reply", nil + } + handler := Client(tProvider, WithClaims(mapClaims), WithTokenHeader(tokenHeader))(next) + header := &headerCarrier{} + _, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") + assert.Equal(t, nil, err2) + assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) +} + func TestClientMissKey(t *testing.T) { testKey := "testKey" mapClaims := jwt.MapClaims{}