parent
aed6af7acc
commit
ab5152dbe1
@ -0,0 +1,68 @@ |
||||
package main |
||||
|
||||
import ( |
||||
"context" |
||||
"log" |
||||
|
||||
"github.com/go-kratos/kratos/examples/helloworld/helloworld" |
||||
"github.com/go-kratos/kratos/v2" |
||||
"github.com/go-kratos/kratos/v2/middleware/auth/jwt" |
||||
"github.com/go-kratos/kratos/v2/transport/grpc" |
||||
"github.com/go-kratos/kratos/v2/transport/http" |
||||
jwtv4 "github.com/golang-jwt/jwt/v4" |
||||
) |
||||
|
||||
type server struct { |
||||
helloworld.UnimplementedGreeterServer |
||||
|
||||
hc helloworld.GreeterClient |
||||
} |
||||
|
||||
func (s *server) SayHello(ctx context.Context, in *helloworld.HelloRequest) (*helloworld.HelloReply, error) { |
||||
return &helloworld.HelloReply{Message: "hello from service"}, nil |
||||
} |
||||
|
||||
func main() { |
||||
testKey := "testKey" |
||||
httpSrv := http.NewServer( |
||||
http.Address(":8000"), |
||||
http.Middleware( |
||||
jwt.Server(func(token *jwtv4.Token) (interface{}, error) { |
||||
return []byte(testKey), nil |
||||
}), |
||||
), |
||||
) |
||||
grpcSrv := grpc.NewServer( |
||||
grpc.Address(":9000"), |
||||
grpc.Middleware( |
||||
jwt.Server(func(token *jwtv4.Token) (interface{}, error) { |
||||
return []byte(testKey), nil |
||||
}), |
||||
), |
||||
) |
||||
serviceTestKey := "serviceTestKey" |
||||
con, _ := grpc.DialInsecure( |
||||
context.Background(), |
||||
grpc.WithEndpoint("dns:///127.0.0.1:9001"), |
||||
grpc.WithMiddleware( |
||||
jwt.Client(func(token *jwtv4.Token) (interface{}, error) { |
||||
return []byte(serviceTestKey), nil |
||||
}), |
||||
), |
||||
) |
||||
s := &server{ |
||||
hc: helloworld.NewGreeterClient(con), |
||||
} |
||||
helloworld.RegisterGreeterServer(grpcSrv, s) |
||||
helloworld.RegisterGreeterHTTPServer(httpSrv, s) |
||||
app := kratos.New( |
||||
kratos.Name("helloworld"), |
||||
kratos.Server( |
||||
httpSrv, |
||||
grpcSrv, |
||||
), |
||||
) |
||||
if err := app.Run(); err != nil { |
||||
log.Fatal(err) |
||||
} |
||||
} |
@ -0,0 +1,150 @@ |
||||
package jwt |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"strings" |
||||
|
||||
"github.com/golang-jwt/jwt/v4" |
||||
|
||||
"github.com/go-kratos/kratos/v2/errors" |
||||
"github.com/go-kratos/kratos/v2/middleware" |
||||
"github.com/go-kratos/kratos/v2/transport" |
||||
) |
||||
|
||||
type authKey struct{} |
||||
|
||||
const ( |
||||
|
||||
// bearerWord the bearer key word for authorization
|
||||
bearerWord string = "Bearer" |
||||
|
||||
// bearerFormat authorization token format
|
||||
bearerFormat string = "Bearer %s" |
||||
|
||||
// authorizationKey holds the key used to store the JWT Token in the request header.
|
||||
authorizationKey string = "Authorization" |
||||
) |
||||
|
||||
var ( |
||||
ErrMissingJwtToken = errors.Unauthorized("UNAUTHORIZED", "JWT token is missing") |
||||
ErrMissingKeyFunc = errors.Unauthorized("UNAUTHORIZED", "keyFunc is missing") |
||||
ErrTokenInvalid = errors.Unauthorized("UNAUTHORIZED", "Token is invalid") |
||||
ErrTokenExpired = errors.Unauthorized("UNAUTHORIZED", "JWT token has expired") |
||||
ErrTokenParseFail = errors.Unauthorized("UNAUTHORIZED", "Fail to parse JWT token ") |
||||
ErrUnSupportSigningMethod = errors.Unauthorized("UNAUTHORIZED", "Wrong signing method") |
||||
ErrWrongContext = errors.Unauthorized("UNAUTHORIZED", "Wrong context for middelware") |
||||
ErrNeedTokenProvider = errors.Unauthorized("UNAUTHORIZED", "Token provider is missing") |
||||
ErrSignToken = errors.Unauthorized("UNAUTHORIZED", "Can not sign token.Is the key correct?") |
||||
ErrGetKey = errors.Unauthorized("UNAUTHORIZED", "Can not get key while signing token") |
||||
) |
||||
|
||||
// Option is jwt option.
|
||||
type Option func(*options) |
||||
|
||||
// Parser is a jwt parser
|
||||
type options struct { |
||||
signingMethod jwt.SigningMethod |
||||
claims jwt.Claims |
||||
} |
||||
|
||||
// WithSigningMethod with signing method option.
|
||||
func WithSigningMethod(method jwt.SigningMethod) Option { |
||||
return func(o *options) { |
||||
o.signingMethod = method |
||||
} |
||||
} |
||||
|
||||
// WithClaims with customer claim
|
||||
func WithClaims(claims jwt.Claims) Option { |
||||
return func(o *options) { |
||||
o.claims = claims |
||||
} |
||||
} |
||||
|
||||
// Server is a server auth middleware. Check the token and extract the info from token.
|
||||
func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware { |
||||
o := &options{ |
||||
signingMethod: jwt.SigningMethodHS256, |
||||
claims: jwt.StandardClaims{}, |
||||
} |
||||
for _, opt := range opts { |
||||
opt(o) |
||||
} |
||||
return func(handler middleware.Handler) middleware.Handler { |
||||
return func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
if header, ok := transport.FromServerContext(ctx); ok { |
||||
if keyFunc == nil { |
||||
return nil, ErrMissingKeyFunc |
||||
} |
||||
auths := strings.SplitN(header.RequestHeader().Get(authorizationKey), " ", 2) |
||||
if len(auths) != 2 || !strings.EqualFold(auths[0], bearerWord) { |
||||
return nil, ErrMissingJwtToken |
||||
} |
||||
jwtToken := auths[1] |
||||
tokenInfo, err := jwt.Parse(jwtToken, keyFunc) |
||||
if err != nil { |
||||
if ve, ok := err.(*jwt.ValidationError); ok { |
||||
if ve.Errors&jwt.ValidationErrorMalformed != 0 { |
||||
return nil, ErrTokenInvalid |
||||
} else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 { |
||||
return nil, ErrTokenExpired |
||||
} else { |
||||
return nil, ErrTokenParseFail |
||||
} |
||||
} |
||||
} else if !tokenInfo.Valid { |
||||
return nil, ErrTokenInvalid |
||||
} else if tokenInfo.Method != o.signingMethod { |
||||
return nil, ErrUnSupportSigningMethod |
||||
} |
||||
ctx = NewContext(ctx, tokenInfo.Claims) |
||||
return handler(ctx, req) |
||||
} |
||||
return nil, ErrWrongContext |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Client is a client jwt middleware.
|
||||
func Client(keyProvider jwt.Keyfunc, opts ...Option) middleware.Middleware { |
||||
o := &options{ |
||||
signingMethod: jwt.SigningMethodHS256, |
||||
claims: jwt.StandardClaims{}, |
||||
} |
||||
for _, opt := range opts { |
||||
opt(o) |
||||
} |
||||
return func(handler middleware.Handler) middleware.Handler { |
||||
return func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
if keyProvider == nil { |
||||
return nil, ErrNeedTokenProvider |
||||
} |
||||
token := jwt.NewWithClaims(o.signingMethod, o.claims) |
||||
key, err := keyProvider(token) |
||||
if err != nil { |
||||
return nil, ErrGetKey |
||||
} |
||||
tokenStr, err := token.SignedString(key) |
||||
if err != nil { |
||||
return nil, ErrSignToken |
||||
} |
||||
if clientContext, ok := transport.FromClientContext(ctx); ok { |
||||
clientContext.RequestHeader().Set(authorizationKey, fmt.Sprintf(bearerFormat, tokenStr)) |
||||
return handler(ctx, req) |
||||
} |
||||
return nil, ErrWrongContext |
||||
} |
||||
} |
||||
} |
||||
|
||||
// NewContext put auth info into context
|
||||
func NewContext(ctx context.Context, info jwt.Claims) context.Context { |
||||
return context.WithValue(ctx, authKey{}, info) |
||||
} |
||||
|
||||
// FromContext extract auth info from context
|
||||
func FromContext(ctx context.Context) (token jwt.Claims, ok bool) { |
||||
token, ok = ctx.Value(authKey{}).(jwt.Claims) |
||||
return |
||||
} |
@ -0,0 +1,331 @@ |
||||
package jwt |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"fmt" |
||||
"net/http" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/golang-jwt/jwt/v4" |
||||
|
||||
"github.com/go-kratos/kratos/v2/middleware" |
||||
"github.com/go-kratos/kratos/v2/transport" |
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
type headerCarrier http.Header |
||||
|
||||
func (hc headerCarrier) Get(key string) string { return http.Header(hc).Get(key) } |
||||
|
||||
func (hc headerCarrier) Set(key string, value string) { http.Header(hc).Set(key, value) } |
||||
|
||||
// Keys lists the keys stored in this carrier.
|
||||
func (hc headerCarrier) Keys() []string { |
||||
keys := make([]string, 0, len(hc)) |
||||
for k := range http.Header(hc) { |
||||
keys = append(keys, k) |
||||
} |
||||
return keys |
||||
} |
||||
|
||||
func newTokenHeader(headerKey string, token string) *headerCarrier { |
||||
header := &headerCarrier{} |
||||
header.Set(headerKey, token) |
||||
return header |
||||
} |
||||
|
||||
type Transport struct { |
||||
kind transport.Kind |
||||
endpoint string |
||||
operation string |
||||
reqHeader transport.Header |
||||
} |
||||
|
||||
func (tr *Transport) Kind() transport.Kind { |
||||
return tr.kind |
||||
} |
||||
|
||||
func (tr *Transport) Endpoint() string { |
||||
return tr.endpoint |
||||
} |
||||
|
||||
func (tr *Transport) Operation() string { |
||||
return tr.operation |
||||
} |
||||
|
||||
func (tr *Transport) RequestHeader() transport.Header { |
||||
return tr.reqHeader |
||||
} |
||||
|
||||
func (tr *Transport) ReplyHeader() transport.Header { |
||||
return nil |
||||
} |
||||
|
||||
func TestServer(t *testing.T) { |
||||
testKey := "testKey" |
||||
mapClaims := jwt.MapClaims{} |
||||
mapClaims["name"] = "xiaoli" |
||||
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) |
||||
token, err := claims.SignedString([]byte(testKey)) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
token = fmt.Sprintf(bearerFormat, token) |
||||
tests := []struct { |
||||
name string |
||||
ctx context.Context |
||||
signingMethod jwt.SigningMethod |
||||
exceptErr error |
||||
key string |
||||
}{ |
||||
{ |
||||
name: "normal", |
||||
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), |
||||
signingMethod: jwt.SigningMethodHS256, |
||||
exceptErr: nil, |
||||
key: testKey, |
||||
}, |
||||
{ |
||||
name: "miss token", |
||||
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: headerCarrier{}}), |
||||
signingMethod: jwt.SigningMethodHS256, |
||||
exceptErr: ErrMissingJwtToken, |
||||
key: testKey, |
||||
}, |
||||
{ |
||||
name: "token invalid", |
||||
ctx: transport.NewServerContext(context.Background(), &Transport{ |
||||
reqHeader: newTokenHeader(authorizationKey, fmt.Sprintf(bearerFormat, "12313123")), |
||||
}), |
||||
signingMethod: jwt.SigningMethodHS256, |
||||
exceptErr: ErrTokenInvalid, |
||||
key: testKey, |
||||
}, |
||||
{ |
||||
name: "method invalid", |
||||
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), |
||||
signingMethod: jwt.SigningMethodES384, |
||||
exceptErr: ErrUnSupportSigningMethod, |
||||
key: testKey, |
||||
}, |
||||
{ |
||||
name: "miss signing method", |
||||
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), |
||||
signingMethod: nil, |
||||
exceptErr: nil, |
||||
key: testKey, |
||||
}, |
||||
{ |
||||
name: "miss signing method", |
||||
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), |
||||
signingMethod: nil, |
||||
exceptErr: nil, |
||||
key: testKey, |
||||
}, |
||||
} |
||||
|
||||
for _, test := range tests { |
||||
t.Run(test.name, func(t *testing.T) { |
||||
var testToken jwt.Claims |
||||
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
t.Log(req) |
||||
testToken, _ = FromContext(ctx) |
||||
return "reply", nil |
||||
} |
||||
var server middleware.Handler |
||||
if test.signingMethod != nil { |
||||
server = Server(func(token *jwt.Token) (interface{}, error) { |
||||
return []byte(test.key), nil |
||||
}, WithSigningMethod(test.signingMethod))(next) |
||||
} else { |
||||
server = Server(func(token *jwt.Token) (interface{}, error) { |
||||
return []byte(test.key), nil |
||||
})(next) |
||||
} |
||||
_, err2 := server(test.ctx, test.name) |
||||
assert.Equal(t, test.exceptErr, err2) |
||||
if test.exceptErr == nil { |
||||
assert.NotNil(t, testToken) |
||||
_, ok := testToken.(jwt.MapClaims) |
||||
assert.True(t, ok) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestClient(t *testing.T) { |
||||
testKey := "testKey" |
||||
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{}) |
||||
token, err := claims.SignedString([]byte(testKey)) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
tProvider := func(*jwt.Token) (interface{}, error) { |
||||
return []byte(testKey), nil |
||||
} |
||||
tests := []struct { |
||||
name string |
||||
expectError error |
||||
tokenProvider jwt.Keyfunc |
||||
}{ |
||||
{ |
||||
name: "normal", |
||||
expectError: nil, |
||||
tokenProvider: tProvider, |
||||
}, |
||||
{ |
||||
name: "miss token provider", |
||||
expectError: ErrNeedTokenProvider, |
||||
tokenProvider: nil, |
||||
}, |
||||
} |
||||
for _, test := range tests { |
||||
t.Run(test.name, func(t *testing.T) { |
||||
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
return "reply", nil |
||||
} |
||||
handler := Client(test.tokenProvider)(next) |
||||
header := &headerCarrier{} |
||||
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") |
||||
assert.Equal(t, test.expectError, err2) |
||||
if err2 == nil { |
||||
assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestTokenExpire(t *testing.T) { |
||||
testKey := "testKey" |
||||
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{ |
||||
ExpiresAt: time.Now().Add(time.Millisecond).Unix(), |
||||
}) |
||||
token, err := claims.SignedString([]byte(testKey)) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
token = fmt.Sprintf(bearerFormat, token) |
||||
time.Sleep(time.Second) |
||||
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
t.Log(req) |
||||
return "reply", nil |
||||
} |
||||
ctx := transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}) |
||||
server := Server(func(token *jwt.Token) (interface{}, error) { |
||||
return []byte(testKey), nil |
||||
}, WithSigningMethod(jwt.SigningMethodHS256))(next) |
||||
_, err2 := server(ctx, "test expire token") |
||||
assert.Equal(t, ErrTokenExpired, err2) |
||||
} |
||||
|
||||
func TestMissingKeyFunc(t *testing.T) { |
||||
testKey := "testKey" |
||||
mapClaims := jwt.MapClaims{} |
||||
mapClaims["name"] = "xiaoli" |
||||
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) |
||||
token, err := claims.SignedString([]byte(testKey)) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
token = fmt.Sprintf(bearerFormat, token) |
||||
test := struct { |
||||
name string |
||||
ctx context.Context |
||||
signingMethod jwt.SigningMethod |
||||
exceptErr error |
||||
key string |
||||
}{ |
||||
name: "miss key", |
||||
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), |
||||
signingMethod: jwt.SigningMethodHS256, |
||||
exceptErr: ErrMissingKeyFunc, |
||||
key: "", |
||||
} |
||||
|
||||
var testToken jwt.Claims |
||||
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
t.Log(req) |
||||
testToken, _ = FromContext(ctx) |
||||
return "reply", nil |
||||
} |
||||
server := Server(nil)(next) |
||||
_, err2 := server(test.ctx, test.name) |
||||
assert.Equal(t, test.exceptErr, err2) |
||||
if test.exceptErr == nil { |
||||
assert.NotNil(t, testToken) |
||||
} |
||||
} |
||||
|
||||
func TestClientWithClaims(t *testing.T) { |
||||
testKey := "testKey" |
||||
mapClaims := jwt.MapClaims{} |
||||
mapClaims["name"] = "xiaoli" |
||||
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) |
||||
token, err := claims.SignedString([]byte(testKey)) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
tProvider := func(*jwt.Token) (interface{}, error) { |
||||
return []byte(testKey), nil |
||||
} |
||||
test := struct { |
||||
name string |
||||
expectError error |
||||
tokenProvider jwt.Keyfunc |
||||
}{ |
||||
name: "normal", |
||||
expectError: nil, |
||||
tokenProvider: tProvider, |
||||
} |
||||
|
||||
t.Run(test.name, func(t *testing.T) { |
||||
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
return "reply", nil |
||||
} |
||||
handler := Client(test.tokenProvider, WithClaims(mapClaims))(next) |
||||
header := &headerCarrier{} |
||||
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") |
||||
assert.Equal(t, test.expectError, err2) |
||||
if err2 == nil { |
||||
assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) |
||||
} |
||||
}) |
||||
} |
||||
|
||||
func TestClientMissKey(t *testing.T) { |
||||
testKey := "testKey" |
||||
mapClaims := jwt.MapClaims{} |
||||
mapClaims["name"] = "xiaoli" |
||||
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) |
||||
token, err := claims.SignedString([]byte(testKey)) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
tProvider := func(*jwt.Token) (interface{}, error) { |
||||
return nil, errors.New("some error") |
||||
} |
||||
test := struct { |
||||
name string |
||||
expectError error |
||||
tokenProvider jwt.Keyfunc |
||||
}{ |
||||
name: "normal", |
||||
expectError: ErrGetKey, |
||||
tokenProvider: tProvider, |
||||
} |
||||
|
||||
t.Run(test.name, func(t *testing.T) { |
||||
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
return "reply", nil |
||||
} |
||||
handler := Client(test.tokenProvider, WithClaims(mapClaims))(next) |
||||
header := &headerCarrier{} |
||||
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") |
||||
assert.Equal(t, test.expectError, err2) |
||||
if err2 == nil { |
||||
assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) |
||||
} |
||||
}) |
||||
} |
Loading…
Reference in new issue