diff --git a/encoding/json/json_test.go b/encoding/json/json_test.go index 5eab1ba37..b43e4ef70 100644 --- a/encoding/json/json_test.go +++ b/encoding/json/json_test.go @@ -1,7 +1,8 @@ package json import ( - "bytes" + "encoding/json" + "reflect" "strings" "testing" @@ -21,6 +22,47 @@ type testMessage struct { Embed *testEmbed `json:"embed,omitempty"` } +type mock struct { + value int +} + +const ( + Unknown = iota + Gopher + Zebra +) + +func (a *mock) UnmarshalJSON(b []byte) error { + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + switch strings.ToLower(s) { + default: + a.value = Unknown + case "gopher": + a.value = Gopher + case "zebra": + a.value = Zebra + } + + return nil +} + +func (a *mock) MarshalJSON() ([]byte, error) { + var s string + switch a.value { + default: + s = "unknown" + case Gopher: + s = "gopher" + case Zebra: + s = "zebra" + } + + return json.Marshal(s) +} + func TestJSON_Marshal(t *testing.T) { tests := []struct { input interface{} @@ -38,6 +80,10 @@ func TestJSON_Marshal(t *testing.T) { input: &testData.TestModel{Id: 1, Name: "go-kratos", Hobby: []string{"1", "2"}}, expect: `{"id":"1","name":"go-kratos","hobby":["1","2"],"attrs":{}}`, }, + { + input: &mock{value: Gopher}, + expect: `"gopher"`, + }, } for _, v := range tests { data, err := (codec{}).Marshal(v.input) @@ -55,8 +101,10 @@ func TestJSON_Marshal(t *testing.T) { } func TestJSON_Unmarshal(t *testing.T) { - p := &testMessage{} - p2 := &testData.TestModel{} + p := testMessage{} + p2 := testData.TestModel{} + p3 := &testData.TestModel{} + p4 := &mock{} tests := []struct { input string expect interface{} @@ -70,13 +118,21 @@ func TestJSON_Unmarshal(t *testing.T) { expect: &p, }, { - input: `{"id":1,"name":"kratos"}`, + input: `{"id":"1","name":"go-kratos","hobby":["1","2"],"attrs":{}}`, expect: &p2, }, + { + input: `{"id":1,"name":"go-kratos","hobby":["1","2"]}`, + expect: &p3, + }, + { + input: `"zebra"`, + expect: p4, + }, } for _, v := range tests { want := []byte(v.input) - err := (codec{}).Unmarshal(want, &v.expect) + err := (codec{}).Unmarshal(want, v.expect) if err != nil { t.Errorf("marshal(%#v): %s", v.input, err) } @@ -84,7 +140,7 @@ func TestJSON_Unmarshal(t *testing.T) { if err != nil { t.Errorf("marshal(%#v): %s", v.input, err) } - if !bytes.Equal(got, want) { + if !reflect.DeepEqual(strings.ReplaceAll(string(got), " ", ""), strings.ReplaceAll(string(want), " ", "")) { t.Errorf("marshal(%#v):\nhave %#q\nwant %#q", v.input, got, want) } }