feat(middleware/auth): add auth middleware (#1274)

* add auth middleware
pull/1461/head
Casper-Mars 3 years ago committed by GitHub
parent aed6af7acc
commit ab5152dbe1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      contrib/config/apollo/go.sum
  2. 1
      contrib/config/kubernetes/go.sum
  3. 1
      contrib/config/nacos/go.sum
  4. 1
      contrib/log/fluent/go.sum
  5. 1
      contrib/log/zap/go.sum
  6. 1
      contrib/metrics/datadog/go.sum
  7. 1
      contrib/metrics/prometheus/go.sum
  8. 1
      contrib/registry/consul/go.sum
  9. 1
      contrib/registry/etcd/go.sum
  10. 1
      contrib/registry/kubernetes/go.sum
  11. 1
      contrib/registry/nacos/go.sum
  12. 1
      contrib/registry/zookeeper/go.sum
  13. 68
      examples/auth/jwt/main.go
  14. 1
      examples/go.mod
  15. 3
      examples/go.sum
  16. 1
      go.mod
  17. 2
      go.sum
  18. 150
      middleware/auth/jwt/jwt.go
  19. 331
      middleware/auth/jwt/jwt_test.go

@ -66,6 +66,7 @@ github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIh
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=

@ -81,6 +81,7 @@ github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvSc
github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=

@ -32,6 +32,7 @@ github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiU
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U=
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s=

@ -25,6 +25,7 @@ github.com/go-kratos/aegis v0.1.1/go.mod h1:jYeSQ3Gesba478zEnujOiG5QdsyF3Xk/8owF
github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=

@ -25,6 +25,7 @@ github.com/go-kratos/aegis v0.1.1/go.mod h1:jYeSQ3Gesba478zEnujOiG5QdsyF3Xk/8owF
github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=

@ -27,6 +27,7 @@ github.com/go-kratos/aegis v0.1.1/go.mod h1:jYeSQ3Gesba478zEnujOiG5QdsyF3Xk/8owF
github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=

@ -81,6 +81,7 @@ github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFG
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=

@ -32,6 +32,7 @@ github.com/go-kratos/aegis v0.1.1/go.mod h1:jYeSQ3Gesba478zEnujOiG5QdsyF3Xk/8owF
github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=

@ -51,6 +51,7 @@ github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=

@ -78,6 +78,7 @@ github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvSc
github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=

@ -32,6 +32,7 @@ github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiU
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U=
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s=

@ -23,6 +23,7 @@ github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvSc
github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U=
github.com/go-zookeeper/zk v1.0.2 h1:4mx0EYENAdX/B/rbunjlt5+4RTA/a9SMHBRuSKdGxPM=
github.com/go-zookeeper/zk v1.0.2/go.mod h1:nOB03cncLtlp4t+UAkGSV+9beXP/akpekBwL+UX1Qcw=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=

@ -0,0 +1,68 @@
package main
import (
"context"
"log"
"github.com/go-kratos/kratos/examples/helloworld/helloworld"
"github.com/go-kratos/kratos/v2"
"github.com/go-kratos/kratos/v2/middleware/auth/jwt"
"github.com/go-kratos/kratos/v2/transport/grpc"
"github.com/go-kratos/kratos/v2/transport/http"
jwtv4 "github.com/golang-jwt/jwt/v4"
)
type server struct {
helloworld.UnimplementedGreeterServer
hc helloworld.GreeterClient
}
func (s *server) SayHello(ctx context.Context, in *helloworld.HelloRequest) (*helloworld.HelloReply, error) {
return &helloworld.HelloReply{Message: "hello from service"}, nil
}
func main() {
testKey := "testKey"
httpSrv := http.NewServer(
http.Address(":8000"),
http.Middleware(
jwt.Server(func(token *jwtv4.Token) (interface{}, error) {
return []byte(testKey), nil
}),
),
)
grpcSrv := grpc.NewServer(
grpc.Address(":9000"),
grpc.Middleware(
jwt.Server(func(token *jwtv4.Token) (interface{}, error) {
return []byte(testKey), nil
}),
),
)
serviceTestKey := "serviceTestKey"
con, _ := grpc.DialInsecure(
context.Background(),
grpc.WithEndpoint("dns:///127.0.0.1:9001"),
grpc.WithMiddleware(
jwt.Client(func(token *jwtv4.Token) (interface{}, error) {
return []byte(serviceTestKey), nil
}),
),
)
s := &server{
hc: helloworld.NewGreeterClient(con),
}
helloworld.RegisterGreeterServer(grpcSrv, s)
helloworld.RegisterGreeterHTTPServer(httpSrv, s)
app := kratos.New(
kratos.Name("helloworld"),
kratos.Server(
httpSrv,
grpcSrv,
),
)
if err := app.Run(); err != nil {
log.Fatal(err)
}
}

@ -19,6 +19,7 @@ require (
github.com/go-redis/redis/extra/redisotel v0.3.0
github.com/go-redis/redis/v8 v8.11.2
github.com/go-sql-driver/mysql v1.6.0
github.com/golang-jwt/jwt/v4 v4.0.0
github.com/google/wire v0.5.0
github.com/gorilla/handlers v1.5.1
github.com/gorilla/mux v1.8.0

@ -209,7 +209,10 @@ github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zV
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/glog v0.0.0-20210429001901-424d2337a529 h1:2voWjNECnrZRbfwXxHB1/j8wa6xdKn85B5NzgVL/pTU=
github.com/golang/glog v0.0.0-20210429001901-424d2337a529/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=

@ -6,6 +6,7 @@ require (
github.com/fsnotify/fsnotify v1.4.9
github.com/go-kratos/aegis v0.1.1
github.com/go-playground/form/v4 v4.2.0
github.com/golang-jwt/jwt/v4 v4.0.0
github.com/google/uuid v1.3.0
github.com/gorilla/mux v1.8.0
github.com/imdario/mergo v0.3.12

@ -28,6 +28,8 @@ github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBY
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/form/v4 v4.2.0 h1:N1wh+Goz61e6w66vo8vJkQt+uwZSoLz50kZPJWR8eic=
github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U=
github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=

@ -0,0 +1,150 @@
package jwt
import (
"context"
"fmt"
"strings"
"github.com/golang-jwt/jwt/v4"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
)
type authKey struct{}
const (
// bearerWord the bearer key word for authorization
bearerWord string = "Bearer"
// bearerFormat authorization token format
bearerFormat string = "Bearer %s"
// authorizationKey holds the key used to store the JWT Token in the request header.
authorizationKey string = "Authorization"
)
var (
ErrMissingJwtToken = errors.Unauthorized("UNAUTHORIZED", "JWT token is missing")
ErrMissingKeyFunc = errors.Unauthorized("UNAUTHORIZED", "keyFunc is missing")
ErrTokenInvalid = errors.Unauthorized("UNAUTHORIZED", "Token is invalid")
ErrTokenExpired = errors.Unauthorized("UNAUTHORIZED", "JWT token has expired")
ErrTokenParseFail = errors.Unauthorized("UNAUTHORIZED", "Fail to parse JWT token ")
ErrUnSupportSigningMethod = errors.Unauthorized("UNAUTHORIZED", "Wrong signing method")
ErrWrongContext = errors.Unauthorized("UNAUTHORIZED", "Wrong context for middelware")
ErrNeedTokenProvider = errors.Unauthorized("UNAUTHORIZED", "Token provider is missing")
ErrSignToken = errors.Unauthorized("UNAUTHORIZED", "Can not sign token.Is the key correct?")
ErrGetKey = errors.Unauthorized("UNAUTHORIZED", "Can not get key while signing token")
)
// Option is jwt option.
type Option func(*options)
// Parser is a jwt parser
type options struct {
signingMethod jwt.SigningMethod
claims jwt.Claims
}
// WithSigningMethod with signing method option.
func WithSigningMethod(method jwt.SigningMethod) Option {
return func(o *options) {
o.signingMethod = method
}
}
// WithClaims with customer claim
func WithClaims(claims jwt.Claims) Option {
return func(o *options) {
o.claims = claims
}
}
// Server is a server auth middleware. Check the token and extract the info from token.
func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware {
o := &options{
signingMethod: jwt.SigningMethodHS256,
claims: jwt.StandardClaims{},
}
for _, opt := range opts {
opt(o)
}
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
if header, ok := transport.FromServerContext(ctx); ok {
if keyFunc == nil {
return nil, ErrMissingKeyFunc
}
auths := strings.SplitN(header.RequestHeader().Get(authorizationKey), " ", 2)
if len(auths) != 2 || !strings.EqualFold(auths[0], bearerWord) {
return nil, ErrMissingJwtToken
}
jwtToken := auths[1]
tokenInfo, err := jwt.Parse(jwtToken, keyFunc)
if err != nil {
if ve, ok := err.(*jwt.ValidationError); ok {
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
return nil, ErrTokenInvalid
} else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 {
return nil, ErrTokenExpired
} else {
return nil, ErrTokenParseFail
}
}
} else if !tokenInfo.Valid {
return nil, ErrTokenInvalid
} else if tokenInfo.Method != o.signingMethod {
return nil, ErrUnSupportSigningMethod
}
ctx = NewContext(ctx, tokenInfo.Claims)
return handler(ctx, req)
}
return nil, ErrWrongContext
}
}
}
// Client is a client jwt middleware.
func Client(keyProvider jwt.Keyfunc, opts ...Option) middleware.Middleware {
o := &options{
signingMethod: jwt.SigningMethodHS256,
claims: jwt.StandardClaims{},
}
for _, opt := range opts {
opt(o)
}
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
if keyProvider == nil {
return nil, ErrNeedTokenProvider
}
token := jwt.NewWithClaims(o.signingMethod, o.claims)
key, err := keyProvider(token)
if err != nil {
return nil, ErrGetKey
}
tokenStr, err := token.SignedString(key)
if err != nil {
return nil, ErrSignToken
}
if clientContext, ok := transport.FromClientContext(ctx); ok {
clientContext.RequestHeader().Set(authorizationKey, fmt.Sprintf(bearerFormat, tokenStr))
return handler(ctx, req)
}
return nil, ErrWrongContext
}
}
}
// NewContext put auth info into context
func NewContext(ctx context.Context, info jwt.Claims) context.Context {
return context.WithValue(ctx, authKey{}, info)
}
// FromContext extract auth info from context
func FromContext(ctx context.Context) (token jwt.Claims, ok bool) {
token, ok = ctx.Value(authKey{}).(jwt.Claims)
return
}

