You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
154 lines
3.5 KiB
154 lines
3.5 KiB
6 years ago
|
package auth
|
||
|
|
||
|
import (
|
||
|
"github.com/bilibili/kratos/pkg/ecode"
|
||
|
bm "github.com/bilibili/kratos/pkg/net/http/blademaster"
|
||
|
"github.com/bilibili/kratos/pkg/net/metadata"
|
||
|
)
|
||
|
|
||
|
// Config is the identify config model.
|
||
|
type Config struct {
|
||
|
// csrf switch.
|
||
|
DisableCSRF bool
|
||
|
}
|
||
|
|
||
|
// Auth is the authorization middleware
|
||
|
type Auth struct {
|
||
|
conf *Config
|
||
|
}
|
||
|
|
||
|
// authFunc will return mid and error by given context
|
||
|
type authFunc func(*bm.Context) (int64, error)
|
||
|
|
||
|
var _defaultConf = &Config{
|
||
|
DisableCSRF: false,
|
||
|
}
|
||
|
|
||
|
// New is used to create an authorization middleware
|
||
|
func New(conf *Config) *Auth {
|
||
|
if conf == nil {
|
||
|
conf = _defaultConf
|
||
|
}
|
||
|
auth := &Auth{
|
||
|
conf: conf,
|
||
|
}
|
||
|
return auth
|
||
|
}
|
||
|
|
||
|
// User is used to mark path as access required.
|
||
|
// If `access_token` is exist in request form, it will using mobile access policy.
|
||
|
// Otherwise to web access policy.
|
||
|
func (a *Auth) User(ctx *bm.Context) {
|
||
|
req := ctx.Request
|
||
|
if req.Form.Get("access_token") == "" {
|
||
|
a.UserWeb(ctx)
|
||
|
return
|
||
|
}
|
||
|
a.UserMobile(ctx)
|
||
|
}
|
||
|
|
||
|
// UserWeb is used to mark path as web access required.
|
||
|
func (a *Auth) UserWeb(ctx *bm.Context) {
|
||
|
a.midAuth(ctx, a.authCookie)
|
||
|
}
|
||
|
|
||
|
// UserMobile is used to mark path as mobile access required.
|
||
|
func (a *Auth) UserMobile(ctx *bm.Context) {
|
||
|
a.midAuth(ctx, a.authToken)
|
||
|
}
|
||
|
|
||
|
// Guest is used to mark path as guest policy.
|
||
|
// If `access_token` is exist in request form, it will using mobile access policy.
|
||
|
// Otherwise to web access policy.
|
||
|
func (a *Auth) Guest(ctx *bm.Context) {
|
||
|
req := ctx.Request
|
||
|
if req.Form.Get("access_token") == "" {
|
||
|
a.GuestWeb(ctx)
|
||
|
return
|
||
|
}
|
||
|
a.GuestMobile(ctx)
|
||
|
}
|
||
|
|
||
|
// GuestWeb is used to mark path as web guest policy.
|
||
|
func (a *Auth) GuestWeb(ctx *bm.Context) {
|
||
|
a.guestAuth(ctx, a.authCookie)
|
||
|
}
|
||
|
|
||
|
// GuestMobile is used to mark path as mobile guest policy.
|
||
|
func (a *Auth) GuestMobile(ctx *bm.Context) {
|
||
|
a.guestAuth(ctx, a.authToken)
|
||
|
}
|
||
|
|
||
|
// authToken is used to authorize request by token
|
||
|
func (a *Auth) authToken(ctx *bm.Context) (int64, error) {
|
||
|
req := ctx.Request
|
||
|
key := req.Form.Get("access_token")
|
||
|
if key == "" {
|
||
|
return 0, ecode.Unauthorized
|
||
|
}
|
||
|
// NOTE: 请求登录鉴权服务接口,拿到对应的用户id
|
||
|
var mid int64
|
||
|
// TODO: get mid from some code
|
||
|
return mid, nil
|
||
|
}
|
||
|
|
||
|
// authCookie is used to authorize request by cookie
|
||
|
func (a *Auth) authCookie(ctx *bm.Context) (int64, error) {
|
||
|
req := ctx.Request
|
||
|
session, _ := req.Cookie("SESSION")
|
||
|
if session == nil {
|
||
|
return 0, ecode.Unauthorized
|
||
|
}
|
||
|
// NOTE: 请求登录鉴权服务接口,拿到对应的用户id
|
||
|
var mid int64
|
||
|
// TODO: get mid from some code
|
||
|
|
||
|
// check csrf
|
||
|
clientCsrf := req.FormValue("csrf")
|
||
|
if a.conf != nil && !a.conf.DisableCSRF && req.Method == "POST" {
|
||
|
// NOTE: 如果开启了CSRF认证,请从CSRF服务获取该用户关联的csrf
|
||
|
var csrf string // TODO: get csrf from some code
|
||
|
if clientCsrf != csrf {
|
||
|
return 0, ecode.Unauthorized
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return mid, nil
|
||
|
}
|
||
|
|
||
|
func (a *Auth) midAuth(ctx *bm.Context, auth authFunc) {
|
||
|
mid, err := auth(ctx)
|
||
|
if err != nil {
|
||
|
ctx.JSON(nil, err)
|
||
|
ctx.Abort()
|
||
|
return
|
||
|
}
|
||
|
setMid(ctx, mid)
|
||
|
}
|
||
|
|
||
|
func (a *Auth) guestAuth(ctx *bm.Context, auth authFunc) {
|
||
|
mid, err := auth(ctx)
|
||
|
// no error happened and mid is valid
|
||
|
if err == nil && mid > 0 {
|
||
|
setMid(ctx, mid)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
ec := ecode.Cause(err)
|
||
|
if ecode.Equal(ec, ecode.Unauthorized) {
|
||
|
ctx.JSON(nil, ec)
|
||
|
ctx.Abort()
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// set mid into context
|
||
|
// NOTE: This method is not thread safe.
|
||
|
func setMid(ctx *bm.Context, mid int64) {
|
||
|
ctx.Set(metadata.Mid, mid)
|
||
|
if md, ok := metadata.FromContext(ctx); ok {
|
||
|
md[metadata.Mid] = mid
|
||
|
return
|
||
|
}
|
||
|
}
|