From fcd3b18e8344f17c189e9bae13d79ee420387c8f Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Fri, 16 Jun 2023 19:54:56 +0800 Subject: [PATCH] fix(encoding/form): failed to decode the map type (#2468) --- encoding/form/form_test.go | 15 ++++++-- encoding/form/proto_decode.go | 51 +++++++++++++++++++++----- encoding/form/proto_decode_test.go | 59 +++++++++++++++++++++++++++++- 3 files changed, 111 insertions(+), 14 deletions(-) diff --git a/encoding/form/form_test.go b/encoding/form/form_test.go index 013edbd1d..27a52c4b1 100644 --- a/encoding/form/form_test.go +++ b/encoding/form/form_test.go @@ -99,7 +99,7 @@ func TestProtoEncodeDecode(t *testing.T) { Price: 11.23, D: 22.22, Byte: []byte("123"), - Map: map[string]string{"kratos": "https://go-kratos.dev/"}, + Map: map[string]string{"kratos": "https://go-kratos.dev/", "kratos_start": "https://go-kratos.dev/en/docs/getting-started/start/"}, Timestamp: ×tamppb.Timestamp{Seconds: 20, Nanos: 2}, Duration: &durationpb.Duration{Seconds: 120, Nanos: 22}, @@ -119,7 +119,8 @@ func TestProtoEncodeDecode(t *testing.T) { t.Fatal(err) } if "a=19&age=18&b=true&bool=false&byte=MTIz&bytes=MTIz&count=3&d=22.22&double=12.33&duration="+ - "2m0.000000022s&field=1%2C2&float=12.34&id=2233&int32=32&int64=64&map%5Bkratos%5D=https%3A%2F%2Fgo-kratos.dev%2F&"+ + "2m0.000000022s&field=1%2C2&float=12.34&id=2233&int32=32&int64=64&"+ + "map%5Bkratos%5D=https%3A%2F%2Fgo-kratos.dev%2F&map%5Bkratos_start%5D=https%3A%2F%2Fgo-kratos.dev%2Fen%2Fdocs%2Fgetting-started%2Fstart%2F&"+ "numberOne=2233&price=11.23&sex=woman&simples=3344&simples=5566&string=go-kratos"+ "×tamp=1970-01-01T00%3A00%3A20.000000002Z&uint32=32&uint64=64&very_simple.component=5566" != string(content) { t.Errorf("rawpath is not equal to %s", content) @@ -153,6 +154,14 @@ func TestProtoEncodeDecode(t *testing.T) { if "5566" != in2.Simples[1] { t.Errorf("expect %v, got %v", "5566", in2.Simples[1]) } + if l := len(in2.GetMap()); l != 2 { + t.Fatalf("in2.Map length want: %d, got: %d", 2, l) + } + for key, val := range in.GetMap() { + if in2Val := in2.GetMap()[key]; in2Val != val { + t.Errorf("%s want: %q, got: %q", "map["+key+"]", val, in2Val) + } + } } func TestDecodeStructPb(t *testing.T) { @@ -181,7 +190,7 @@ func TestDecodeBytesValuePb(t *testing.T) { content := "bytes=" + val in2 := &complex.Complex{} if err := encoding.GetCodec(Name).Unmarshal([]byte(content), in2); err != nil { - t.Error(err) + t.Fatal(err) } if url != string(in2.Bytes.Value) { t.Errorf("except %s, got %s", val, in2.Bytes.Value) diff --git a/encoding/form/proto_decode.go b/encoding/form/proto_decode.go index 1620162b4..28421fefb 100644 --- a/encoding/form/proto_decode.go +++ b/encoding/form/proto_decode.go @@ -20,6 +20,8 @@ import ( "google.golang.org/protobuf/types/known/wrapperspb" ) +var ErrInvalidFormatMapKey = errors.New("invalid formatting for map key") + // DecodeValues decode url value into proto message. func DecodeValues(msg proto.Message, values url.Values) error { for key, values := range values { @@ -77,13 +79,23 @@ func populateFieldValues(v protoreflect.Message, fieldPath []string, values []st } func getFieldDescriptor(v protoreflect.Message, fieldName string) protoreflect.FieldDescriptor { - fields := v.Descriptor().Fields() - var fd protoreflect.FieldDescriptor - if fd = getDescriptorByFieldAndName(fields, fieldName); fd == nil { - if v.Descriptor().FullName() == structMessageFullname { + var ( + fields = v.Descriptor().Fields() + fd = getDescriptorByFieldAndName(fields, fieldName) + ) + if fd == nil { + switch { + case v.Descriptor().FullName() == structMessageFullname: fd = fields.ByNumber(structFieldsFieldNumber) - } else if len(fieldName) > 2 && strings.HasSuffix(fieldName, "[]") { + case len(fieldName) > 2 && strings.HasSuffix(fieldName, "[]"): fd = getDescriptorByFieldAndName(fields, strings.TrimSuffix(fieldName, "[]")) + default: + // If the type is map, you get the string "map[kratos]", where "map" is a field of proto and "kratos" is a key of map + field, _, err := parseURLQueryMapKey(fieldName) + if err != nil { + break + } + fd = getDescriptorByFieldAndName(fields, field) } } return fd @@ -121,14 +133,20 @@ func populateRepeatedField(fd protoreflect.FieldDescriptor, list protoreflect.Li } func populateMapField(fd protoreflect.FieldDescriptor, mp protoreflect.Map, fieldPath []string, values []string) error { - // post sub key. - nkey := len(fieldPath) - 1 - key, err := parseField(fd.MapKey(), fieldPath[nkey]) + var ( + nKey = len(fieldPath) - 1 // post sub key + vKey = len(values) - 1 + fieldName = fieldPath[nKey] + ) + _, keyName, err := parseURLQueryMapKey(fieldName) + if err != nil { + return err + } + key, err := parseField(fd.MapKey(), keyName) if err != nil { return fmt.Errorf("parsing map key %q: %w", fd.FullName().Name(), err) } - vkey := len(values) - 1 - value, err := parseField(fd.MapValue(), values[vkey]) + value, err := parseField(fd.MapValue(), values[vKey]) if err != nil { return fmt.Errorf("parsing map value %q: %w", fd.FullName().Name(), err) } @@ -331,3 +349,16 @@ func jsonSnakeCase(s string) string { func isASCIIUpper(c byte) bool { return 'A' <= c && c <= 'Z' } + +// parseURLQueryMapKey parse the url.Values the field name and key name of the value map type key +// for example: convert "map[key]" to "map" and "key" +func parseURLQueryMapKey(key string) (string, string, error) { + var ( + startIndex = strings.IndexByte(key, '[') + endIndex = strings.IndexByte(key, ']') + ) + if startIndex <= 0 || startIndex >= endIndex || len(key) != endIndex+1 { + return "", "", ErrInvalidFormatMapKey + } + return key[:startIndex], key[startIndex+1 : endIndex], nil +} diff --git a/encoding/form/proto_decode_test.go b/encoding/form/proto_decode_test.go index fbe234334..986d7c333 100644 --- a/encoding/form/proto_decode_test.go +++ b/encoding/form/proto_decode_test.go @@ -87,7 +87,7 @@ func TestPopulateMapField(t *testing.T) { comp := &complex.Complex{} field := getFieldDescriptor(comp.ProtoReflect(), "map") // Fill the comp map field with the url query values - err = populateMapField(field, comp.ProtoReflect().Mutable(field).Map(), []string{"kratos"}, query["map[kratos]"]) + err = populateMapField(field, comp.ProtoReflect().Mutable(field).Map(), []string{"map[kratos]"}, query["map[kratos]"]) if err != nil { t.Fatal(err) } @@ -215,3 +215,60 @@ func TestIsASCIIUpper(t *testing.T) { }) } } + +func TestParseURLQueryMapKey(t *testing.T) { + tests := []struct { + fieldName string + field string + fieldKey string + err error + }{ + { + fieldName: "map[kratos]", field: "map", fieldKey: "kratos", err: nil, + }, + { + fieldName: "map[]", field: "map", fieldKey: "", err: nil, + }, + { + fieldName: "", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "[[]", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "map[kratos]=", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "[kratos]", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "map", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "map[", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "]kratos[", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "[kratos", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "kratos]", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + } + for _, test := range tests { + t.Run(test.fieldName, func(t *testing.T) { + fieldName, fieldKey, err := parseURLQueryMapKey(test.fieldName) + if test.err != err { + t.Fatalf("want: %s, got: %s", test.err, err) + } + if test.field != fieldName { + t.Errorf("want: %s, got: %s", test.field, fieldName) + } + if test.fieldKey != fieldKey { + t.Errorf("want: %s, got: %s", test.fieldKey, fieldKey) + } + }) + } +}