From 38bfb46b5aca07e8b7bdd45f0be03887e4c78287 Mon Sep 17 00:00:00 2001 From: Dean Karn Date: Sun, 17 Nov 2019 13:02:10 -0800 Subject: [PATCH] fix required_* --- baked_in.go | 26 ++++++------ field_level.go | 37 ++++++++++++++++- util.go | 7 ++-- validator.go | 2 +- validator_test.go | 104 ++++++++++++++++++++++++++++++++++------------ 5 files changed, 129 insertions(+), 47 deletions(-) diff --git a/baked_in.go b/baked_in.go index 95d613c..23f76b3 100644 --- a/baked_in.go +++ b/baked_in.go @@ -1316,11 +1316,11 @@ func hasValue(fl FieldLevel) bool { // requireCheckField is a func for check field kind func requireCheckFieldKind(fl FieldLevel, param string, defaultNotFoundValue bool) bool { field := fl.Field() - var ok bool kind := field.Kind() + var nullable, found bool if len(param) > 0 { - field, kind, ok = fl.GetStructFieldOKAdvanced(fl.Parent(), param) - if !ok { + field, kind, nullable, found = fl.GetStructFieldOKAdvanced2(fl.Parent(), param) + if !found { return defaultNotFoundValue } } @@ -1328,9 +1328,12 @@ func requireCheckFieldKind(fl FieldLevel, param string, defaultNotFoundValue boo case reflect.Invalid: return defaultNotFoundValue case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func: - return !field.IsNil() + return field.IsNil() default: - return field.IsValid() && field.Interface() != reflect.Zero(field.Type()).Interface() + if nullable && field.Interface() != nil { + return false + } + return field.IsValid() && field.Interface() == reflect.Zero(field.Type()).Interface() } } @@ -1339,7 +1342,7 @@ func requireCheckFieldKind(fl FieldLevel, param string, defaultNotFoundValue boo func requiredWith(fl FieldLevel) bool { params := parseOneOfParam2(fl.Param()) for _, param := range params { - if requireCheckFieldKind(fl, param, false) { + if !requireCheckFieldKind(fl, param, true) { return hasValue(fl) } } @@ -1351,7 +1354,7 @@ func requiredWith(fl FieldLevel) bool { func requiredWithAll(fl FieldLevel) bool { params := parseOneOfParam2(fl.Param()) for _, param := range params { - if !requireCheckFieldKind(fl, param, false) { + if requireCheckFieldKind(fl, param, true) { return true } } @@ -1361,11 +1364,8 @@ func requiredWithAll(fl FieldLevel) bool { // RequiredWithout is the validation function // The field under validation must be present and not empty only when any of the other specified fields are not present. func requiredWithout(fl FieldLevel) bool { - params := parseOneOfParam2(fl.Param()) - for _, param := range params { - if !requireCheckFieldKind(fl, param, true) { - return hasValue(fl) - } + if requireCheckFieldKind(fl, strings.TrimSpace(fl.Param()), true) { + return hasValue(fl) } return true } @@ -1375,7 +1375,7 @@ func requiredWithout(fl FieldLevel) bool { func requiredWithoutAll(fl FieldLevel) bool { params := parseOneOfParam2(fl.Param()) for _, param := range params { - if requireCheckFieldKind(fl, param, true) { + if !requireCheckFieldKind(fl, param, true) { return true } } diff --git a/field_level.go b/field_level.go index 24bc134..7a13f33 100644 --- a/field_level.go +++ b/field_level.go @@ -37,11 +37,27 @@ type FieldLevel interface { // // NOTE: when not successful ok will be false, this can happen when a nested struct is nil and so the field // could not be retrieved because it didn't exist. + // + // Deprecated: Use GetStructFieldOK2() instead which also return if the value is nullable. GetStructFieldOK() (reflect.Value, reflect.Kind, bool) // GetStructFieldOKAdvanced is the same as GetStructFieldOK except that it accepts the parent struct to start looking for // the field and namespace allowing more extensibility for validators. + // + // Deprecated: Use GetStructFieldOKAdvanced2() instead which also return if the value is nullable. GetStructFieldOKAdvanced(val reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool) + + // traverses the parent struct to retrieve a specific field denoted by the provided namespace + // in the param and returns the field, field kind, if it's a nullable type and whether is was successful in retrieving + // the field at all. + // + // NOTE: when not successful ok will be false, this can happen when a nested struct is nil and so the field + // could not be retrieved because it didn't exist. + GetStructFieldOK2() (reflect.Value, reflect.Kind, bool, bool) + + // GetStructFieldOKAdvanced is the same as GetStructFieldOK except that it accepts the parent struct to start looking for + // the field and namespace allowing more extensibility for validators. + GetStructFieldOKAdvanced2(val reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool, bool) } var _ FieldLevel = new(validate) @@ -52,7 +68,7 @@ func (v *validate) Field() reflect.Value { } // FieldName returns the field's name with the tag -// name takeing precedence over the fields actual name. +// name taking precedence over the fields actual name. func (v *validate) FieldName() string { return v.cf.altName } @@ -68,12 +84,29 @@ func (v *validate) Param() string { } // GetStructFieldOK returns Param returns param for validation against current field +// +// Deprecated: Use GetStructFieldOK2() instead which also return if the value is nullable. func (v *validate) GetStructFieldOK() (reflect.Value, reflect.Kind, bool) { - return v.getStructFieldOKInternal(v.slflParent, v.ct.param) + current, kind, _, found := v.getStructFieldOKInternal(v.slflParent, v.ct.param) + return current, kind, found } // GetStructFieldOKAdvanced is the same as GetStructFieldOK except that it accepts the parent struct to start looking for // the field and namespace allowing more extensibility for validators. +// +// Deprecated: Use GetStructFieldOKAdvanced2() instead which also return if the value is nullable. func (v *validate) GetStructFieldOKAdvanced(val reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool) { + current, kind, _, found := v.GetStructFieldOKAdvanced2(val, namespace) + return current, kind, found +} + +// GetStructFieldOK returns Param returns param for validation against current field +func (v *validate) GetStructFieldOK2() (reflect.Value, reflect.Kind, bool, bool) { + return v.getStructFieldOKInternal(v.slflParent, v.ct.param) +} + +// GetStructFieldOKAdvanced is the same as GetStructFieldOK except that it accepts the parent struct to start looking for +// the field and namespace allowing more extensibility for validators. +func (v *validate) GetStructFieldOKAdvanced2(val reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool, bool) { return v.getStructFieldOKInternal(val, namespace) } diff --git a/util.go b/util.go index 16a5517..71acbdc 100644 --- a/util.go +++ b/util.go @@ -57,11 +57,10 @@ BEGIN: // // NOTE: when not successful ok will be false, this can happen when a nested struct is nil and so the field // could not be retrieved because it didn't exist. -func (v *validate) getStructFieldOKInternal(val reflect.Value, namespace string) (current reflect.Value, kind reflect.Kind, found bool) { +func (v *validate) getStructFieldOKInternal(val reflect.Value, namespace string) (current reflect.Value, kind reflect.Kind, nullable bool, found bool) { BEGIN: - current, kind, _ = v.ExtractType(val) - + current, kind, nullable = v.ExtractType(val) if kind == reflect.Invalid { return } @@ -112,7 +111,7 @@ BEGIN: arrIdx, _ := strconv.Atoi(namespace[idx+1 : idx2]) if arrIdx >= current.Len() { - return current, kind, false + return } startIdx := idx2 + 1 diff --git a/validator.go b/validator.go index 3abf5d3..342e72e 100644 --- a/validator.go +++ b/validator.go @@ -7,7 +7,7 @@ import ( "strconv" ) -// per validate contruct +// per validate construct type validate struct { v *Validate top reflect.Value diff --git a/validator_test.go b/validator_test.go index 6f42a91..17d6edc 100644 --- a/validator_test.go +++ b/validator_test.go @@ -1640,125 +1640,125 @@ func TestCrossNamespaceFieldValidation(t *testing.T) { v: vd, } - current, kind, ok := v.getStructFieldOKInternal(val, "Inner.CreatedAt") + current, kind, _, ok := v.getStructFieldOKInternal(val, "Inner.CreatedAt") Equal(t, ok, true) Equal(t, kind, reflect.Struct) tm, ok := current.Interface().(time.Time) Equal(t, ok, true) Equal(t, tm, now) - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.Slice[1]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.Slice[1]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, _, ok = v.getStructFieldOKInternal(val, "Inner.CrazyNonExistantField") + current, _, _, ok = v.getStructFieldOKInternal(val, "Inner.CrazyNonExistantField") Equal(t, ok, false) - current, _, ok = v.getStructFieldOKInternal(val, "Inner.Slice[101]") + current, _, _, ok = v.getStructFieldOKInternal(val, "Inner.Slice[101]") Equal(t, ok, false) - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.Map[key3]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.Map[key3]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val3") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapMap[key2][key2-1]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapMap[key2][key2-1]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapStructs[key2].Name") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapStructs[key2].Name") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "name2") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapMapStruct[key3][key3-1].Name") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapMapStruct[key3][key3-1].Name") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "name3") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.SliceSlice[2][0]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.SliceSlice[2][0]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "7") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.SliceSliceStruct[2][1].Name") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.SliceSliceStruct[2][1].Name") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "name8") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.SliceMap[1][key5]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.SliceMap[1][key5]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val5") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapSlice[key3][2]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapSlice[key3][2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "9") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapInt[2]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapInt[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapInt8[2]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapInt8[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapInt16[2]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapInt16[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapInt32[2]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapInt32[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapInt64[2]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapInt64[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapUint[2]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapUint[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapUint8[2]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapUint8[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapUint16[2]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapUint16[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapUint32[2]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapUint32[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapUint64[2]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapUint64[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapFloat32[3.03]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapFloat32[3.03]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val3") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapFloat64[2.02]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapFloat64[2.02]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.MapBool[true]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.MapBool[true]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val1") @@ -1784,13 +1784,13 @@ func TestCrossNamespaceFieldValidation(t *testing.T) { val = reflect.ValueOf(test) - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.SliceStructs[2]") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.SliceStructs[2]") Equal(t, ok, true) Equal(t, kind, reflect.Ptr) Equal(t, current.String(), "<*validator.SliceStruct Value>") Equal(t, current.IsNil(), true) - current, kind, ok = v.getStructFieldOKInternal(val, "Inner.SliceStructs[2].Name") + current, kind, _, ok = v.getStructFieldOKInternal(val, "Inner.SliceStructs[2].Name") Equal(t, ok, false) Equal(t, kind, reflect.Ptr) Equal(t, current.String(), "<*validator.SliceStruct Value>") @@ -8864,3 +8864,53 @@ func TestAbilityToValidateNils(t *testing.T) { errs = val.Struct(ts) NotEqual(t, errs, nil) } + +func TestRequiredWithoutPointers(t *testing.T) { + type Lookup struct { + FieldA *bool `json:"fieldA,omitempty" validate:"required_without=FieldB"` + FieldB *bool `json:"fieldB,omitempty" validate:"required_without=FieldA"` + } + + b := true + lookup := Lookup{ + FieldA: &b, + FieldB: nil, + } + + val := New() + errs := val.Struct(lookup) + Equal(t, errs, nil) + + b = false + lookup = Lookup{ + FieldA: &b, + FieldB: nil, + } + errs = val.Struct(lookup) + Equal(t, errs, nil) +} + +func TestRequiredWithoutAllPointers(t *testing.T) { + type Lookup struct { + FieldA *bool `json:"fieldA,omitempty" validate:"required_without_all=FieldB"` + FieldB *bool `json:"fieldB,omitempty" validate:"required_without_all=FieldA"` + } + + b := true + lookup := Lookup{ + FieldA: &b, + FieldB: nil, + } + + val := New() + errs := val.Struct(lookup) + Equal(t, errs, nil) + + b = false + lookup = Lookup{ + FieldA: &b, + FieldB: nil, + } + errs = val.Struct(lookup) + Equal(t, errs, nil) +}