@ -0,0 +1,331 @@
package jwt
import (
"context"
"errors"
"fmt"
"net/http"
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
"github.com/stretchr/testify/assert"
)
type headerCarrier http.Header
func (hc headerCarrier) Get(key string) string { return http.Header(hc).Get(key) }
func (hc headerCarrier) Set(key string, value string) { http.Header(hc).Set(key, value) }
// Keys lists the keys stored in this carrier.
func (hc headerCarrier) Keys() []string {
keys := make([]string, 0, len(hc))
for k := range http.Header(hc) {
keys = append(keys, k)
}
return keys
}
func newTokenHeader(headerKey string, token string) *headerCarrier {
header := &headerCarrier{}
header.Set(headerKey, token)
return header
}
type Transport struct {
kind transport.Kind
endpoint string
operation string
reqHeader transport.Header
}
func (tr *Transport) Kind() transport.Kind {
return tr.kind
}
func (tr *Transport) Endpoint() string {
return tr.endpoint
}
func (tr *Transport) Operation() string {
return tr.operation
}
func (tr *Transport) RequestHeader() transport.Header {
return tr.reqHeader
}
func (tr *Transport) ReplyHeader() transport.Header {
return nil
}
func TestServer(t *testing.T) {
testKey := "testKey"
mapClaims := jwt.MapClaims{}
mapClaims["name"] = "xiaoli"
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims)
token, err := claims.SignedString([]byte(testKey))
if err != nil {
panic(err)
}
token = fmt.Sprintf(bearerFormat, token)
tests := []struct {
name string
ctx context.Context
signingMethod jwt.SigningMethod
exceptErr error
key string
}{
{
name: "normal",
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}),
signingMethod: jwt.SigningMethodHS256,
exceptErr: nil,
key: testKey,
},
{
name: "miss token",
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: headerCarrier{}}),
signingMethod: jwt.SigningMethodHS256,
exceptErr: ErrMissingJwtToken,
key: testKey,
},
{
name: "token invalid",
ctx: transport.NewServerContext(context.Background(), &Transport{
reqHeader: newTokenHeader(authorizationKey, fmt.Sprintf(bearerFormat, "12313123")),
}),
signingMethod: jwt.SigningMethodHS256,
exceptErr: ErrTokenInvalid,
key: testKey,
},
{
name: "method invalid",
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}),
signingMethod: jwt.SigningMethodES384,
exceptErr: ErrUnSupportSigningMethod,
key: testKey,
},
{
name: "miss signing method",
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}),
signingMethod: nil,
exceptErr: nil,
key: testKey,
},
{
name: "miss signing method",
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}),
signingMethod: nil,
exceptErr: nil,
key: testKey,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var testToken jwt.Claims
next := func(ctx context.Context, req interface{}) (interface{}, error) {
t.Log(req)
testToken, _ = FromContext(ctx)
return "reply", nil
}
var server middleware.Handler
if test.signingMethod != nil {
server = Server(func(token *jwt.Token) (interface{}, error) {
return []byte(test.key), nil
}, WithSigningMethod(test.signingMethod))(next)
} else {
server = Server(func(token *jwt.Token) (interface{}, error) {
return []byte(test.key), nil
})(next)
}
_, err2 := server(test.ctx, test.name)
assert.Equal(t, test.exceptErr, err2)
if test.exceptErr == nil {
assert.NotNil(t, testToken)
_, ok := testToken.(jwt.MapClaims)
assert.True(t, ok)
}
})
}
}
func TestClient(t *testing.T) {
testKey := "testKey"
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{})
token, err := claims.SignedString([]byte(testKey))
if err != nil {
panic(err)
}
tProvider := func(*jwt.Token) (interface{}, error) {
return []byte(testKey), nil
}
tests := []struct {
name string
expectError error
tokenProvider jwt.Keyfunc
}{
{
name: "normal",
expectError: nil,
tokenProvider: tProvider,
},
{
name: "miss token provider",
expectError: ErrNeedTokenProvider,
tokenProvider: nil,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
next := func(ctx context.Context, req interface{}) (interface{}, error) {
return "reply", nil
}
handler := Client(test.tokenProvider)(next)
header := &headerCarrier{}
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
assert.Equal(t, test.expectError, err2)
if err2 == nil {
assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey))
}
})
}
}
func TestTokenExpire(t *testing.T) {
testKey := "testKey"
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{
ExpiresAt: time.Now().Add(time.Millisecond).Unix(),
})
token, err := claims.SignedString([]byte(testKey))
if err != nil {
panic(err)
}
token = fmt.Sprintf(bearerFormat, token)
time.Sleep(time.Second)
next := func(ctx context.Context, req interface{}) (interface{}, error) {
t.Log(req)
return "reply", nil
}
ctx := transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)})
server := Server(func(token *jwt.Token) (interface{}, error) {
return []byte(testKey), nil
}, WithSigningMethod(jwt.SigningMethodHS256))(next)
_, err2 := server(ctx, "test expire token")
assert.Equal(t, ErrTokenExpired, err2)
}
func TestMissingKeyFunc(t *testing.T) {
testKey := "testKey"
mapClaims := jwt.MapClaims{}
mapClaims["name"] = "xiaoli"
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims)
token, err := claims.SignedString([]byte(testKey))
if err != nil {
panic(err)
}
token = fmt.Sprintf(bearerFormat, token)
test := struct {
name string
ctx context.Context
signingMethod jwt.SigningMethod
exceptErr error
key string
}{
name: "miss key",
ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}),
signingMethod: jwt.SigningMethodHS256,
exceptErr: ErrMissingKeyFunc,
key: "",
}
var testToken jwt.Claims
next := func(ctx context.Context, req interface{}) (interface{}, error) {
t.Log(req)
testToken, _ = FromContext(ctx)
return "reply", nil
}
server := Server(nil)(next)
_, err2 := server(test.ctx, test.name)
assert.Equal(t, test.exceptErr, err2)
if test.exceptErr == nil {
assert.NotNil(t, testToken)
}
}
func TestClientWithClaims(t *testing.T) {
testKey := "testKey"
mapClaims := jwt.MapClaims{}
mapClaims["name"] = "xiaoli"
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims)
token, err := claims.SignedString([]byte(testKey))
if err != nil {
panic(err)
}
tProvider := func(*jwt.Token) (interface{}, error) {
return []byte(testKey), nil
}
test := struct {
name string
expectError error
tokenProvider jwt.Keyfunc
}{
name: "normal",
expectError: nil,
tokenProvider: tProvider,
}
t.Run(test.name, func(t *testing.T) {
next := func(ctx context.Context, req interface{}) (interface{}, error) {
return "reply", nil
}
handler := Client(test.tokenProvider, WithClaims(mapClaims))(next)
header := &headerCarrier{}
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
assert.Equal(t, test.expectError, err2)
if err2 == nil {
assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey))
}
})
}
func TestClientMissKey(t *testing.T) {
testKey := "testKey"
mapClaims := jwt.MapClaims{}
mapClaims["name"] = "xiaoli"
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims)
token, err := claims.SignedString([]byte(testKey))
if err != nil {
panic(err)
}
tProvider := func(*jwt.Token) (interface{}, error) {
return nil, errors.New("some error")
}
test := struct {
name string
expectError error
tokenProvider jwt.Keyfunc
}{
name: "normal",
expectError: ErrGetKey,
tokenProvider: tProvider,
}
t.Run(test.name, func(t *testing.T) {
next := func(ctx context.Context, req interface{}) (interface{}, error) {
return "reply", nil
}
handler := Client(test.tokenProvider, WithClaims(mapClaims))(next)
header := &headerCarrier{}
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
assert.Equal(t, test.expectError, err2)
if err2 == nil {
assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey))
}
})
}
Loading…
Cancel
Save