From 67c4fdf0dec848d8c8201971c03b624791a07c5e Mon Sep 17 00:00:00 2001 From: Long Bui Date: Wed, 29 Jul 2020 23:07:14 +0700 Subject: [PATCH] Make unique tag work with pointer fields. --- baked_in.go | 20 ++++++++++---- validator_test.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/baked_in.go b/baked_in.go index 36e8057..902eb4e 100644 --- a/baked_in.go +++ b/baked_in.go @@ -241,23 +241,33 @@ func isUnique(fl FieldLevel) bool { switch field.Kind() { case reflect.Slice, reflect.Array: + elem := field.Type().Elem() + if elem.Kind() == reflect.Ptr { + elem = elem.Elem() + } + if param == "" { - m := reflect.MakeMap(reflect.MapOf(field.Type().Elem(), v.Type())) + m := reflect.MakeMap(reflect.MapOf(elem, v.Type())) for i := 0; i < field.Len(); i++ { - m.SetMapIndex(field.Index(i), v) + m.SetMapIndex(reflect.Indirect(field.Index(i)), v) } return field.Len() == m.Len() } - sf, ok := field.Type().Elem().FieldByName(param) + sf, ok := elem.FieldByName(param) if !ok { panic(fmt.Sprintf("Bad field name %s", param)) } - m := reflect.MakeMap(reflect.MapOf(sf.Type, v.Type())) + sfTyp := sf.Type + if sfTyp.Kind() == reflect.Ptr { + sfTyp = sfTyp.Elem() + } + + m := reflect.MakeMap(reflect.MapOf(sfTyp, v.Type())) for i := 0; i < field.Len(); i++ { - m.SetMapIndex(field.Index(i).FieldByName(param), v) + m.SetMapIndex(reflect.Indirect(reflect.Indirect(field.Index(i)).FieldByName(param)), v) } return field.Len() == m.Len() case reflect.Map: diff --git a/validator_test.go b/validator_test.go index e76a3cd..00bd11a 100644 --- a/validator_test.go +++ b/validator_test.go @@ -346,6 +346,18 @@ func StructLevelInvalidError(sl StructLevel) { } } +func stringPtr(v string) *string { + return &v +} + +func intPtr(v int) *int { + return &v +} + +func float64Ptr(v float64) *float64 { + return &v +} + func TestStructLevelInvalidError(t *testing.T) { validate := New() @@ -8144,6 +8156,12 @@ func TestUniqueValidation(t *testing.T) { {[2]string{"a", "a"}, false}, {[2]interface{}{"a", "a"}, false}, {[4]interface{}{"a", 1, "b", 1}, false}, + {[2]*string{stringPtr("a"), stringPtr("b")}, true}, + {[2]*int{intPtr(1), intPtr(2)}, true}, + {[2]*float64{float64Ptr(1), float64Ptr(2)}, true}, + {[2]*string{stringPtr("a"), stringPtr("a")}, false}, + {[2]*float64{float64Ptr(1), float64Ptr(1)}, false}, + {[2]*int{intPtr(1), intPtr(1)}, false}, // Slices {[]string{"a", "b"}, true}, {[]int{1, 2}, true}, @@ -8155,6 +8173,12 @@ func TestUniqueValidation(t *testing.T) { {[]string{"a", "a"}, false}, {[]interface{}{"a", "a"}, false}, {[]interface{}{"a", 1, "b", 1}, false}, + {[]*string{stringPtr("a"), stringPtr("b")}, true}, + {[]*int{intPtr(1), intPtr(2)}, true}, + {[]*float64{float64Ptr(1), float64Ptr(2)}, true}, + {[]*string{stringPtr("a"), stringPtr("a")}, false}, + {[]*float64{float64Ptr(1), float64Ptr(1)}, false}, + {[]*int{intPtr(1), intPtr(1)}, false}, // Maps {map[string]string{"one": "a", "two": "b"}, true}, {map[string]int{"one": 1, "two": 2}, true}, @@ -8235,6 +8259,49 @@ func TestUniqueValidationStructSlice(t *testing.T) { PanicMatches(t, func() { validate.Var(testStructs, "unique=C") }, "Bad field name C") } +func TestUniqueValidationStructPtrSlice(t *testing.T) { + testStructs := []*struct { + A *string + B *string + }{ + {A: stringPtr("one"), B: stringPtr("two")}, + {A: stringPtr("one"), B: stringPtr("three")}, + } + + tests := []struct { + target interface{} + param string + expected bool + }{ + {testStructs, "unique", true}, + {testStructs, "unique=A", false}, + {testStructs, "unique=B", true}, + } + + validate := New() + + for i, test := range tests { + + errs := validate.Var(test.target, test.param) + + if test.expected { + if !IsEqual(errs, nil) { + t.Fatalf("Index: %d unique failed Error: %v", i, errs) + } + } else { + if IsEqual(errs, nil) { + t.Fatalf("Index: %d unique failed Error: %v", i, errs) + } else { + val := getError(errs, "", "") + if val.Tag() != "unique" { + t.Fatalf("Index: %d unique failed Error: %v", i, errs) + } + } + } + } + PanicMatches(t, func() { validate.Var(testStructs, "unique=C") }, "Bad field name C") +} + func TestHTMLValidation(t *testing.T) { tests := []struct { param string