package binding import ( "encoding/base64" "fmt" "reflect" "regexp" "strconv" "strings" "time" "github.com/go-kratos/kratos/v2/encoding/form" "google.golang.org/genproto/protobuf/field_mask" "google.golang.org/protobuf/proto" protoreflect "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" ) // EncodeURL encode proto message to url path. func EncodeURL(pathTemplate string, msg proto.Message, needQuery bool) string { if msg == nil || (reflect.ValueOf(msg).Kind() == reflect.Ptr && reflect.ValueOf(msg).IsNil()) { return pathTemplate } reg := regexp.MustCompile(`/{[.\w]+}`) if reg == nil { return pathTemplate } pathParams := make(map[string]struct{}) path := reg.ReplaceAllStringFunc(pathTemplate, func(in string) string { if len(in) < 4 { return in } key := in[2 : len(in)-1] vars := strings.Split(key, ".") if value, err := getValueByField(msg.ProtoReflect(), vars); err == nil { pathParams[key] = struct{}{} return "/" + value } return in }) if needQuery { u, err := form.EncodeMap(msg) if err == nil && len(u) > 0 { for key := range pathParams { delete(u, key) } query := u.Encode() if query != "" { path += "?" + query } } } return path } func getValueByField(v protoreflect.Message, fieldPath []string) (string, error) { var fd protoreflect.FieldDescriptor for i, fieldName := range fieldPath { fields := v.Descriptor().Fields() if fd = fields.ByJSONName(fieldName); fd == nil { fd = fields.ByName(protoreflect.Name(fieldName)) if fd == nil { return "", fmt.Errorf("field path not found: %q", fieldName) } } if i == len(fieldPath)-1 { break } if fd.Message() == nil || fd.Cardinality() == protoreflect.Repeated { return "", fmt.Errorf("invalid path: %q is not a message", fieldName) } v = v.Get(fd).Message() } return encodeField(fd, v.Get(fd)) } func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) { switch fieldDescriptor.Kind() { case protoreflect.BoolKind: return strconv.FormatBool(value.Bool()), nil case protoreflect.EnumKind: if fieldDescriptor.Enum().FullName() == "google.protobuf.NullValue" { return "null", nil } desc := fieldDescriptor.Enum().Values().ByNumber(value.Enum()) return string(desc.Name()), nil case protoreflect.StringKind: return value.String(), nil case protoreflect.BytesKind: return base64.URLEncoding.EncodeToString(value.Bytes()), nil case protoreflect.MessageKind, protoreflect.GroupKind: return encodeMessage(fieldDescriptor.Message(), value) default: return fmt.Sprintf("%v", value.Interface()), nil } } // marshalMessage marshals the fields in the given protoreflect.Message. // If the typeURL is non-empty, then a synthetic "@type" field is injected // containing the URL as the value. func encodeMessage(msgDescriptor protoreflect.MessageDescriptor, value protoreflect.Value) (string, error) { switch msgDescriptor.FullName() { case "google.protobuf.Timestamp": t, ok := value.Interface().(*timestamppb.Timestamp) if !ok { return "", nil } return t.AsTime().Format(time.RFC3339Nano), nil case "google.protobuf.Duration": d, ok := value.Interface().(*durationpb.Duration) if !ok { return "", nil } return d.AsDuration().String(), nil case "google.protobuf.BytesValue": b, ok := value.Interface().(*wrapperspb.BytesValue) if !ok { return "", nil } return base64.StdEncoding.EncodeToString(b.Value), nil case "google.protobuf.DoubleValue", "google.protobuf.FloatValue", "google.protobuf.Int64Value", "google.protobuf.Int32Value", "google.protobuf.UInt64Value", "google.protobuf.UInt32Value", "google.protobuf.BoolValue", "google.protobuf.StringValue": fd := msgDescriptor.Fields() v := value.Message().Get(fd.ByName(protoreflect.Name("value"))).Message() return fmt.Sprintf("%v", v.Interface()), nil case "google.protobuf.FieldMask": m, ok := value.Interface().(*field_mask.FieldMask) if !ok { return "", nil } return strings.Join(m.Paths, ","), nil default: return "", fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName())) } }