You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
kratos/transport/http/binding/encode.go

250 lines
6.9 KiB

package binding
import (
"encoding/base64"
"fmt"
"net/url"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
// EncodeURL encode proto message to url path.
func EncodeURL(pathTemplate string, msg proto.Message, needQuery bool) string {
if msg == nil || (reflect.ValueOf(msg).Kind() == reflect.Ptr && reflect.ValueOf(msg).IsNil()) {
return pathTemplate
}
reg := regexp.MustCompile(`/{[.\w]+}`)
if reg == nil {
return pathTemplate
}
pathParams := make(map[string]struct{}, 0)
path := reg.ReplaceAllStringFunc(pathTemplate, func(in string) string {
if len(in) < 4 {
return in
}
key := in[2 : len(in)-1]
vars := strings.Split(key, ".")
if value, err := getValueByField(msg.ProtoReflect(), vars); err == nil {
pathParams[key] = struct{}{}
return "/" + value
}
return in
})
if needQuery {
u, err := EncodeQuery(msg)
if err == nil && len(u) > 0 {
for key := range pathParams {
delete(u, key)
}
query := u.Encode()
if query != "" {
path += "?" + query
}
}
}
return path
}
// EncodeQuery encode proto message to url query.
func EncodeQuery(msg proto.Message) (url.Values, error) {
if msg == nil || (reflect.ValueOf(msg).Kind() == reflect.Ptr && reflect.ValueOf(msg).IsNil()) {
return url.Values{}, nil
}
u := make(url.Values)
err := encodeByField(u, "", msg.ProtoReflect())
if err != nil {
return nil, err
}
return u, nil
}
func getValueByField(v protoreflect.Message, fieldPath []string) (string, error) {
var fd protoreflect.FieldDescriptor
for i, fieldName := range fieldPath {
fields := v.Descriptor().Fields()
if fd = fields.ByName(protoreflect.Name(fieldName)); fd == nil {
fd = fields.ByJSONName(fieldName)
if fd == nil {
return "", fmt.Errorf("field path not found: %q", fieldName)
}
}
if i == len(fieldPath)-1 {
break
}
if fd.Message() == nil || fd.Cardinality() == protoreflect.Repeated {
return "", fmt.Errorf("invalid path: %q is not a message", fieldName)
}
v = v.Get(fd).Message()
}
return encodeField(fd, v.Get(fd))
}
func encodeByField(u url.Values, path string, v protoreflect.Message) error {
for i := 0; i < v.Descriptor().Fields().Len(); i++ {
fd := v.Descriptor().Fields().Get(i)
var key string
var newPath string
if fd.HasJSONName() {
key = fd.JSONName()
} else {
key = fd.TextName()
}
if path == "" {
newPath = key
} else {
newPath = path + "." + key
}
if of := fd.ContainingOneof(); of != nil {
if f := v.WhichOneof(of); f != nil {
if f != fd {
continue
}
}
continue
}
switch {
case fd.IsList():
if v.Get(fd).List().Len() > 0 {
list, err := encodeRepeatedField(fd, v.Get(fd).List())
if err != nil {
return err
}
u[newPath] = list
}
case fd.IsMap():
if v.Get(fd).Map().Len() > 0 {
m, err := encodeMapField(fd, v.Get(fd).Map())
if err != nil {
return err
}
for k, value := range m {
u[fmt.Sprintf("%s[%s]", newPath, k)] = []string{value}
}
}
case (fd.Kind() == protoreflect.MessageKind) || (fd.Kind() == protoreflect.GroupKind):
value, err := encodeMessage(fd.Message(), v.Get(fd))
if err == nil {
u[newPath] = []string{value}
continue
}
err = encodeByField(u, newPath, v.Get(fd).Message())
if err != nil {
return err
}
default:
value, err := encodeField(fd, v.Get(fd))
if err != nil {
return err
}
u[newPath] = []string{value}
}
}
return nil
}
func encodeMessageField(fieldDescriptor protoreflect.FieldDescriptor, msgValue protoreflect.Message) (string, error) {
return encodeField(fieldDescriptor, msgValue.Get(fieldDescriptor))
}
func encodeRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List) ([]string, error) {
var values []string
for i := 0; i < list.Len(); i++ {
value, err := encodeField(fieldDescriptor, list.Get(i))
if err != nil {
return nil, err
}
values = append(values, value)
}
return values, nil
}
func encodeMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map) (map[string]string, error) {
m := make(map[string]string)
mp.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
key, err := encodeField(fieldDescriptor.MapValue(), k.Value())
if err != nil {
return false
}
value, err := encodeField(fieldDescriptor.MapValue(), v)
if err != nil {
return false
}
m[key] = value
return true
})
return m, nil
}
func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) {
switch fieldDescriptor.Kind() {
case protoreflect.BoolKind:
return strconv.FormatBool(value.Bool()), nil
case protoreflect.EnumKind:
if fieldDescriptor.Enum().FullName() == "google.protobuf.NullValue" {
return "null", nil
}
desc := fieldDescriptor.Enum().Values().ByNumber(value.Enum())
return string(desc.Name()), nil
case protoreflect.StringKind:
return value.String(), nil
case protoreflect.BytesKind:
return base64.URLEncoding.EncodeToString(value.Bytes()), nil
case protoreflect.MessageKind, protoreflect.GroupKind:
return encodeMessage(fieldDescriptor.Message(), value)
default:
return fmt.Sprintf("%v", value.Interface()), nil
}
}
// marshalMessage marshals the fields in the given protoreflect.Message.
// If the typeURL is non-empty, then a synthetic "@type" field is injected
// containing the URL as the value.
func encodeMessage(msgDescriptor protoreflect.MessageDescriptor, value protoreflect.Value) (string, error) {
switch msgDescriptor.FullName() {
case "google.protobuf.Timestamp":
t, ok := value.Interface().(*timestamppb.Timestamp)
if !ok {
return "", nil
}
return t.AsTime().Format(time.RFC3339Nano), nil
case "google.protobuf.Duration":
d, ok := value.Interface().(*durationpb.Duration)
if !ok {
return "", nil
}
return d.AsDuration().String(), nil
case "google.protobuf.BytesValue":
b, ok := value.Interface().(*wrapperspb.BytesValue)
if !ok {
return "", nil
}
return base64.StdEncoding.EncodeToString(b.Value), nil
case "google.protobuf.DoubleValue", "google.protobuf.FloatValue", "google.protobuf.Int64Value", "google.protobuf.Int32Value",
"google.protobuf.UInt64Value", "google.protobuf.UInt32Value", "google.protobuf.BoolValue", "google.protobuf.StringValue":
fd := msgDescriptor.Fields()
v := value.Message().Get(fd.ByName(protoreflect.Name("value"))).Message()
return fmt.Sprintf("%v", v.Interface()), nil
case "google.protobuf.FieldMask":
m, ok := value.Interface().(*field_mask.FieldMask)
if !ok {
return "", nil
}
return strings.Join(m.Paths, ","), nil
default:
return "", fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
}
}