package http import ( "bytes" "encoding/json" "gitea.drugeyes.vip/pharnexbase/tools/request" "gitea.drugeyes.vip/pharnexbase/utils/enum" "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/transport/http" "github.com/go-kratos/kratos/v2/transport/http/binding" "github.com/tidwall/gjson" "io" "net/url" "strings" ) var SignErr = errors.Forbidden("SignError", "签名错误") func (t *Transport) RequestQueryDecoder() http.DecodeRequestFunc { return func(r *http.Request, v interface{}) error { // 将post和query参数合并 params := make(map[string]string) for _, val := range strings.Split(r.URL.RawQuery, "&") { vals := strings.Split(val, "=") if len(vals) < 2 { continue } params[vals[0]], _ = url.PathUnescape(vals[1]) } signature := request.NewSignature(t.apiKey, request.NewSHA1HashAlg()) signStr := signature.GenSignature(params) if signStr != r.Header.Get("sm5") { return SignErr } if err := binding.BindQuery(r.URL.Query(), v); err != nil { return errors.BadRequest("CODEC", err.Error()) } return nil } } // RequestDecoder 请求拦截 func (t *Transport) RequestDecoder() http.DecodeRequestFunc { return func(r *http.Request, v interface{}) error { // 从Request Header的Content-Type中提取出对应的解码器 codec, ok := http.CodecForRequest(r, "Content-Type") // 如果找不到对应的解码器此时会报错 if !ok { return errors.BadRequest("CODEC", r.Header.Get("Content-Type")) } data, err := io.ReadAll(r.Body) if err != nil { return errors.BadRequest("CODEC", err.Error()) } // 将post和query参数合并 params := make(map[string]string) gjson.ParseBytes(data).ForEach(func(key, value gjson.Result) bool { switch value.Type { case gjson.JSON: var buf bytes.Buffer json.Compact(&buf, []byte(value.String())) params[key.String()] = buf.String() default: params[key.String()] = value.String() } return true }) for _, val := range strings.Split(r.URL.RawQuery, "&") { vals := strings.Split(val, "=") if len(vals) < 2 { continue } params[vals[0]], _ = url.PathUnescape(vals[1]) } signature := request.NewSignature(t.apiKey, request.NewSHA1HashAlg()) signStr := signature.GenSignature(params) if signStr != r.Header.Get("sm5") { return SignErr } if err = codec.Unmarshal(data, v); err != nil { return errors.BadRequest("CODEC", err.Error()) } return nil } } func (t *Transport) RequestDecoderWithSignFilter(noEncrypt map[string]bool) http.DecodeRequestFunc { return func(r *http.Request, v interface{}) error { // 从Request Header的Content-Type中提取出对应的解码器 codec, ok := http.CodecForRequest(r, "Content-Type") // 如果找不到对应的解码器此时会报错 if !ok { return errors.BadRequest("CODEC", r.Header.Get("Content-Type")) } data, err := io.ReadAll(r.Body) if err != nil { return errors.BadRequest("CODEC", err.Error()) } // 将post和query参数合并 params := make(map[string]string) gjson.ParseBytes(data).ForEach(func(key, value gjson.Result) bool { switch value.Type { case gjson.JSON: var buf bytes.Buffer json.Compact(&buf, []byte(value.String())) params[key.String()] = buf.String() default: params[key.String()] = value.String() } return true }) for _, val := range strings.Split(r.URL.RawQuery, "&") { vals := strings.Split(val, "=") if len(vals) < 2 { continue } params[vals[0]], _ = url.PathUnescape(vals[1]) } // 正式服验签 if t.env == enum.Env_production && !noEncrypt[r.URL.Path] { signature := request.NewSignature(t.apiKey, request.NewSHA1HashAlg()) signStr := signature.GenSignature(params) if signStr != r.Header.Get("sm5") { return SignErr } } if err = codec.Unmarshal(data, v); err != nil { return errors.BadRequest("CODEC", err.Error()) } return nil } }