From ab5152dbe1ca0a277389ebdf7904bda10f8b53af Mon Sep 17 00:00:00 2001 From: Casper-Mars <50834595+Casper-Mars@users.noreply.github.com> Date: Fri, 10 Sep 2021 10:49:18 +0800 Subject: [PATCH] feat(middleware/auth): add auth middleware (#1274) * add auth middleware --- contrib/config/apollo/go.sum | 1 + contrib/config/kubernetes/go.sum | 1 + contrib/config/nacos/go.sum | 1 + contrib/log/fluent/go.sum | 1 + contrib/log/zap/go.sum | 1 + contrib/metrics/datadog/go.sum | 1 + contrib/metrics/prometheus/go.sum | 1 + contrib/registry/consul/go.sum | 1 + contrib/registry/etcd/go.sum | 1 + contrib/registry/kubernetes/go.sum | 1 + contrib/registry/nacos/go.sum | 1 + contrib/registry/zookeeper/go.sum | 1 + examples/auth/jwt/main.go | 68 ++++++ examples/go.mod | 1 + examples/go.sum | 3 + go.mod | 1 + go.sum | 2 + middleware/auth/jwt/jwt.go | 150 +++++++++++++ middleware/auth/jwt/jwt_test.go | 331 +++++++++++++++++++++++++++++ 19 files changed, 568 insertions(+) create mode 100644 examples/auth/jwt/main.go create mode 100644 middleware/auth/jwt/jwt.go create mode 100644 middleware/auth/jwt/jwt_test.go diff --git a/contrib/config/apollo/go.sum b/contrib/config/apollo/go.sum index 27bdce2fd..95613eb9c 100644 --- a/contrib/config/apollo/go.sum +++ b/contrib/config/apollo/go.sum @@ -66,6 +66,7 @@ github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIh github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= diff --git a/contrib/config/kubernetes/go.sum b/contrib/config/kubernetes/go.sum index 7bbbb3aeb..c892b6fa7 100644 --- a/contrib/config/kubernetes/go.sum +++ b/contrib/config/kubernetes/go.sum @@ -81,6 +81,7 @@ github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvSc github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= diff --git a/contrib/config/nacos/go.sum b/contrib/config/nacos/go.sum index 4a2549774..269ca9362 100644 --- a/contrib/config/nacos/go.sum +++ b/contrib/config/nacos/go.sum @@ -32,6 +32,7 @@ github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiU github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U= github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s= diff --git a/contrib/log/fluent/go.sum b/contrib/log/fluent/go.sum index 177b61275..cef40a501 100644 --- a/contrib/log/fluent/go.sum +++ b/contrib/log/fluent/go.sum @@ -25,6 +25,7 @@ github.com/go-kratos/aegis v0.1.1/go.mod h1:jYeSQ3Gesba478zEnujOiG5QdsyF3Xk/8owF github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/contrib/log/zap/go.sum b/contrib/log/zap/go.sum index 2604a6670..a2bf8ea16 100644 --- a/contrib/log/zap/go.sum +++ b/contrib/log/zap/go.sum @@ -25,6 +25,7 @@ github.com/go-kratos/aegis v0.1.1/go.mod h1:jYeSQ3Gesba478zEnujOiG5QdsyF3Xk/8owF github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/contrib/metrics/datadog/go.sum b/contrib/metrics/datadog/go.sum index 157486255..a928fe28a 100644 --- a/contrib/metrics/datadog/go.sum +++ b/contrib/metrics/datadog/go.sum @@ -27,6 +27,7 @@ github.com/go-kratos/aegis v0.1.1/go.mod h1:jYeSQ3Gesba478zEnujOiG5QdsyF3Xk/8owF github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/contrib/metrics/prometheus/go.sum b/contrib/metrics/prometheus/go.sum index 713369514..8afc05ccd 100644 --- a/contrib/metrics/prometheus/go.sum +++ b/contrib/metrics/prometheus/go.sum @@ -81,6 +81,7 @@ github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFG github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= diff --git a/contrib/registry/consul/go.sum b/contrib/registry/consul/go.sum index 5f040e99e..86ffd694e 100644 --- a/contrib/registry/consul/go.sum +++ b/contrib/registry/consul/go.sum @@ -32,6 +32,7 @@ github.com/go-kratos/aegis v0.1.1/go.mod h1:jYeSQ3Gesba478zEnujOiG5QdsyF3Xk/8owF github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/contrib/registry/etcd/go.sum b/contrib/registry/etcd/go.sum index 9d13581fa..60e22bba4 100644 --- a/contrib/registry/etcd/go.sum +++ b/contrib/registry/etcd/go.sum @@ -51,6 +51,7 @@ github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/contrib/registry/kubernetes/go.sum b/contrib/registry/kubernetes/go.sum index ad4c58162..4f374a164 100644 --- a/contrib/registry/kubernetes/go.sum +++ b/contrib/registry/kubernetes/go.sum @@ -78,6 +78,7 @@ github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvSc github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= diff --git a/contrib/registry/nacos/go.sum b/contrib/registry/nacos/go.sum index 0e7ca4a63..41bd7d4bf 100644 --- a/contrib/registry/nacos/go.sum +++ b/contrib/registry/nacos/go.sum @@ -32,6 +32,7 @@ github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiU github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U= github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s= diff --git a/contrib/registry/zookeeper/go.sum b/contrib/registry/zookeeper/go.sum index 2f8dd5727..c2ca4cd38 100644 --- a/contrib/registry/zookeeper/go.sum +++ b/contrib/registry/zookeeper/go.sum @@ -23,6 +23,7 @@ github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvSc github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U= github.com/go-zookeeper/zk v1.0.2 h1:4mx0EYENAdX/B/rbunjlt5+4RTA/a9SMHBRuSKdGxPM= github.com/go-zookeeper/zk v1.0.2/go.mod h1:nOB03cncLtlp4t+UAkGSV+9beXP/akpekBwL+UX1Qcw= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/examples/auth/jwt/main.go b/examples/auth/jwt/main.go new file mode 100644 index 000000000..b1cd22eb0 --- /dev/null +++ b/examples/auth/jwt/main.go @@ -0,0 +1,68 @@ +package main + +import ( + "context" + "log" + + "github.com/go-kratos/kratos/examples/helloworld/helloworld" + "github.com/go-kratos/kratos/v2" + "github.com/go-kratos/kratos/v2/middleware/auth/jwt" + "github.com/go-kratos/kratos/v2/transport/grpc" + "github.com/go-kratos/kratos/v2/transport/http" + jwtv4 "github.com/golang-jwt/jwt/v4" +) + +type server struct { + helloworld.UnimplementedGreeterServer + + hc helloworld.GreeterClient +} + +func (s *server) SayHello(ctx context.Context, in *helloworld.HelloRequest) (*helloworld.HelloReply, error) { + return &helloworld.HelloReply{Message: "hello from service"}, nil +} + +func main() { + testKey := "testKey" + httpSrv := http.NewServer( + http.Address(":8000"), + http.Middleware( + jwt.Server(func(token *jwtv4.Token) (interface{}, error) { + return []byte(testKey), nil + }), + ), + ) + grpcSrv := grpc.NewServer( + grpc.Address(":9000"), + grpc.Middleware( + jwt.Server(func(token *jwtv4.Token) (interface{}, error) { + return []byte(testKey), nil + }), + ), + ) + serviceTestKey := "serviceTestKey" + con, _ := grpc.DialInsecure( + context.Background(), + grpc.WithEndpoint("dns:///127.0.0.1:9001"), + grpc.WithMiddleware( + jwt.Client(func(token *jwtv4.Token) (interface{}, error) { + return []byte(serviceTestKey), nil + }), + ), + ) + s := &server{ + hc: helloworld.NewGreeterClient(con), + } + helloworld.RegisterGreeterServer(grpcSrv, s) + helloworld.RegisterGreeterHTTPServer(httpSrv, s) + app := kratos.New( + kratos.Name("helloworld"), + kratos.Server( + httpSrv, + grpcSrv, + ), + ) + if err := app.Run(); err != nil { + log.Fatal(err) + } +} diff --git a/examples/go.mod b/examples/go.mod index be2d8b25d..54730a135 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -19,6 +19,7 @@ require ( github.com/go-redis/redis/extra/redisotel v0.3.0 github.com/go-redis/redis/v8 v8.11.2 github.com/go-sql-driver/mysql v1.6.0 + github.com/golang-jwt/jwt/v4 v4.0.0 github.com/google/wire v0.5.0 github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.0 diff --git a/examples/go.sum b/examples/go.sum index ad5c117c3..e58693d99 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -209,7 +209,10 @@ github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zV github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v0.0.0-20210429001901-424d2337a529 h1:2voWjNECnrZRbfwXxHB1/j8wa6xdKn85B5NzgVL/pTU= github.com/golang/glog v0.0.0-20210429001901-424d2337a529/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= diff --git a/go.mod b/go.mod index bcfcee567..c5d51e85e 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/fsnotify/fsnotify v1.4.9 github.com/go-kratos/aegis v0.1.1 github.com/go-playground/form/v4 v4.2.0 + github.com/golang-jwt/jwt/v4 v4.0.0 github.com/google/uuid v1.3.0 github.com/gorilla/mux v1.8.0 github.com/imdario/mergo v0.3.12 diff --git a/go.sum b/go.sum index bde536a8c..91d27434e 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBY github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/form/v4 v4.2.0 h1:N1wh+Goz61e6w66vo8vJkQt+uwZSoLz50kZPJWR8eic= github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U= +github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/middleware/auth/jwt/jwt.go b/middleware/auth/jwt/jwt.go new file mode 100644 index 000000000..23b648d54 --- /dev/null +++ b/middleware/auth/jwt/jwt.go @@ -0,0 +1,150 @@ +package jwt + +import ( + "context" + "fmt" + "strings" + + "github.com/golang-jwt/jwt/v4" + + "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/middleware" + "github.com/go-kratos/kratos/v2/transport" +) + +type authKey struct{} + +const ( + + // bearerWord the bearer key word for authorization + bearerWord string = "Bearer" + + // bearerFormat authorization token format + bearerFormat string = "Bearer %s" + + // authorizationKey holds the key used to store the JWT Token in the request header. + authorizationKey string = "Authorization" +) + +var ( + ErrMissingJwtToken = errors.Unauthorized("UNAUTHORIZED", "JWT token is missing") + ErrMissingKeyFunc = errors.Unauthorized("UNAUTHORIZED", "keyFunc is missing") + ErrTokenInvalid = errors.Unauthorized("UNAUTHORIZED", "Token is invalid") + ErrTokenExpired = errors.Unauthorized("UNAUTHORIZED", "JWT token has expired") + ErrTokenParseFail = errors.Unauthorized("UNAUTHORIZED", "Fail to parse JWT token ") + ErrUnSupportSigningMethod = errors.Unauthorized("UNAUTHORIZED", "Wrong signing method") + ErrWrongContext = errors.Unauthorized("UNAUTHORIZED", "Wrong context for middelware") + ErrNeedTokenProvider = errors.Unauthorized("UNAUTHORIZED", "Token provider is missing") + ErrSignToken = errors.Unauthorized("UNAUTHORIZED", "Can not sign token.Is the key correct?") + ErrGetKey = errors.Unauthorized("UNAUTHORIZED", "Can not get key while signing token") +) + +// Option is jwt option. +type Option func(*options) + +// Parser is a jwt parser +type options struct { + signingMethod jwt.SigningMethod + claims jwt.Claims +} + +// WithSigningMethod with signing method option. +func WithSigningMethod(method jwt.SigningMethod) Option { + return func(o *options) { + o.signingMethod = method + } +} + +// WithClaims with customer claim +func WithClaims(claims jwt.Claims) Option { + return func(o *options) { + o.claims = claims + } +} + +// Server is a server auth middleware. Check the token and extract the info from token. +func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware { + o := &options{ + signingMethod: jwt.SigningMethodHS256, + claims: jwt.StandardClaims{}, + } + for _, opt := range opts { + opt(o) + } + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req interface{}) (interface{}, error) { + if header, ok := transport.FromServerContext(ctx); ok { + if keyFunc == nil { + return nil, ErrMissingKeyFunc + } + auths := strings.SplitN(header.RequestHeader().Get(authorizationKey), " ", 2) + if len(auths) != 2 || !strings.EqualFold(auths[0], bearerWord) { + return nil, ErrMissingJwtToken + } + jwtToken := auths[1] + tokenInfo, err := jwt.Parse(jwtToken, keyFunc) + if err != nil { + if ve, ok := err.(*jwt.ValidationError); ok { + if ve.Errors&jwt.ValidationErrorMalformed != 0 { + return nil, ErrTokenInvalid + } else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 { + return nil, ErrTokenExpired + } else { + return nil, ErrTokenParseFail + } + } + } else if !tokenInfo.Valid { + return nil, ErrTokenInvalid + } else if tokenInfo.Method != o.signingMethod { + return nil, ErrUnSupportSigningMethod + } + ctx = NewContext(ctx, tokenInfo.Claims) + return handler(ctx, req) + } + return nil, ErrWrongContext + } + } +} + +// Client is a client jwt middleware. +func Client(keyProvider jwt.Keyfunc, opts ...Option) middleware.Middleware { + o := &options{ + signingMethod: jwt.SigningMethodHS256, + claims: jwt.StandardClaims{}, + } + for _, opt := range opts { + opt(o) + } + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req interface{}) (interface{}, error) { + if keyProvider == nil { + return nil, ErrNeedTokenProvider + } + token := jwt.NewWithClaims(o.signingMethod, o.claims) + key, err := keyProvider(token) + if err != nil { + return nil, ErrGetKey + } + tokenStr, err := token.SignedString(key) + if err != nil { + return nil, ErrSignToken + } + if clientContext, ok := transport.FromClientContext(ctx); ok { + clientContext.RequestHeader().Set(authorizationKey, fmt.Sprintf(bearerFormat, tokenStr)) + return handler(ctx, req) + } + return nil, ErrWrongContext + } + } +} + +// NewContext put auth info into context +func NewContext(ctx context.Context, info jwt.Claims) context.Context { + return context.WithValue(ctx, authKey{}, info) +} + +// FromContext extract auth info from context +func FromContext(ctx context.Context) (token jwt.Claims, ok bool) { + token, ok = ctx.Value(authKey{}).(jwt.Claims) + return +} diff --git a/middleware/auth/jwt/jwt_test.go b/middleware/auth/jwt/jwt_test.go new file mode 100644 index 000000000..627b585e2 --- /dev/null +++ b/middleware/auth/jwt/jwt_test.go @@ -0,0 +1,331 @@ +package jwt + +import ( + "context" + "errors" + "fmt" + "net/http" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + + "github.com/go-kratos/kratos/v2/middleware" + "github.com/go-kratos/kratos/v2/transport" + "github.com/stretchr/testify/assert" +) + +type headerCarrier http.Header + +func (hc headerCarrier) Get(key string) string { return http.Header(hc).Get(key) } + +func (hc headerCarrier) Set(key string, value string) { http.Header(hc).Set(key, value) } + +// Keys lists the keys stored in this carrier. +func (hc headerCarrier) Keys() []string { + keys := make([]string, 0, len(hc)) + for k := range http.Header(hc) { + keys = append(keys, k) + } + return keys +} + +func newTokenHeader(headerKey string, token string) *headerCarrier { + header := &headerCarrier{} + header.Set(headerKey, token) + return header +} + +type Transport struct { + kind transport.Kind + endpoint string + operation string + reqHeader transport.Header +} + +func (tr *Transport) Kind() transport.Kind { + return tr.kind +} + +func (tr *Transport) Endpoint() string { + return tr.endpoint +} + +func (tr *Transport) Operation() string { + return tr.operation +} + +func (tr *Transport) RequestHeader() transport.Header { + return tr.reqHeader +} + +func (tr *Transport) ReplyHeader() transport.Header { + return nil +} + +func TestServer(t *testing.T) { + testKey := "testKey" + mapClaims := jwt.MapClaims{} + mapClaims["name"] = "xiaoli" + claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) + token, err := claims.SignedString([]byte(testKey)) + if err != nil { + panic(err) + } + token = fmt.Sprintf(bearerFormat, token) + tests := []struct { + name string + ctx context.Context + signingMethod jwt.SigningMethod + exceptErr error + key string + }{ + { + name: "normal", + ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), + signingMethod: jwt.SigningMethodHS256, + exceptErr: nil, + key: testKey, + }, + { + name: "miss token", + ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: headerCarrier{}}), + signingMethod: jwt.SigningMethodHS256, + exceptErr: ErrMissingJwtToken, + key: testKey, + }, + { + name: "token invalid", + ctx: transport.NewServerContext(context.Background(), &Transport{ + reqHeader: newTokenHeader(authorizationKey, fmt.Sprintf(bearerFormat, "12313123")), + }), + signingMethod: jwt.SigningMethodHS256, + exceptErr: ErrTokenInvalid, + key: testKey, + }, + { + name: "method invalid", + ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), + signingMethod: jwt.SigningMethodES384, + exceptErr: ErrUnSupportSigningMethod, + key: testKey, + }, + { + name: "miss signing method", + ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), + signingMethod: nil, + exceptErr: nil, + key: testKey, + }, + { + name: "miss signing method", + ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), + signingMethod: nil, + exceptErr: nil, + key: testKey, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var testToken jwt.Claims + next := func(ctx context.Context, req interface{}) (interface{}, error) { + t.Log(req) + testToken, _ = FromContext(ctx) + return "reply", nil + } + var server middleware.Handler + if test.signingMethod != nil { + server = Server(func(token *jwt.Token) (interface{}, error) { + return []byte(test.key), nil + }, WithSigningMethod(test.signingMethod))(next) + } else { + server = Server(func(token *jwt.Token) (interface{}, error) { + return []byte(test.key), nil + })(next) + } + _, err2 := server(test.ctx, test.name) + assert.Equal(t, test.exceptErr, err2) + if test.exceptErr == nil { + assert.NotNil(t, testToken) + _, ok := testToken.(jwt.MapClaims) + assert.True(t, ok) + } + }) + } +} + +func TestClient(t *testing.T) { + testKey := "testKey" + claims := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{}) + token, err := claims.SignedString([]byte(testKey)) + if err != nil { + panic(err) + } + tProvider := func(*jwt.Token) (interface{}, error) { + return []byte(testKey), nil + } + tests := []struct { + name string + expectError error + tokenProvider jwt.Keyfunc + }{ + { + name: "normal", + expectError: nil, + tokenProvider: tProvider, + }, + { + name: "miss token provider", + expectError: ErrNeedTokenProvider, + tokenProvider: nil, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + next := func(ctx context.Context, req interface{}) (interface{}, error) { + return "reply", nil + } + handler := Client(test.tokenProvider)(next) + header := &headerCarrier{} + _, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") + assert.Equal(t, test.expectError, err2) + if err2 == nil { + assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) + } + }) + } +} + +func TestTokenExpire(t *testing.T) { + testKey := "testKey" + claims := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{ + ExpiresAt: time.Now().Add(time.Millisecond).Unix(), + }) + token, err := claims.SignedString([]byte(testKey)) + if err != nil { + panic(err) + } + token = fmt.Sprintf(bearerFormat, token) + time.Sleep(time.Second) + next := func(ctx context.Context, req interface{}) (interface{}, error) { + t.Log(req) + return "reply", nil + } + ctx := transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}) + server := Server(func(token *jwt.Token) (interface{}, error) { + return []byte(testKey), nil + }, WithSigningMethod(jwt.SigningMethodHS256))(next) + _, err2 := server(ctx, "test expire token") + assert.Equal(t, ErrTokenExpired, err2) +} + +func TestMissingKeyFunc(t *testing.T) { + testKey := "testKey" + mapClaims := jwt.MapClaims{} + mapClaims["name"] = "xiaoli" + claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) + token, err := claims.SignedString([]byte(testKey)) + if err != nil { + panic(err) + } + token = fmt.Sprintf(bearerFormat, token) + test := struct { + name string + ctx context.Context + signingMethod jwt.SigningMethod + exceptErr error + key string + }{ + name: "miss key", + ctx: transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, token)}), + signingMethod: jwt.SigningMethodHS256, + exceptErr: ErrMissingKeyFunc, + key: "", + } + + var testToken jwt.Claims + next := func(ctx context.Context, req interface{}) (interface{}, error) { + t.Log(req) + testToken, _ = FromContext(ctx) + return "reply", nil + } + server := Server(nil)(next) + _, err2 := server(test.ctx, test.name) + assert.Equal(t, test.exceptErr, err2) + if test.exceptErr == nil { + assert.NotNil(t, testToken) + } +} + +func TestClientWithClaims(t *testing.T) { + testKey := "testKey" + mapClaims := jwt.MapClaims{} + mapClaims["name"] = "xiaoli" + claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) + token, err := claims.SignedString([]byte(testKey)) + if err != nil { + panic(err) + } + tProvider := func(*jwt.Token) (interface{}, error) { + return []byte(testKey), nil + } + test := struct { + name string + expectError error + tokenProvider jwt.Keyfunc + }{ + name: "normal", + expectError: nil, + tokenProvider: tProvider, + } + + t.Run(test.name, func(t *testing.T) { + next := func(ctx context.Context, req interface{}) (interface{}, error) { + return "reply", nil + } + handler := Client(test.tokenProvider, WithClaims(mapClaims))(next) + header := &headerCarrier{} + _, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") + assert.Equal(t, test.expectError, err2) + if err2 == nil { + assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) + } + }) +} + +func TestClientMissKey(t *testing.T) { + testKey := "testKey" + mapClaims := jwt.MapClaims{} + mapClaims["name"] = "xiaoli" + claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims) + token, err := claims.SignedString([]byte(testKey)) + if err != nil { + panic(err) + } + tProvider := func(*jwt.Token) (interface{}, error) { + return nil, errors.New("some error") + } + test := struct { + name string + expectError error + tokenProvider jwt.Keyfunc + }{ + name: "normal", + expectError: ErrGetKey, + tokenProvider: tProvider, + } + + t.Run(test.name, func(t *testing.T) { + next := func(ctx context.Context, req interface{}) (interface{}, error) { + return "reply", nil + } + handler := Client(test.tokenProvider, WithClaims(mapClaims))(next) + header := &headerCarrier{} + _, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok") + assert.Equal(t, test.expectError, err2) + if err2 == nil { + assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey)) + } + }) +}