diff --git a/encoding/form/proto_encode.go b/encoding/form/proto_encode.go index 13fbcde4e..411810ad9 100644 --- a/encoding/form/proto_encode.go +++ b/encoding/form/proto_encode.go @@ -7,14 +7,10 @@ import ( "reflect" "strconv" "strings" - "time" "google.golang.org/genproto/protobuf/field_mask" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/types/known/durationpb" - "google.golang.org/protobuf/types/known/timestamppb" - "google.golang.org/protobuf/types/known/wrapperspb" ) // EncodeMap encode proto message to url query. @@ -84,7 +80,7 @@ func encodeByField(u url.Values, path string, v protoreflect.Message) error { return err } default: - value, err := encodeField(fd, v.Get(fd)) + value, err := EncodeField(fd, v.Get(fd)) if err != nil { return err } @@ -98,7 +94,7 @@ func encodeByField(u url.Values, path string, v protoreflect.Message) error { func encodeRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List) ([]string, error) { var values []string for i := 0; i < list.Len(); i++ { - value, err := encodeField(fieldDescriptor, list.Get(i)) + value, err := EncodeField(fieldDescriptor, list.Get(i)) if err != nil { return nil, err } @@ -111,11 +107,11 @@ func encodeRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list prot func encodeMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map) (map[string]string, error) { m := make(map[string]string) mp.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { - key, err := encodeField(fieldDescriptor.MapValue(), k.Value()) + key, err := EncodeField(fieldDescriptor.MapValue(), k.Value()) if err != nil { return false } - value, err := encodeField(fieldDescriptor.MapValue(), v) + value, err := EncodeField(fieldDescriptor.MapValue(), v) if err != nil { return false } @@ -126,7 +122,8 @@ func encodeMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflec return m, nil } -func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) { +// EncodeField encode proto message filed +func EncodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) { switch fieldDescriptor.Kind() { case protoreflect.BoolKind: return strconv.FormatBool(value.Bool()), nil @@ -147,29 +144,17 @@ func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflec } } -// marshalMessage marshals the fields in the given protoreflect.Message. +// encodeMessage marshals the fields in the given protoreflect.Message. // If the typeURL is non-empty, then a synthetic "@type" field is injected // containing the URL as the value. func encodeMessage(msgDescriptor protoreflect.MessageDescriptor, value protoreflect.Value) (string, error) { switch msgDescriptor.FullName() { - case "google.protobuf.Timestamp": - t, ok := value.Interface().(*timestamppb.Timestamp) - if !ok { - return "", nil - } - return t.AsTime().Format(time.RFC3339Nano), nil - case "google.protobuf.Duration": - d, ok := value.Interface().(*durationpb.Duration) - if !ok { - return "", nil - } - return d.AsDuration().String(), nil - case "google.protobuf.BytesValue": - b, ok := value.Interface().(*wrapperspb.BytesValue) - if !ok { - return "", nil - } - return base64.StdEncoding.EncodeToString(b.Value), nil + case timestampMessageFullname: + return marshalTimestamp(value.Message()) + case durationMessageFullname: + return marshalDuration(value.Message()) + case bytesMessageFullname: + return marshalBytes(value.Message()) case "google.protobuf.DoubleValue", "google.protobuf.FloatValue", "google.protobuf.Int64Value", "google.protobuf.Int32Value", "google.protobuf.UInt64Value", "google.protobuf.UInt32Value", "google.protobuf.BoolValue", "google.protobuf.StringValue": fd := msgDescriptor.Fields() diff --git a/encoding/form/well_known_types.go b/encoding/form/well_known_types.go new file mode 100644 index 000000000..779594926 --- /dev/null +++ b/encoding/form/well_known_types.go @@ -0,0 +1,88 @@ +package form + +import ( + "encoding/base64" + "fmt" + "math" + "strings" + "time" + + "google.golang.org/protobuf/reflect/protoreflect" +) + +const ( + // timestamp + timestampMessageFullname protoreflect.FullName = "google.protobuf.Timestamp" + maxTimestampSeconds = 253402300799 + minTimestampSeconds = -6213559680013 + timestampSecondsFieldNumber protoreflect.FieldNumber = 1 + timestampNanosFieldNumber protoreflect.FieldNumber = 2 + + // duration + durationMessageFullname protoreflect.FullName = "google.protobuf.Duration" + secondsInNanos = 999999999 + durationSecondsFieldNumber protoreflect.FieldNumber = 1 + durationNanosFieldNumber protoreflect.FieldNumber = 2 + + // bytes + bytesMessageFullname protoreflect.FullName = "google.protobuf.BytesValue" + bytesValueFieldNumber protoreflect.FieldNumber = 1 +) + +func marshalTimestamp(m protoreflect.Message) (string, error) { + fds := m.Descriptor().Fields() + fdSeconds := fds.ByNumber(timestampSecondsFieldNumber) + fdNanos := fds.ByNumber(timestampNanosFieldNumber) + + secsVal := m.Get(fdSeconds) + nanosVal := m.Get(fdNanos) + secs := secsVal.Int() + nanos := nanosVal.Int() + if secs < minTimestampSeconds || secs > maxTimestampSeconds { + return "", fmt.Errorf("%s: seconds out of range %v", timestampMessageFullname, secs) + } + if nanos < 0 || nanos > secondsInNanos { + return "", fmt.Errorf("%s: nanos out of range %v", timestampMessageFullname, nanos) + } + // Uses RFC 3339, where generated output will be Z-normalized and uses 0, 3, + // 6 or 9 fractional digits. + t := time.Unix(secs, nanos).UTC() + x := t.Format("2006-01-02T15:04:05.000000000") + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, ".000") + return x + "Z", nil +} + +func marshalDuration(m protoreflect.Message) (string, error) { + fds := m.Descriptor().Fields() + fdSeconds := fds.ByNumber(durationSecondsFieldNumber) + fdNanos := fds.ByNumber(durationNanosFieldNumber) + + secsVal := m.Get(fdSeconds) + nanosVal := m.Get(fdNanos) + secs := secsVal.Int() + nanos := nanosVal.Int() + d := time.Duration(secs) * time.Second + overflow := d/time.Second != time.Duration(secs) + d += time.Duration(nanos) * time.Nanosecond + overflow = overflow || (secs < 0 && nanos < 0 && d > 0) + overflow = overflow || (secs > 0 && nanos > 0 && d < 0) + if overflow { + switch { + case secs < 0: + return time.Duration(math.MinInt64).String(), nil + case secs > 0: + return time.Duration(math.MaxInt64).String(), nil + } + } + return d.String(), nil +} + +func marshalBytes(m protoreflect.Message) (string, error) { + fds := m.Descriptor().Fields() + fdBytes := fds.ByNumber(bytesValueFieldNumber) + bytesVal := m.Get(fdBytes) + val := bytesVal.Bytes() + return base64.StdEncoding.EncodeToString(val), nil +} diff --git a/transport/http/binding/encode.go b/transport/http/binding/encode.go index c5dbd0cc6..fd7ea1747 100644 --- a/transport/http/binding/encode.go +++ b/transport/http/binding/encode.go @@ -1,22 +1,15 @@ package binding import ( - "encoding/base64" "fmt" "reflect" "regexp" - "strconv" "strings" - "time" "github.com/go-kratos/kratos/v2/encoding/form" - "google.golang.org/genproto/protobuf/field_mask" "google.golang.org/protobuf/proto" protoreflect "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/types/known/durationpb" - "google.golang.org/protobuf/types/known/timestamppb" - "google.golang.org/protobuf/types/known/wrapperspb" ) // EncodeURL encode proto message to url path. @@ -74,65 +67,5 @@ func getValueByField(v protoreflect.Message, fieldPath []string) (string, error) } v = v.Get(fd).Message() } - return encodeField(fd, v.Get(fd)) -} - -func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) { - switch fieldDescriptor.Kind() { - case protoreflect.BoolKind: - return strconv.FormatBool(value.Bool()), nil - case protoreflect.EnumKind: - if fieldDescriptor.Enum().FullName() == "google.protobuf.NullValue" { - return "null", nil - } - desc := fieldDescriptor.Enum().Values().ByNumber(value.Enum()) - return string(desc.Name()), nil - case protoreflect.StringKind: - return value.String(), nil - case protoreflect.BytesKind: - return base64.URLEncoding.EncodeToString(value.Bytes()), nil - case protoreflect.MessageKind, protoreflect.GroupKind: - return encodeMessage(fieldDescriptor.Message(), value) - default: - return fmt.Sprintf("%v", value.Interface()), nil - } -} - -// encodeMessage marshals the fields in the given protoreflect.Message. -// If the typeURL is non-empty, then a synthetic "@type" field is injected -// containing the URL as the value. -func encodeMessage(msgDescriptor protoreflect.MessageDescriptor, value protoreflect.Value) (string, error) { - switch msgDescriptor.FullName() { - case "google.protobuf.Timestamp": - t, ok := value.Interface().(*timestamppb.Timestamp) - if !ok { - return "", nil - } - return t.AsTime().Format(time.RFC3339Nano), nil - case "google.protobuf.Duration": - d, ok := value.Interface().(*durationpb.Duration) - if !ok { - return "", nil - } - return d.AsDuration().String(), nil - case "google.protobuf.BytesValue": - b, ok := value.Interface().(*wrapperspb.BytesValue) - if !ok { - return "", nil - } - return base64.StdEncoding.EncodeToString(b.Value), nil - case "google.protobuf.DoubleValue", "google.protobuf.FloatValue", "google.protobuf.Int64Value", "google.protobuf.Int32Value", - "google.protobuf.UInt64Value", "google.protobuf.UInt32Value", "google.protobuf.BoolValue", "google.protobuf.StringValue": - fd := msgDescriptor.Fields() - v := value.Message().Get(fd.ByName(protoreflect.Name("value"))).Message() - return fmt.Sprintf("%v", v.Interface()), nil - case "google.protobuf.FieldMask": - m, ok := value.Interface().(*field_mask.FieldMask) - if !ok { - return "", nil - } - return strings.Join(m.Paths, ","), nil - default: - return "", fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName())) - } + return form.EncodeField(fd, v.Get(fd)) }