package auth

import (
	"github.com/bilibili/kratos/pkg/ecode"
	bm "github.com/bilibili/kratos/pkg/net/http/blademaster"
	"github.com/bilibili/kratos/pkg/net/metadata"
)

// Config is the identify config model.
type Config struct {
	// csrf switch.
	DisableCSRF bool
}

// Auth is the authorization middleware
type Auth struct {
	conf *Config
}

// authFunc will return mid and error by given context
type authFunc func(*bm.Context) (int64, error)

var _defaultConf = &Config{
	DisableCSRF: false,
}

// New is used to create an authorization middleware
func New(conf *Config) *Auth {
	if conf == nil {
		conf = _defaultConf
	}
	auth := &Auth{
		conf: conf,
	}
	return auth
}

// User is used to mark path as access required.
// If `access_token` is exist in request form, it will using mobile access policy.
// Otherwise to web access policy.
func (a *Auth) User(ctx *bm.Context) {
	req := ctx.Request
	if req.Form.Get("access_token") == "" {
		a.UserWeb(ctx)
		return
	}
	a.UserMobile(ctx)
}

// UserWeb is used to mark path as web access required.
func (a *Auth) UserWeb(ctx *bm.Context) {
	a.midAuth(ctx, a.authCookie)
}

// UserMobile is used to mark path as mobile access required.
func (a *Auth) UserMobile(ctx *bm.Context) {
	a.midAuth(ctx, a.authToken)
}

// Guest is used to mark path as guest policy.
// If `access_token` is exist in request form, it will using mobile access policy.
// Otherwise to web access policy.
func (a *Auth) Guest(ctx *bm.Context) {
	req := ctx.Request
	if req.Form.Get("access_token") == "" {
		a.GuestWeb(ctx)
		return
	}
	a.GuestMobile(ctx)
}

// GuestWeb is used to mark path as web guest policy.
func (a *Auth) GuestWeb(ctx *bm.Context) {
	a.guestAuth(ctx, a.authCookie)
}

// GuestMobile is used to mark path as mobile guest policy.
func (a *Auth) GuestMobile(ctx *bm.Context) {
	a.guestAuth(ctx, a.authToken)
}

// authToken is used to authorize request by token
func (a *Auth) authToken(ctx *bm.Context) (int64, error) {
	req := ctx.Request
	key := req.Form.Get("access_token")
	if key == "" {
		return 0, ecode.Unauthorized
	}
	// NOTE: 请求登录鉴权服务接口,拿到对应的用户id
	var mid int64
	// TODO: get mid from some code
	return mid, nil
}

// authCookie is used to authorize request by cookie
func (a *Auth) authCookie(ctx *bm.Context) (int64, error) {
	req := ctx.Request
	session, _ := req.Cookie("SESSION")
	if session == nil {
		return 0, ecode.Unauthorized
	}
	// NOTE: 请求登录鉴权服务接口,拿到对应的用户id
	var mid int64
	// TODO: get mid from some code

	// check csrf
	clientCsrf := req.FormValue("csrf")
	if a.conf != nil && !a.conf.DisableCSRF && req.Method == "POST" {
		// NOTE: 如果开启了CSRF认证,请从CSRF服务获取该用户关联的csrf
		var csrf string // TODO: get csrf from some code
		if clientCsrf != csrf {
			return 0, ecode.Unauthorized
		}
	}

	return mid, nil
}

func (a *Auth) midAuth(ctx *bm.Context, auth authFunc) {
	mid, err := auth(ctx)
	if err != nil {
		ctx.JSON(nil, err)
		ctx.Abort()
		return
	}
	setMid(ctx, mid)
}

func (a *Auth) guestAuth(ctx *bm.Context, auth authFunc) {
	mid, err := auth(ctx)
	// no error happened and mid is valid
	if err == nil && mid > 0 {
		setMid(ctx, mid)
		return
	}

	ec := ecode.Cause(err)
	if ecode.Equal(ec, ecode.Unauthorized) {
		ctx.JSON(nil, ec)
		ctx.Abort()
		return
	}
}

// set mid into context
// NOTE: This method is not thread safe.
func setMid(ctx *bm.Context, mid int64) {
	ctx.Set(metadata.Mid, mid)
	if md, ok := metadata.FromContext(ctx); ok {
		md[metadata.Mid] = mid
		return
	}
}