diff --git a/encoding/form/form.go b/encoding/form/form.go index 823aeb8f1..ba0447e47 100644 --- a/encoding/form/form.go +++ b/encoding/form/form.go @@ -4,10 +4,10 @@ import ( "net/url" "reflect" - "github.com/go-kratos/kratos/v2/encoding" - "github.com/go-playground/form/v4" "google.golang.org/protobuf/proto" + + "github.com/go-kratos/kratos/v2/encoding" ) const ( @@ -70,7 +70,8 @@ func (c codec) Unmarshal(data []byte, v interface{}) error { } if m, ok := v.(proto.Message); ok { return DecodeValues(m, vs) - } else if m, ok := reflect.Indirect(reflect.ValueOf(v)).Interface().(proto.Message); ok { + } + if m, ok := rv.Interface().(proto.Message); ok { return DecodeValues(m, vs) } diff --git a/encoding/form/form_test.go b/encoding/form/form_test.go index ddd053cda..647d1edc8 100644 --- a/encoding/form/form_test.go +++ b/encoding/form/form_test.go @@ -26,14 +26,12 @@ type TestModel struct { Name string `json:"name"` } -const contentType = "x-www-form-urlencoded" - func TestFormCodecMarshal(t *testing.T) { req := &LoginRequest{ Username: "kratos", Password: "kratos_pwd", } - content, err := encoding.GetCodec(contentType).Marshal(req) + content, err := encoding.GetCodec(Name).Marshal(req) if err != nil { t.Errorf("marshal error: %v", err) } @@ -45,7 +43,7 @@ func TestFormCodecMarshal(t *testing.T) { Username: "kratos", Password: "", } - content, err = encoding.GetCodec(contentType).Marshal(req) + content, err = encoding.GetCodec(Name).Marshal(req) if err != nil { t.Errorf("expect %v, got %v", nil, err) } @@ -57,7 +55,7 @@ func TestFormCodecMarshal(t *testing.T) { ID: 1, Name: "kratos", } - content, err = encoding.GetCodec(contentType).Marshal(m) + content, err = encoding.GetCodec(Name).Marshal(m) t.Log(string(content)) if err != nil { t.Errorf("expect %v, got %v", nil, err) @@ -72,13 +70,13 @@ func TestFormCodecUnmarshal(t *testing.T) { Username: "kratos", Password: "kratos_pwd", } - content, err := encoding.GetCodec(contentType).Marshal(req) + content, err := encoding.GetCodec(Name).Marshal(req) if err != nil { t.Errorf("expect %v, got %v", nil, err) } bindReq := new(LoginRequest) - err = encoding.GetCodec(contentType).Unmarshal(content, bindReq) + err = encoding.GetCodec(Name).Unmarshal(content, bindReq) if err != nil { t.Errorf("expect %v, got %v", nil, err) } @@ -119,7 +117,7 @@ func TestProtoEncodeDecode(t *testing.T) { String_: &wrapperspb.StringValue{Value: "go-kratos"}, Bytes: &wrapperspb.BytesValue{Value: []byte("123")}, } - content, err := encoding.GetCodec(contentType).Marshal(in) + content, err := encoding.GetCodec(Name).Marshal(in) if err != nil { t.Errorf("expect %v, got %v", nil, err) } @@ -130,7 +128,7 @@ func TestProtoEncodeDecode(t *testing.T) { t.Errorf("rawpath is not equal to %v", string(content)) } in2 := &complex.Complex{} - err = encoding.GetCodec(contentType).Unmarshal(content, in2) + err = encoding.GetCodec(Name).Unmarshal(content, in2) if err != nil { t.Errorf("expect %v, got %v", nil, err) } @@ -163,7 +161,7 @@ func TestProtoEncodeDecode(t *testing.T) { func TestDecodeStructPb(t *testing.T) { req := new(ectest.StructPb) query := `data={"name":"kratos"}&data_list={"name1": "kratos"}&data_list={"name2": "go-kratos"}` - if err := encoding.GetCodec(contentType).Unmarshal([]byte(query), req); err != nil { + if err := encoding.GetCodec(Name).Unmarshal([]byte(query), req); err != nil { t.Errorf("expect %v, got %v", nil, err) } if !reflect.DeepEqual("kratos", req.Data.GetFields()["name"].GetStringValue()) { @@ -186,7 +184,7 @@ func TestDecodeBytesValuePb(t *testing.T) { val := base64.URLEncoding.EncodeToString([]byte(url)) content := "bytes=" + val in2 := &complex.Complex{} - if err := encoding.GetCodec(contentType).Unmarshal([]byte(content), in2); err != nil { + if err := encoding.GetCodec(Name).Unmarshal([]byte(content), in2); err != nil { t.Errorf("expect %v, got %v", nil, err) } if !reflect.DeepEqual(url, string(in2.Bytes.Value)) { diff --git a/encoding/form/proto_decode.go b/encoding/form/proto_decode.go index fcf46d469..1620162b4 100644 --- a/encoding/form/proto_decode.go +++ b/encoding/form/proto_decode.go @@ -10,13 +10,12 @@ import ( "time" "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/types/known/structpb" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/fieldmaskpb" + "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -122,15 +121,13 @@ func populateRepeatedField(fd protoreflect.FieldDescriptor, list protoreflect.Li } func populateMapField(fd protoreflect.FieldDescriptor, mp protoreflect.Map, fieldPath []string, values []string) error { - flen := len(fieldPath) - vlen := len(values) // post sub key. - nkey := flen - 1 + nkey := len(fieldPath) - 1 key, err := parseField(fd.MapKey(), fieldPath[nkey]) if err != nil { return fmt.Errorf("parsing map key %q: %w", fd.FullName().Name(), err) } - vkey := vlen - 1 + vkey := len(values) - 1 value, err := parseField(fd.MapValue(), values[vkey]) if err != nil { return fmt.Errorf("parsing map value %q: %w", fd.FullName().Name(), err) diff --git a/encoding/form/proto_decode_test.go b/encoding/form/proto_decode_test.go new file mode 100644 index 000000000..fbe234334 --- /dev/null +++ b/encoding/form/proto_decode_test.go @@ -0,0 +1,217 @@ +package form + +import ( + "fmt" + "net/url" + "reflect" + "strconv" + "testing" + + "google.golang.org/protobuf/reflect/protoreflect" + + "github.com/go-kratos/kratos/v2/internal/testdata/complex" +) + +func TestDecodeValues(t *testing.T) { + form, err := url.ParseQuery("a=19&age=18&b=true&bool=false&byte=MTIz&bytes=MTIz&count=3&d=22.22&double=12.33&duration=" + + "2m0.000000022s&field=1%2C2&float=12.34&id=2233&int32=32&int64=64&numberOne=2233&price=11.23&sex=woman&simples=3344&" + + "simples=5566&string=go-kratos×tamp=1970-01-01T00%3A00%3A20.000000002Z&uint32=32&uint64=64&very_simple.component=5566") + if err != nil { + t.Fatal(err) + } + + comp := &complex.Complex{} + err = DecodeValues(comp, form) + if err != nil { + t.Fatal(err) + } + if comp.Id != int64(2233) { + t.Errorf("want %v, got %v", int64(2233), comp.Id) + } + if comp.NoOne != "2233" { + t.Errorf("want %v, got %v", "2233", comp.NoOne) + } + if comp.Simple == nil { + t.Fatalf("want %v, got %v", nil, comp.Simple) + } + if comp.Simple.Component != "5566" { + t.Errorf("want %v, got %v", "5566", comp.Simple.Component) + } + if len(comp.Simples) != 2 { + t.Fatalf("want %v, got %v", 2, len(comp.Simples)) + } + if comp.Simples[0] != "3344" { + t.Errorf("want %v, got %v", "3344", comp.Simples[0]) + } + if comp.Simples[1] != "5566" { + t.Errorf("want %v, got %v", "5566", comp.Simples[1]) + } +} + +func TestGetFieldDescriptor(t *testing.T) { + comp := &complex.Complex{} + + field := getFieldDescriptor(comp.ProtoReflect(), "id") + if field.Kind() != protoreflect.Int64Kind { + t.Errorf("want: %d, got: %d", protoreflect.Int64Kind, field.Kind()) + } + + field = getFieldDescriptor(comp.ProtoReflect(), "simples") + if field.Kind() != protoreflect.StringKind { + t.Errorf("want: %d, got: %d", protoreflect.StringKind, field.Kind()) + } +} + +func TestPopulateRepeatedField(t *testing.T) { + query, err := url.ParseQuery("simples=3344&simples=5566") + if err != nil { + t.Fatal(err) + } + comp := &complex.Complex{} + field := getFieldDescriptor(comp.ProtoReflect(), "simples") + + err = populateRepeatedField(field, comp.ProtoReflect().Mutable(field).List(), query["simples"]) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual([]string{"3344", "5566"}, comp.GetSimples()) { + t.Errorf("want: %v, got: %v", []string{"3344", "5566"}, comp.GetSimples()) + } +} + +func TestPopulateMapField(t *testing.T) { + query, err := url.ParseQuery("map%5Bkratos%5D=https://go-kratos.dev/") + if err != nil { + t.Fatal(err) + } + comp := &complex.Complex{} + field := getFieldDescriptor(comp.ProtoReflect(), "map") + // Fill the comp map field with the url query values + err = populateMapField(field, comp.ProtoReflect().Mutable(field).Map(), []string{"kratos"}, query["map[kratos]"]) + if err != nil { + t.Fatal(err) + } + // Get the comp map field value + if query["map[kratos]"][0] != comp.Map["kratos"] { + t.Errorf("want: %s, got: %s", query["map[kratos]"], comp.Map["kratos"]) + } +} + +func TestParseField(t *testing.T) { + tests := []struct { + name string + fieldName string + protoReflectKind protoreflect.Kind + value string + targetProtoReflectValue protoreflect.Value + targetErr error + }{ + { + name: "BoolKind", + fieldName: "b", + protoReflectKind: protoreflect.BoolKind, + value: "true", + targetProtoReflectValue: protoreflect.ValueOfBool(true), + targetErr: nil, + }, + { + name: "BoolKind", + fieldName: "b", + protoReflectKind: protoreflect.BoolKind, + value: "a", + targetProtoReflectValue: protoreflect.Value{}, + targetErr: &strconv.NumError{Func: "ParseBool", Num: "a", Err: strconv.ErrSyntax}, + }, + { + name: "EnumKind", + fieldName: "sex", + protoReflectKind: protoreflect.EnumKind, + value: "1", + targetProtoReflectValue: protoreflect.ValueOfEnum(1), + targetErr: nil, + }, + { + name: "EnumKind", + fieldName: "sex", + protoReflectKind: protoreflect.EnumKind, + value: "2", + targetProtoReflectValue: protoreflect.Value{}, + targetErr: fmt.Errorf("%q is not a valid value", "2"), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + comp := &complex.Complex{} + field := getFieldDescriptor(comp.ProtoReflect(), test.fieldName) + if test.protoReflectKind != field.Kind() { + t.Fatalf("want: %d, got: %d", test.protoReflectKind, field.Kind()) + } + val, err := parseField(field, test.value) + if !reflect.DeepEqual(test.targetErr, err) { + t.Fatalf("want: %s, got: %s", test.targetErr, err) + } + if !reflect.DeepEqual(test.targetProtoReflectValue, val) { + t.Errorf("want: %s, got: %s", test.targetProtoReflectValue, val) + } + }) + } +} + +func TestJsonSnakeCase(t *testing.T) { + tests := []struct { + camelCase string + snakeCase string + }{ + { + "userId", "user_id", + }, + { + "user", "user", + }, + { + "userIdAndUsername", "user_id_and_username", + }, + { + "", "", + }, + } + for _, test := range tests { + t.Run(test.camelCase, func(t *testing.T) { + snake := jsonSnakeCase(test.camelCase) + if snake != test.snakeCase { + t.Errorf("want: %s, got: %s", test.snakeCase, snake) + } + }) + } +} + +func TestIsASCIIUpper(t *testing.T) { + tests := []struct { + b byte + upper bool + }{ + { + 'A', true, + }, + { + 'a', false, + }, + { + ',', false, + }, + { + '1', false, + }, + { + ' ', false, + }, + } + for _, test := range tests { + t.Run(string(test.b), func(t *testing.T) { + upper := isASCIIUpper(test.b) + if test.upper != upper { + t.Errorf("'%s' is not ascii upper", string(test.b)) + } + }) + } +}