You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
kratos/encoding/form/proto_encode.go

235 lines
6.5 KiB

package form
import (
"encoding/base64"
"fmt"
"net/url"
"reflect"
"strconv"
"strings"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/fieldmaskpb"
)
func encodeValues(msg interface{}, forceTextName bool) (url.Values, error) {
if msg == nil || (reflect.ValueOf(msg).Kind() == reflect.Ptr && reflect.ValueOf(msg).IsNil()) {
return url.Values{}, nil
}
if v, ok := msg.(proto.Message); ok {
u := make(url.Values)
err := encodeByField(u, "", v.ProtoReflect(), forceTextName)
if err != nil {
return nil, err
}
return u, nil
}
return encoder.Encode(msg)
}
// EncodeValues encode a message into url values.
func EncodeValues(msg interface{}) (url.Values, error) {
return encodeValues(msg, false)
}
// EncodeTextNameValues encode a message into url values.
func EncodeTextNameValues(msg interface{}) (url.Values, error) {
return encodeValues(msg, true)
}
func encodeByField(u url.Values, path string, m protoreflect.Message, forceTextName bool) (finalErr error) {
m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
var (
key string
newPath string
)
if forceTextName {
key = fd.TextName()
} else if fd.HasJSONName() {
key = fd.JSONName()
} else {
key = fd.TextName()
}
if path == "" {
newPath = key
} else {
newPath = path + "." + key
}
if of := fd.ContainingOneof(); of != nil {
if f := m.WhichOneof(of); f != nil && f != fd {
return true
}
}
switch {
case fd.IsList():
if v.List().Len() > 0 {
err := encodeRepeatedField(fd, v.List(), u, newPath, forceTextName)
if err != nil {
finalErr = err
return false
}
}
case fd.IsMap():
if v.Map().Len() > 0 {
m, err := encodeMapField(fd, v.Map())
if err != nil {
finalErr = err
return false
}
for k, value := range m {
u.Set(fmt.Sprintf("%s[%s]", newPath, k), value)
}
}
case (fd.Kind() == protoreflect.MessageKind) || (fd.Kind() == protoreflect.GroupKind):
value, err := encodeMessage(fd.Message(), v)
if err == nil {
u.Set(newPath, value)
return true
}
if err = encodeByField(u, newPath, v.Message(), forceTextName); err != nil {
finalErr = err
return false
}
default:
value, err := EncodeField(fd, v)
if err != nil {
finalErr = err
return false
}
u.Set(newPath, value)
}
return true
})
return
}
func encodeRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, u url.Values, newPath string, forceTextName bool) error {
for i := 0; i < list.Len(); i++ {
value, err := EncodeField(fieldDescriptor, list.Get(i))
if err == nil {
u.Add(newPath, value)
} else {
if err = encodeByField(u, fmt.Sprintf("%s[%d]", newPath, i), list.Get(i).Message(), forceTextName); err != nil {
return err
}
}
}
return nil
}
func encodeMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map) (map[string]string, error) {
m := make(map[string]string)
mp.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
key, err := EncodeField(fieldDescriptor.MapValue(), k.Value())
if err != nil {
return false
}
value, err := EncodeField(fieldDescriptor.MapValue(), v)
if err != nil {
return false
}
m[key] = value
return true
})
return m, nil
}
// EncodeField encode proto message filed
func EncodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) {
switch fieldDescriptor.Kind() {
case protoreflect.BoolKind:
return strconv.FormatBool(value.Bool()), nil
case protoreflect.EnumKind:
if fieldDescriptor.Enum().FullName() == "google.protobuf.NullValue" {
return nullStr, nil
}
desc := fieldDescriptor.Enum().Values().ByNumber(value.Enum())
return string(desc.Name()), nil
case protoreflect.StringKind:
return value.String(), nil
case protoreflect.BytesKind:
return base64.URLEncoding.EncodeToString(value.Bytes()), nil
case protoreflect.MessageKind, protoreflect.GroupKind:
return encodeMessage(fieldDescriptor.Message(), value)
default:
return fmt.Sprint(value.Interface()), nil
}
}
// encodeMessage marshals the fields in the given protoreflect.Message.
// If the typeURL is non-empty, then a synthetic "@type" field is injected
// containing the URL as the value.
func encodeMessage(msgDescriptor protoreflect.MessageDescriptor, value protoreflect.Value) (string, error) {
switch msgDescriptor.FullName() {
case timestampMessageFullname:
return marshalTimestamp(value.Message())
case durationMessageFullname:
return marshalDuration(value.Message())
case bytesMessageFullname:
return marshalBytes(value.Message())
case "google.protobuf.DoubleValue", "google.protobuf.FloatValue", "google.protobuf.Int64Value", "google.protobuf.Int32Value",
"google.protobuf.UInt64Value", "google.protobuf.UInt32Value", "google.protobuf.BoolValue", "google.protobuf.StringValue":
fd := msgDescriptor.Fields()
v := value.Message().Get(fd.ByName("value"))
return fmt.Sprint(v.Interface()), nil
case fieldMaskFullName:
m, ok := value.Message().Interface().(*fieldmaskpb.FieldMask)
if !ok || m == nil {
return "", nil
}
for i, v := range m.Paths {
m.Paths[i] = jsonCamelCase(v)
}
return strings.Join(m.Paths, ","), nil
default:
return "", fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
}
}
// EncodeFieldMask return field mask name=paths
func EncodeFieldMask(m protoreflect.Message) (query string) {
m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
if fd.Kind() == protoreflect.MessageKind {
if msg := fd.Message(); msg.FullName() == fieldMaskFullName {
value, err := encodeMessage(msg, v)
if err != nil {
return false
}
if fd.HasJSONName() {
query = fd.JSONName() + "=" + value
} else {
query = fd.TextName() + "=" + value
}
return false
}
}
return true
})
return
}
// jsonCamelCase converts a snake_case identifier to a camelCase identifier,
// according to the protobuf JSON specification.
// references: https://github.com/protocolbuffers/protobuf-go/blob/master/encoding/protojson/well_known_types.go#L842
func jsonCamelCase(s string) string {
var b []byte
var wasUnderscore bool
for i := 0; i < len(s); i++ { // proto identifiers are always ASCII
c := s[i]
if c != '_' {
if wasUnderscore && isASCIILower(c) {
c -= 'a' - 'A' // convert to uppercase
}
b = append(b, c)
}
wasUnderscore = c == '_'
}
return string(b)
}
func isASCIILower(c byte) bool {
return 'a' <= c && c <= 'z'
}