fix: fix encode proto well known types in form and url query (#1559)

* fix encode proto well known types
pull/1561/head
longxboy 3 years ago committed by GitHub
parent 014778b72a
commit 210e414e6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 41
      encoding/form/proto_encode.go
  2. 88
      encoding/form/well_known_types.go
  3. 69
      transport/http/binding/encode.go

@ -7,14 +7,10 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time"
"google.golang.org/genproto/protobuf/field_mask" "google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
) )
// EncodeMap encode proto message to url query. // EncodeMap encode proto message to url query.
@ -84,7 +80,7 @@ func encodeByField(u url.Values, path string, v protoreflect.Message) error {
return err return err
} }
default: default:
value, err := encodeField(fd, v.Get(fd)) value, err := EncodeField(fd, v.Get(fd))
if err != nil { if err != nil {
return err return err
} }
@ -98,7 +94,7 @@ func encodeByField(u url.Values, path string, v protoreflect.Message) error {
func encodeRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List) ([]string, error) { func encodeRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List) ([]string, error) {
var values []string var values []string
for i := 0; i < list.Len(); i++ { for i := 0; i < list.Len(); i++ {
value, err := encodeField(fieldDescriptor, list.Get(i)) value, err := EncodeField(fieldDescriptor, list.Get(i))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -111,11 +107,11 @@ func encodeRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list prot
func encodeMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map) (map[string]string, error) { func encodeMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map) (map[string]string, error) {
m := make(map[string]string) m := make(map[string]string)
mp.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { mp.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
key, err := encodeField(fieldDescriptor.MapValue(), k.Value()) key, err := EncodeField(fieldDescriptor.MapValue(), k.Value())
if err != nil { if err != nil {
return false return false
} }
value, err := encodeField(fieldDescriptor.MapValue(), v) value, err := EncodeField(fieldDescriptor.MapValue(), v)
if err != nil { if err != nil {
return false return false
} }
@ -126,7 +122,8 @@ func encodeMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflec
return m, nil return m, nil
} }
func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) { // EncodeField encode proto message filed
func EncodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) {
switch fieldDescriptor.Kind() { switch fieldDescriptor.Kind() {
case protoreflect.BoolKind: case protoreflect.BoolKind:
return strconv.FormatBool(value.Bool()), nil return strconv.FormatBool(value.Bool()), nil
@ -147,29 +144,17 @@ func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflec
} }
} }
// marshalMessage marshals the fields in the given protoreflect.Message. // encodeMessage marshals the fields in the given protoreflect.Message.
// If the typeURL is non-empty, then a synthetic "@type" field is injected // If the typeURL is non-empty, then a synthetic "@type" field is injected
// containing the URL as the value. // containing the URL as the value.
func encodeMessage(msgDescriptor protoreflect.MessageDescriptor, value protoreflect.Value) (string, error) { func encodeMessage(msgDescriptor protoreflect.MessageDescriptor, value protoreflect.Value) (string, error) {
switch msgDescriptor.FullName() { switch msgDescriptor.FullName() {
case "google.protobuf.Timestamp": case timestampMessageFullname:
t, ok := value.Interface().(*timestamppb.Timestamp) return marshalTimestamp(value.Message())
if !ok { case durationMessageFullname:
return "", nil return marshalDuration(value.Message())
} case bytesMessageFullname:
return t.AsTime().Format(time.RFC3339Nano), nil return marshalBytes(value.Message())
case "google.protobuf.Duration":
d, ok := value.Interface().(*durationpb.Duration)
if !ok {
return "", nil
}
return d.AsDuration().String(), nil
case "google.protobuf.BytesValue":
b, ok := value.Interface().(*wrapperspb.BytesValue)
if !ok {
return "", nil
}
return base64.StdEncoding.EncodeToString(b.Value), nil
case "google.protobuf.DoubleValue", "google.protobuf.FloatValue", "google.protobuf.Int64Value", "google.protobuf.Int32Value", case "google.protobuf.DoubleValue", "google.protobuf.FloatValue", "google.protobuf.Int64Value", "google.protobuf.Int32Value",
"google.protobuf.UInt64Value", "google.protobuf.UInt32Value", "google.protobuf.BoolValue", "google.protobuf.StringValue": "google.protobuf.UInt64Value", "google.protobuf.UInt32Value", "google.protobuf.BoolValue", "google.protobuf.StringValue":
fd := msgDescriptor.Fields() fd := msgDescriptor.Fields()

@ -0,0 +1,88 @@
package form
import (
"encoding/base64"
"fmt"
"math"
"strings"
"time"
"google.golang.org/protobuf/reflect/protoreflect"
)
const (
// timestamp
timestampMessageFullname protoreflect.FullName = "google.protobuf.Timestamp"
maxTimestampSeconds = 253402300799
minTimestampSeconds = -6213559680013
timestampSecondsFieldNumber protoreflect.FieldNumber = 1
timestampNanosFieldNumber protoreflect.FieldNumber = 2
// duration
durationMessageFullname protoreflect.FullName = "google.protobuf.Duration"
secondsInNanos = 999999999
durationSecondsFieldNumber protoreflect.FieldNumber = 1
durationNanosFieldNumber protoreflect.FieldNumber = 2
// bytes
bytesMessageFullname protoreflect.FullName = "google.protobuf.BytesValue"
bytesValueFieldNumber protoreflect.FieldNumber = 1
)
func marshalTimestamp(m protoreflect.Message) (string, error) {
fds := m.Descriptor().Fields()
fdSeconds := fds.ByNumber(timestampSecondsFieldNumber)
fdNanos := fds.ByNumber(timestampNanosFieldNumber)
secsVal := m.Get(fdSeconds)
nanosVal := m.Get(fdNanos)
secs := secsVal.Int()
nanos := nanosVal.Int()
if secs < minTimestampSeconds || secs > maxTimestampSeconds {
return "", fmt.Errorf("%s: seconds out of range %v", timestampMessageFullname, secs)
}
if nanos < 0 || nanos > secondsInNanos {
return "", fmt.Errorf("%s: nanos out of range %v", timestampMessageFullname, nanos)
}
// Uses RFC 3339, where generated output will be Z-normalized and uses 0, 3,
// 6 or 9 fractional digits.
t := time.Unix(secs, nanos).UTC()
x := t.Format("2006-01-02T15:04:05.000000000")
x = strings.TrimSuffix(x, "000")
x = strings.TrimSuffix(x, "000")
x = strings.TrimSuffix(x, ".000")
return x + "Z", nil
}
func marshalDuration(m protoreflect.Message) (string, error) {
fds := m.Descriptor().Fields()
fdSeconds := fds.ByNumber(durationSecondsFieldNumber)
fdNanos := fds.ByNumber(durationNanosFieldNumber)
secsVal := m.Get(fdSeconds)
nanosVal := m.Get(fdNanos)
secs := secsVal.Int()
nanos := nanosVal.Int()
d := time.Duration(secs) * time.Second
overflow := d/time.Second != time.Duration(secs)
d += time.Duration(nanos) * time.Nanosecond
overflow = overflow || (secs < 0 && nanos < 0 && d > 0)
overflow = overflow || (secs > 0 && nanos > 0 && d < 0)
if overflow {
switch {
case secs < 0:
return time.Duration(math.MinInt64).String(), nil
case secs > 0:
return time.Duration(math.MaxInt64).String(), nil
}
}
return d.String(), nil
}
func marshalBytes(m protoreflect.Message) (string, error) {
fds := m.Descriptor().Fields()
fdBytes := fds.ByNumber(bytesValueFieldNumber)
bytesVal := m.Get(fdBytes)
val := bytesVal.Bytes()
return base64.StdEncoding.EncodeToString(val), nil
}

@ -1,22 +1,15 @@
package binding package binding
import ( import (
"encoding/base64"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
"strconv"
"strings" "strings"
"time"
"github.com/go-kratos/kratos/v2/encoding/form" "github.com/go-kratos/kratos/v2/encoding/form"
"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
) )
// EncodeURL encode proto message to url path. // EncodeURL encode proto message to url path.
@ -74,65 +67,5 @@ func getValueByField(v protoreflect.Message, fieldPath []string) (string, error)
} }
v = v.Get(fd).Message() v = v.Get(fd).Message()
} }
return encodeField(fd, v.Get(fd)) return form.EncodeField(fd, v.Get(fd))
}
func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) {
switch fieldDescriptor.Kind() {
case protoreflect.BoolKind:
return strconv.FormatBool(value.Bool()), nil
case protoreflect.EnumKind:
if fieldDescriptor.Enum().FullName() == "google.protobuf.NullValue" {
return "null", nil
}
desc := fieldDescriptor.Enum().Values().ByNumber(value.Enum())
return string(desc.Name()), nil
case protoreflect.StringKind:
return value.String(), nil
case protoreflect.BytesKind:
return base64.URLEncoding.EncodeToString(value.Bytes()), nil
case protoreflect.MessageKind, protoreflect.GroupKind:
return encodeMessage(fieldDescriptor.Message(), value)
default:
return fmt.Sprintf("%v", value.Interface()), nil
}
}
// encodeMessage marshals the fields in the given protoreflect.Message.
// If the typeURL is non-empty, then a synthetic "@type" field is injected
// containing the URL as the value.
func encodeMessage(msgDescriptor protoreflect.MessageDescriptor, value protoreflect.Value) (string, error) {
switch msgDescriptor.FullName() {
case "google.protobuf.Timestamp":
t, ok := value.Interface().(*timestamppb.Timestamp)
if !ok {
return "", nil
}
return t.AsTime().Format(time.RFC3339Nano), nil
case "google.protobuf.Duration":
d, ok := value.Interface().(*durationpb.Duration)
if !ok {
return "", nil
}
return d.AsDuration().String(), nil
case "google.protobuf.BytesValue":
b, ok := value.Interface().(*wrapperspb.BytesValue)
if !ok {
return "", nil
}
return base64.StdEncoding.EncodeToString(b.Value), nil
case "google.protobuf.DoubleValue", "google.protobuf.FloatValue", "google.protobuf.Int64Value", "google.protobuf.Int32Value",
"google.protobuf.UInt64Value", "google.protobuf.UInt32Value", "google.protobuf.BoolValue", "google.protobuf.StringValue":
fd := msgDescriptor.Fields()
v := value.Message().Get(fd.ByName(protoreflect.Name("value"))).Message()
return fmt.Sprintf("%v", v.Interface()), nil
case "google.protobuf.FieldMask":
m, ok := value.Interface().(*field_mask.FieldMask)
if !ok {
return "", nil
}
return strings.Join(m.Paths, ","), nil
default:
return "", fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
}
} }

Loading…
Cancel
Save