feat(middleware/auth/jwt): add customer header (#1752)

pull/1755/head
Casper-Mars 3 years ago committed by GitHub
parent 1c3185f9e5
commit 76ab0baa56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 40
      middleware/auth/jwt/jwt.go
  2. 28
      middleware/auth/jwt/jwt_test.go

@ -22,21 +22,24 @@ const (
// bearerFormat authorization token format // bearerFormat authorization token format
bearerFormat string = "Bearer %s" bearerFormat string = "Bearer %s"
// authorizationKey holds the key used to store the JWT Token in the request header. // authorizationKey holds the key used to store the JWT Token in the request tokenHeader.
authorizationKey string = "Authorization" authorizationKey string = "Authorization"
// reason holds the error reason.
reason string = "UNAUTHORIZED"
) )
var ( var (
ErrMissingJwtToken = errors.Unauthorized("UNAUTHORIZED", "JWT token is missing") ErrMissingJwtToken = errors.Unauthorized(reason, "JWT token is missing")
ErrMissingKeyFunc = errors.Unauthorized("UNAUTHORIZED", "keyFunc is missing") ErrMissingKeyFunc = errors.Unauthorized(reason, "keyFunc is missing")
ErrTokenInvalid = errors.Unauthorized("UNAUTHORIZED", "Token is invalid") ErrTokenInvalid = errors.Unauthorized(reason, "Token is invalid")
ErrTokenExpired = errors.Unauthorized("UNAUTHORIZED", "JWT token has expired") ErrTokenExpired = errors.Unauthorized(reason, "JWT token has expired")
ErrTokenParseFail = errors.Unauthorized("UNAUTHORIZED", "Fail to parse JWT token ") ErrTokenParseFail = errors.Unauthorized(reason, "Fail to parse JWT token ")
ErrUnSupportSigningMethod = errors.Unauthorized("UNAUTHORIZED", "Wrong signing method") ErrUnSupportSigningMethod = errors.Unauthorized(reason, "Wrong signing method")
ErrWrongContext = errors.Unauthorized("UNAUTHORIZED", "Wrong context for middleware") ErrWrongContext = errors.Unauthorized(reason, "Wrong context for middleware")
ErrNeedTokenProvider = errors.Unauthorized("UNAUTHORIZED", "Token provider is missing") ErrNeedTokenProvider = errors.Unauthorized(reason, "Token provider is missing")
ErrSignToken = errors.Unauthorized("UNAUTHORIZED", "Can not sign token.Is the key correct?") ErrSignToken = errors.Unauthorized(reason, "Can not sign token.Is the key correct?")
ErrGetKey = errors.Unauthorized("UNAUTHORIZED", "Can not get key while signing token") ErrGetKey = errors.Unauthorized(reason, "Can not get key while signing token")
) )
// Option is jwt option. // Option is jwt option.
@ -46,6 +49,7 @@ type Option func(*options)
type options struct { type options struct {
signingMethod jwt.SigningMethod signingMethod jwt.SigningMethod
claims jwt.Claims claims jwt.Claims
tokenHeader map[string]interface{}
} }
// WithSigningMethod with signing method option. // WithSigningMethod with signing method option.
@ -62,6 +66,13 @@ func WithClaims(claims jwt.Claims) Option {
} }
} }
// WithTokenHeader withe customer tokenHeader for client side
func WithTokenHeader(header map[string]interface{}) Option {
return func(o *options) {
o.tokenHeader = header
}
}
// Server is a server auth middleware. Check the token and extract the info from token. // Server is a server auth middleware. Check the token and extract the info from token.
func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware { func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware {
o := &options{ o := &options{
@ -93,7 +104,7 @@ func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware {
return nil, ErrTokenParseFail return nil, ErrTokenParseFail
} }
} }
return nil, errors.Unauthorized("UNAUTHORIZED", err.Error()) return nil, errors.Unauthorized(reason, err.Error())
} else if !tokenInfo.Valid { } else if !tokenInfo.Valid {
return nil, ErrTokenInvalid return nil, ErrTokenInvalid
} else if tokenInfo.Method != o.signingMethod { } else if tokenInfo.Method != o.signingMethod {
@ -122,6 +133,11 @@ func Client(keyProvider jwt.Keyfunc, opts ...Option) middleware.Middleware {
return nil, ErrNeedTokenProvider return nil, ErrNeedTokenProvider
} }
token := jwt.NewWithClaims(o.signingMethod, o.claims) token := jwt.NewWithClaims(o.signingMethod, o.claims)
if o.tokenHeader != nil {
for k, v := range o.tokenHeader {
token.Header[k] = v
}
}
key, err := keyProvider(token) key, err := keyProvider(token)
if err != nil { if err != nil {
return nil, ErrGetKey return nil, ErrGetKey

@ -294,6 +294,34 @@ func TestClientWithClaims(t *testing.T) {
}) })
} }
func TestClientWithHeader(t *testing.T) {
testKey := "testKey"
mapClaims := jwt.MapClaims{}
mapClaims["name"] = "xiaoli"
tokenHeader := map[string]interface{}{
"test": "test",
}
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims)
for k, v := range tokenHeader {
claims.Header[k] = v
}
token, err := claims.SignedString([]byte(testKey))
if err != nil {
panic(err)
}
tProvider := func(*jwt.Token) (interface{}, error) {
return []byte(testKey), nil
}
next := func(ctx context.Context, req interface{}) (interface{}, error) {
return "reply", nil
}
handler := Client(tProvider, WithClaims(mapClaims), WithTokenHeader(tokenHeader))(next)
header := &headerCarrier{}
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
assert.Equal(t, nil, err2)
assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey))
}
func TestClientMissKey(t *testing.T) { func TestClientMissKey(t *testing.T) {
testKey := "testKey" testKey := "testKey"
mapClaims := jwt.MapClaims{} mapClaims := jwt.MapClaims{}

Loading…
Cancel
Save