通用包
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
utils/transport/v1/http/request.go

142 lines
3.9 KiB

package http
import (
"bytes"
"encoding/json"
"gitea.drugeyes.vip/pharnexbase/tools/request"
"gitea.drugeyes.vip/pharnexbase/utils/enum"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/transport/http"
"github.com/go-kratos/kratos/v2/transport/http/binding"
"github.com/tidwall/gjson"
"io"
"net/url"
"strings"
)
var SignErr = errors.Forbidden("SignError", "签名错误")
func (t *Transport) RequestQueryDecoder() http.DecodeRequestFunc {
return func(r *http.Request, v interface{}) error {
// 将post和query参数合并
params := make(map[string]string)
for _, val := range strings.Split(r.URL.RawQuery, "&") {
vals := strings.Split(val, "=")
if len(vals) < 2 {
continue
}
params[vals[0]], _ = url.PathUnescape(vals[1])
}
signature := request.NewSignature(t.apiKey, request.NewSHA1HashAlg())
signStr := signature.GenSignature(params)
if signStr != r.Header.Get("sm5") {
return SignErr
}
if err := binding.BindQuery(r.URL.Query(), v); err != nil {
return errors.BadRequest("CODEC", err.Error())
}
return nil
}
}
// RequestDecoder 请求拦截
func (t *Transport) RequestDecoder() http.DecodeRequestFunc {
return func(r *http.Request, v interface{}) error {
// 从Request Header的Content-Type中提取出对应的解码器
codec, ok := http.CodecForRequest(r, "Content-Type")
// 如果找不到对应的解码器此时会报错
if !ok {
return errors.BadRequest("CODEC", r.Header.Get("Content-Type"))
}
data, err := io.ReadAll(r.Body)
if err != nil {
return errors.BadRequest("CODEC", err.Error())
}
// 将post和query参数合并
params := make(map[string]string)
gjson.ParseBytes(data).ForEach(func(key, value gjson.Result) bool {
switch value.Type {
case gjson.JSON:
var buf bytes.Buffer
json.Compact(&buf, []byte(value.String()))
params[key.String()] = buf.String()
default:
params[key.String()] = value.String()
}
return true
})
for _, val := range strings.Split(r.URL.RawQuery, "&") {
vals := strings.Split(val, "=")
if len(vals) < 2 {
continue
}
params[vals[0]], _ = url.PathUnescape(vals[1])
}
signature := request.NewSignature(t.apiKey, request.NewSHA1HashAlg())
signStr := signature.GenSignature(params)
if signStr != r.Header.Get("sm5") {
return SignErr
}
if err = codec.Unmarshal(data, v); err != nil {
return errors.BadRequest("CODEC", err.Error())
}
return nil
}
}
func (t *Transport) RequestDecoderWithSignFilter(noEncrypt map[string]bool) http.DecodeRequestFunc {
return func(r *http.Request, v interface{}) error {
// 从Request Header的Content-Type中提取出对应的解码器
codec, ok := http.CodecForRequest(r, "Content-Type")
// 如果找不到对应的解码器此时会报错
if !ok {
return errors.BadRequest("CODEC", r.Header.Get("Content-Type"))
}
data, err := io.ReadAll(r.Body)
if err != nil {
return errors.BadRequest("CODEC", err.Error())
}
// 将post和query参数合并
params := make(map[string]string)
gjson.ParseBytes(data).ForEach(func(key, value gjson.Result) bool {
switch value.Type {
case gjson.JSON:
var buf bytes.Buffer
json.Compact(&buf, []byte(value.String()))
params[key.String()] = buf.String()
default:
params[key.String()] = value.String()
}
return true
})
for _, val := range strings.Split(r.URL.RawQuery, "&") {
vals := strings.Split(val, "=")
if len(vals) < 2 {
continue
}
params[vals[0]], _ = url.PathUnescape(vals[1])
}
// 正式服验签
if t.env == enum.Env_production && !noEncrypt[r.URL.Path] {
signature := request.NewSignature(t.apiKey, request.NewSHA1HashAlg())
signStr := signature.GenSignature(params)
if signStr != r.Header.Get("sm5") {
return SignErr
}
}
if err = codec.Unmarshal(data, v); err != nil {
return errors.BadRequest("CODEC", err.Error())
}
return nil
}
}