diff --git a/baked_in.go b/baked_in.go index 338cddd..a2660ef 100644 --- a/baked_in.go +++ b/baked_in.go @@ -1301,26 +1301,35 @@ func isDefault(fl FieldLevel) bool { // HasValue is the validation function for validating if the current field's value is not the default static value. func hasValue(fl FieldLevel) bool { - return requireCheckFieldKind(fl, "") + field := fl.Field() + switch field.Kind() { + case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func: + return !field.IsNil() + default: + if fl.(*validate).fldIsPointer && field.Interface() != nil { + return true + } + return field.IsValid() && field.Interface() != reflect.Zero(field.Type()).Interface() + } } // requireCheckField is a func for check field kind func requireCheckFieldKind(fl FieldLevel, param string) bool { field := fl.Field() + var ok bool + kind := field.Kind() if len(param) > 0 { - if fl.Parent().Kind() == reflect.Ptr { - field = fl.Parent().Elem().FieldByName(param) - } else { - field = fl.Parent().FieldByName(param) + field, kind, ok = fl.GetStructFieldOKAdvanced(fl.Parent(), param) + if !ok { + return true } } - switch field.Kind() { + switch kind { + case reflect.Invalid: + return true case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func: return !field.IsNil() default: - if fl.(*validate).fldIsPointer && field.Interface() != nil { - return true - } return field.IsValid() && field.Interface() != reflect.Zero(field.Type()).Interface() } } @@ -1328,76 +1337,55 @@ func requireCheckFieldKind(fl FieldLevel, param string) bool { // RequiredWith is the validation function // The field under validation must be present and not empty only if any of the other specified fields are present. func requiredWith(fl FieldLevel) bool { - params := parseOneOfParam2(fl.Param()) for _, param := range params { - if requireCheckFieldKind(fl, param) { return requireCheckFieldKind(fl, "") } } - return true } // RequiredWithAll is the validation function // The field under validation must be present and not empty only if all of the other specified fields are present. func requiredWithAll(fl FieldLevel) bool { - isValidateCurrentField := true params := parseOneOfParam2(fl.Param()) for _, param := range params { if !requireCheckFieldKind(fl, param) { isValidateCurrentField = false + break } } - if isValidateCurrentField { return requireCheckFieldKind(fl, "") } - return true } // RequiredWithout is the validation function // The field under validation must be present and not empty only when any of the other specified fields are not present. func requiredWithout(fl FieldLevel) bool { - - isValidateCurrentField := false params := parseOneOfParam2(fl.Param()) for _, param := range params { - - if requireCheckFieldKind(fl, param) { - isValidateCurrentField = true + if !requireCheckFieldKind(fl, param) { + return hasValue(fl) } } - - if !isValidateCurrentField { - return requireCheckFieldKind(fl, "") - } - return true } // RequiredWithoutAll is the validation function // The field under validation must be present and not empty only when all of the other specified fields are not present. func requiredWithoutAll(fl FieldLevel) bool { - - isValidateCurrentField := true params := parseOneOfParam2(fl.Param()) for _, param := range params { - if requireCheckFieldKind(fl, param) { - isValidateCurrentField = false + return true } } - - if isValidateCurrentField { - return requireCheckFieldKind(fl, "") - } - - return true + return hasValue(fl) } // IsGteField is the validation function for validating if the current field's value is greater than or equal to the field specified by the param's value. diff --git a/cache.go b/cache.go index a7a4202..8276504 100644 --- a/cache.go +++ b/cache.go @@ -87,18 +87,19 @@ type cField struct { } type cTag struct { - tag string - aliasTag string - actualAliasTag string - param string - keys *cTag // only populated when using tag's 'keys' and 'endkeys' for map key validation - next *cTag - fn FuncCtx - typeof tagType - hasTag bool - hasAlias bool - hasParam bool // true if parameter used eg. eq= where the equal sign has been set - isBlockEnd bool // indicates the current tag represents the last validation in the block + tag string + aliasTag string + actualAliasTag string + param string + keys *cTag // only populated when using tag's 'keys' and 'endkeys' for map key validation + next *cTag + fn FuncCtx + typeof tagType + hasTag bool + hasAlias bool + hasParam bool // true if parameter used eg. eq= where the equal sign has been set + isBlockEnd bool // indicates the current tag represents the last validation in the block + runValidationWhenNil bool } func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStruct { @@ -141,9 +142,7 @@ func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStr customName = fld.Name if v.hasTagNameFunc { - name := v.tagNameFunc(fld) - if len(name) > 0 { customName = name } @@ -168,16 +167,13 @@ func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStr namesEqual: fld.Name == customName, }) } - v.structCache.Set(typ, cs) - return cs } func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias string, hasAlias bool) (firstCtag *cTag, current *cTag) { var t string - var ok bool noAlias := len(alias) == 0 tags := strings.Split(tag, tagSeparator) @@ -270,11 +266,9 @@ func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias s continue default: - if t == isdefault { current.typeof = typeIsDefault } - // if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C" orVals := strings.Split(t, orSeparator) @@ -300,7 +294,10 @@ func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias s panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, fieldName))) } - if current.fn, ok = v.validations[current.tag]; !ok { + if wrapper, ok := v.validations[current.tag]; ok { + current.fn = wrapper.fn + current.runValidationWhenNil = wrapper.runValidatinOnNil + } else { panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, current.tag, fieldName))) } diff --git a/field_level.go b/field_level.go index cbfbc15..24bc134 100644 --- a/field_level.go +++ b/field_level.go @@ -38,6 +38,10 @@ type FieldLevel interface { // NOTE: when not successful ok will be false, this can happen when a nested struct is nil and so the field // could not be retrieved because it didn't exist. GetStructFieldOK() (reflect.Value, reflect.Kind, bool) + + // GetStructFieldOKAdvanced is the same as GetStructFieldOK except that it accepts the parent struct to start looking for + // the field and namespace allowing more extensibility for validators. + GetStructFieldOKAdvanced(val reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool) } var _ FieldLevel = new(validate) @@ -67,3 +71,9 @@ func (v *validate) Param() string { func (v *validate) GetStructFieldOK() (reflect.Value, reflect.Kind, bool) { return v.getStructFieldOKInternal(v.slflParent, v.ct.param) } + +// GetStructFieldOKAdvanced is the same as GetStructFieldOK except that it accepts the parent struct to start looking for +// the field and namespace allowing more extensibility for validators. +func (v *validate) GetStructFieldOKAdvanced(val reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool) { + return v.getStructFieldOKInternal(val, namespace) +} diff --git a/validator.go b/validator.go index 67473f1..3abf5d3 100644 --- a/validator.go +++ b/validator.go @@ -94,7 +94,6 @@ func (v *validate) validateStruct(ctx context.Context, parent reflect.Value, cur // traverseField validates any field, be it a struct or single field, ensures it's validity and passes it along to be validated via it's tag options func (v *validate) traverseField(ctx context.Context, parent reflect.Value, current reflect.Value, ns []byte, structNs []byte, cf *cField, ct *cTag) { - var typ reflect.Type var kind reflect.Kind @@ -112,16 +111,36 @@ func (v *validate) traverseField(ctx context.Context, parent reflect.Value, curr } if ct.hasTag { + if kind == reflect.Invalid { + v.str1 = string(append(ns, cf.altName...)) + if v.v.hasTagNameFunc { + v.str2 = string(append(structNs, cf.name...)) + } else { + v.str2 = v.str1 + } + v.errs = append(v.errs, + &fieldError{ + v: v.v, + tag: ct.aliasTag, + actualTag: ct.tag, + ns: v.str1, + structNs: v.str2, + fieldLen: uint8(len(cf.altName)), + structfieldLen: uint8(len(cf.name)), + param: ct.param, + kind: kind, + }, + ) + return + } v.str1 = string(append(ns, cf.altName...)) - if v.v.hasTagNameFunc { v.str2 = string(append(structNs, cf.name...)) } else { v.str2 = v.str1 } - - if kind == reflect.Invalid { + if !ct.runValidationWhenNil { v.errs = append(v.errs, &fieldError{ v: v.v, @@ -131,31 +150,14 @@ func (v *validate) traverseField(ctx context.Context, parent reflect.Value, curr structNs: v.str2, fieldLen: uint8(len(cf.altName)), structfieldLen: uint8(len(cf.name)), + value: current.Interface(), param: ct.param, kind: kind, + typ: current.Type(), }, ) - return } - - v.errs = append(v.errs, - &fieldError{ - v: v.v, - tag: ct.aliasTag, - actualTag: ct.tag, - ns: v.str1, - structNs: v.str2, - fieldLen: uint8(len(cf.altName)), - structfieldLen: uint8(len(cf.name)), - value: current.Interface(), - param: ct.param, - kind: kind, - typ: current.Type(), - }, - ) - - return } case reflect.Struct: diff --git a/validator_instance.go b/validator_instance.go index fc9db5a..4a89d40 100644 --- a/validator_instance.go +++ b/validator_instance.go @@ -13,27 +13,31 @@ import ( ) const ( - defaultTagName = "validate" - utf8HexComma = "0x2C" - utf8Pipe = "0x7C" - tagSeparator = "," - orSeparator = "|" - tagKeySeparator = "=" - structOnlyTag = "structonly" - noStructLevelTag = "nostructlevel" - omitempty = "omitempty" - isdefault = "isdefault" - skipValidationTag = "-" - diveTag = "dive" - keysTag = "keys" - endKeysTag = "endkeys" - requiredTag = "required" - namespaceSeparator = "." - leftBracket = "[" - rightBracket = "]" - restrictedTagChars = ".[],|=+()`~!@#$%^&*\\\"/?<>{}" - restrictedAliasErr = "Alias '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation" - restrictedTagErr = "Tag '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation" + defaultTagName = "validate" + utf8HexComma = "0x2C" + utf8Pipe = "0x7C" + tagSeparator = "," + orSeparator = "|" + tagKeySeparator = "=" + structOnlyTag = "structonly" + noStructLevelTag = "nostructlevel" + omitempty = "omitempty" + isdefault = "isdefault" + requiredWithoutAllTag = "required_without_all" + requiredWithoutTag = "required_without" + requiredWithTag = "required_with" + requiredWithAllTag = "required_with_all" + skipValidationTag = "-" + diveTag = "dive" + keysTag = "keys" + endKeysTag = "endkeys" + requiredTag = "required" + namespaceSeparator = "." + leftBracket = "[" + rightBracket = "]" + restrictedTagChars = ".[],|=+()`~!@#$%^&*\\\"/?<>{}" + restrictedAliasErr = "Alias '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation" + restrictedTagErr = "Tag '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation" ) var ( @@ -55,6 +59,11 @@ type CustomTypeFunc func(field reflect.Value) interface{} // TagNameFunc allows for adding of a custom tag name parser type TagNameFunc func(field reflect.StructField) string +type internalValidationFuncWrapper struct { + fn FuncCtx + runValidatinOnNil bool +} + // Validate contains the validator settings and cache type Validate struct { tagName string @@ -65,7 +74,7 @@ type Validate struct { structLevelFuncs map[reflect.Type]StructLevelFuncCtx customFuncs map[reflect.Type]CustomTypeFunc aliases map[string]string - validations map[string]FuncCtx + validations map[string]internalValidationFuncWrapper transTagFunc map[ut.Translator]map[string]TranslationFunc // map[]map[]TranslationFunc tagCache *tagCache structCache *structCache @@ -83,7 +92,7 @@ func New() *Validate { v := &Validate{ tagName: defaultTagName, aliases: make(map[string]string, len(bakedInAliases)), - validations: make(map[string]FuncCtx, len(bakedInValidators)), + validations: make(map[string]internalValidationFuncWrapper, len(bakedInValidators)), tagCache: tc, structCache: sc, } @@ -96,8 +105,14 @@ func New() *Validate { // must copy validators for separate validations to be used in each instance for k, val := range bakedInValidators { - // no need to error check here, baked in will always be valid - _ = v.registerValidation(k, wrapFunc(val), true) + switch k { + // these require that even if the value is nil that the validation should run, omitempty still overrides this behaviour + case requiredWithTag, requiredWithAllTag, requiredWithoutTag, requiredWithoutAllTag: + _ = v.registerValidation(k, wrapFunc(val), true, true) + default: + // no need to error check here, baked in will always be valid + _ = v.registerValidation(k, wrapFunc(val), true, false) + } } v.pool = &sync.Pool{ @@ -140,18 +155,21 @@ func (v *Validate) RegisterTagNameFunc(fn TagNameFunc) { // NOTES: // - if the key already exists, the previous validation function will be replaced. // - this method is not thread-safe it is intended that these all be registered prior to any validation -func (v *Validate) RegisterValidation(tag string, fn Func) error { - return v.RegisterValidationCtx(tag, wrapFunc(fn)) +func (v *Validate) RegisterValidation(tag string, fn Func, callValidationEvenIfNull ...bool) error { + return v.RegisterValidationCtx(tag, wrapFunc(fn), callValidationEvenIfNull...) } // RegisterValidationCtx does the same as RegisterValidation on accepts a FuncCtx validation // allowing context.Context validation support. -func (v *Validate) RegisterValidationCtx(tag string, fn FuncCtx) error { - return v.registerValidation(tag, fn, false) +func (v *Validate) RegisterValidationCtx(tag string, fn FuncCtx, callValidationEvenIfNull ...bool) error { + var nilCheckable bool + if len(callValidationEvenIfNull) > 0 { + nilCheckable = callValidationEvenIfNull[0] + } + return v.registerValidation(tag, fn, false, nilCheckable) } -func (v *Validate) registerValidation(tag string, fn FuncCtx, bakedIn bool) error { - +func (v *Validate) registerValidation(tag string, fn FuncCtx, bakedIn bool, nilCheckable bool) error { if len(tag) == 0 { return errors.New("Function Key cannot be empty") } @@ -161,13 +179,10 @@ func (v *Validate) registerValidation(tag string, fn FuncCtx, bakedIn bool) erro } _, ok := restrictedTags[tag] - if !bakedIn && (ok || strings.ContainsAny(tag, restrictedTagChars)) { panic(fmt.Sprintf(restrictedTagErr, tag)) } - - v.validations[tag] = fn - + v.validations[tag] = internalValidationFuncWrapper{fn: fn, runValidatinOnNil: nilCheckable} return nil } diff --git a/validator_test.go b/validator_test.go index a8ccfaa..4844537 100644 --- a/validator_test.go +++ b/validator_test.go @@ -8677,14 +8677,20 @@ func TestRequiredWithAll(t *testing.T) { func TestRequiredWithout(t *testing.T) { + type Inner struct { + Field *string + } + fieldVal := "test" test := struct { + Inner *Inner Field1 string `validate:"omitempty" json:"field_1"` Field2 *string `validate:"required_without=Field1" json:"field_2"` Field3 map[string]string `validate:"required_without=Field2" json:"field_3"` Field4 interface{} `validate:"required_without=Field3" json:"field_4"` Field5 string `validate:"required_without=Field3" json:"field_5"` }{ + Inner: &Inner{Field: &fieldVal}, Field2: &fieldVal, Field3: map[string]string{"key": "val"}, Field4: "test", @@ -8694,29 +8700,35 @@ func TestRequiredWithout(t *testing.T) { validate := New() errs := validate.Struct(test) - - if errs != nil { - t.Fatalf("failed Error: %s", errs) - } + Equal(t, errs, nil) test2 := struct { - Field1 string `validate:"omitempty" json:"field_1"` + Inner *Inner + Inner2 *Inner + Field1 string `json:"field_1"` Field2 *string `validate:"required_without=Field1" json:"field_2"` Field3 map[string]string `validate:"required_without=Field2" json:"field_3"` Field4 interface{} `validate:"required_without=Field3" json:"field_4"` Field5 string `validate:"required_without=Field3" json:"field_5"` Field6 string `validate:"required_without=Field1" json:"field_6"` + Field7 string `validate:"required_without=Inner.Field" json:"field_7"` + Field8 string `validate:"required_without=Inner.Field" json:"field_8"` }{ + Inner: &Inner{}, Field3: map[string]string{"key": "val"}, Field4: "test", Field5: "test", } errs = validate.Struct(&test2) + NotEqual(t, errs, nil) - if errs == nil { - t.Fatalf("failed Error: %s", errs) - } + ve := errs.(ValidationErrors) + Equal(t, len(ve), 4) + AssertError(t, errs, "Field2", "Field2", "Field2", "Field2", "required_without") + AssertError(t, errs, "Field6", "Field6", "Field6", "Field6", "required_without") + AssertError(t, errs, "Field7", "Field7", "Field7", "Field7", "required_without") + AssertError(t, errs, "Field8", "Field8", "Field8", "Field8", "required_without") } func TestRequiredWithoutAll(t *testing.T) { @@ -8739,10 +8751,7 @@ func TestRequiredWithoutAll(t *testing.T) { validate := New() errs := validate.Struct(test) - - if errs != nil { - t.Fatalf("failed Error: %s", errs) - } + Equal(t, errs, nil) test2 := struct { Field1 string `validate:"omitempty" json:"field_1"` @@ -8750,7 +8759,7 @@ func TestRequiredWithoutAll(t *testing.T) { Field3 map[string]string `validate:"required_without_all=Field2" json:"field_3"` Field4 interface{} `validate:"required_without_all=Field3" json:"field_4"` Field5 string `validate:"required_without_all=Field3" json:"field_5"` - Field6 string `validate:"required_without_all=Field1" json:"field_6"` + Field6 string `validate:"required_without_all=Field1 Field3" json:"field_6"` }{ Field3: map[string]string{"key": "val"}, Field4: "test", @@ -8758,8 +8767,23 @@ func TestRequiredWithoutAll(t *testing.T) { } errs = validate.Struct(test2) + NotEqual(t, errs, nil) - if errs == nil { - t.Fatalf("failed Error: %s", errs) + ve := errs.(ValidationErrors) + Equal(t, len(ve), 1) + AssertError(t, errs, "Field2", "Field2", "Field2", "Field2", "required_without_all") +} + +func TestLookup(t *testing.T) { + type Lookup struct { + FieldA *string `json:"fieldA,omitempty" validate:"required_without=FieldB"` + FieldB *string `json:"fieldB,omitempty" validate:"required_without=FieldA"` + } + + fieldAValue := "1232" + lookup := Lookup{ + FieldA: &fieldAValue, + FieldB: nil, } + Equal(t, New().Struct(lookup), nil) }