fix required_without_*

pull/523/head
Dean Karn 5 years ago
parent 51fcc303b3
commit cc25246f01
  1. 58
      baked_in.go
  2. 13
      cache.go
  3. 10
      field_level.go
  4. 16
      validator.go
  5. 41
      validator_instance.go
  6. 54
      validator_test.go

@ -1301,26 +1301,35 @@ func isDefault(fl FieldLevel) bool {
// HasValue is the validation function for validating if the current field's value is not the default static value. // HasValue is the validation function for validating if the current field's value is not the default static value.
func hasValue(fl FieldLevel) bool { func hasValue(fl FieldLevel) bool {
return requireCheckFieldKind(fl, "") field := fl.Field()
switch field.Kind() {
case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func:
return !field.IsNil()
default:
if fl.(*validate).fldIsPointer && field.Interface() != nil {
return true
}
return field.IsValid() && field.Interface() != reflect.Zero(field.Type()).Interface()
}
} }
// requireCheckField is a func for check field kind // requireCheckField is a func for check field kind
func requireCheckFieldKind(fl FieldLevel, param string) bool { func requireCheckFieldKind(fl FieldLevel, param string) bool {
field := fl.Field() field := fl.Field()
var ok bool
kind := field.Kind()
if len(param) > 0 { if len(param) > 0 {
if fl.Parent().Kind() == reflect.Ptr { field, kind, ok = fl.GetStructFieldOKAdvanced(fl.Parent(), param)
field = fl.Parent().Elem().FieldByName(param) if !ok {
} else { return true
field = fl.Parent().FieldByName(param)
} }
} }
switch field.Kind() { switch kind {
case reflect.Invalid:
return true
case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func: case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func:
return !field.IsNil() return !field.IsNil()
default: default:
if fl.(*validate).fldIsPointer && field.Interface() != nil {
return true
}
return field.IsValid() && field.Interface() != reflect.Zero(field.Type()).Interface() return field.IsValid() && field.Interface() != reflect.Zero(field.Type()).Interface()
} }
} }
@ -1328,76 +1337,55 @@ func requireCheckFieldKind(fl FieldLevel, param string) bool {
// RequiredWith is the validation function // RequiredWith is the validation function
// The field under validation must be present and not empty only if any of the other specified fields are present. // The field under validation must be present and not empty only if any of the other specified fields are present.
func requiredWith(fl FieldLevel) bool { func requiredWith(fl FieldLevel) bool {
params := parseOneOfParam2(fl.Param()) params := parseOneOfParam2(fl.Param())
for _, param := range params { for _, param := range params {
if requireCheckFieldKind(fl, param) { if requireCheckFieldKind(fl, param) {
return requireCheckFieldKind(fl, "") return requireCheckFieldKind(fl, "")
} }
} }
return true return true
} }
// RequiredWithAll is the validation function // RequiredWithAll is the validation function
// The field under validation must be present and not empty only if all of the other specified fields are present. // The field under validation must be present and not empty only if all of the other specified fields are present.
func requiredWithAll(fl FieldLevel) bool { func requiredWithAll(fl FieldLevel) bool {
isValidateCurrentField := true isValidateCurrentField := true
params := parseOneOfParam2(fl.Param()) params := parseOneOfParam2(fl.Param())
for _, param := range params { for _, param := range params {
if !requireCheckFieldKind(fl, param) { if !requireCheckFieldKind(fl, param) {
isValidateCurrentField = false isValidateCurrentField = false
break
} }
} }
if isValidateCurrentField { if isValidateCurrentField {
return requireCheckFieldKind(fl, "") return requireCheckFieldKind(fl, "")
} }
return true return true
} }
// RequiredWithout is the validation function // RequiredWithout is the validation function
// The field under validation must be present and not empty only when any of the other specified fields are not present. // The field under validation must be present and not empty only when any of the other specified fields are not present.
func requiredWithout(fl FieldLevel) bool { func requiredWithout(fl FieldLevel) bool {
isValidateCurrentField := false
params := parseOneOfParam2(fl.Param()) params := parseOneOfParam2(fl.Param())
for _, param := range params { for _, param := range params {
if !requireCheckFieldKind(fl, param) {
if requireCheckFieldKind(fl, param) { return hasValue(fl)
isValidateCurrentField = true
}
} }
if !isValidateCurrentField {
return requireCheckFieldKind(fl, "")
} }
return true return true
} }
// RequiredWithoutAll is the validation function // RequiredWithoutAll is the validation function
// The field under validation must be present and not empty only when all of the other specified fields are not present. // The field under validation must be present and not empty only when all of the other specified fields are not present.
func requiredWithoutAll(fl FieldLevel) bool { func requiredWithoutAll(fl FieldLevel) bool {
isValidateCurrentField := true
params := parseOneOfParam2(fl.Param()) params := parseOneOfParam2(fl.Param())
for _, param := range params { for _, param := range params {
if requireCheckFieldKind(fl, param) { if requireCheckFieldKind(fl, param) {
isValidateCurrentField = false return true
} }
} }
return hasValue(fl)
if isValidateCurrentField {
return requireCheckFieldKind(fl, "")
}
return true
} }
// IsGteField is the validation function for validating if the current field's value is greater than or equal to the field specified by the param's value. // IsGteField is the validation function for validating if the current field's value is greater than or equal to the field specified by the param's value.

@ -99,6 +99,7 @@ type cTag struct {
hasAlias bool hasAlias bool
hasParam bool // true if parameter used eg. eq= where the equal sign has been set hasParam bool // true if parameter used eg. eq= where the equal sign has been set
isBlockEnd bool // indicates the current tag represents the last validation in the block isBlockEnd bool // indicates the current tag represents the last validation in the block
runValidationWhenNil bool
} }
func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStruct { func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStruct {
@ -141,9 +142,7 @@ func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStr
customName = fld.Name customName = fld.Name
if v.hasTagNameFunc { if v.hasTagNameFunc {
name := v.tagNameFunc(fld) name := v.tagNameFunc(fld)
if len(name) > 0 { if len(name) > 0 {
customName = name customName = name
} }
@ -168,16 +167,13 @@ func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStr
namesEqual: fld.Name == customName, namesEqual: fld.Name == customName,
}) })
} }
v.structCache.Set(typ, cs) v.structCache.Set(typ, cs)
return cs return cs
} }
func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias string, hasAlias bool) (firstCtag *cTag, current *cTag) { func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias string, hasAlias bool) (firstCtag *cTag, current *cTag) {
var t string var t string
var ok bool
noAlias := len(alias) == 0 noAlias := len(alias) == 0
tags := strings.Split(tag, tagSeparator) tags := strings.Split(tag, tagSeparator)
@ -270,11 +266,9 @@ func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias s
continue continue
default: default:
if t == isdefault { if t == isdefault {
current.typeof = typeIsDefault current.typeof = typeIsDefault
} }
// if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C" // if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C"
orVals := strings.Split(t, orSeparator) orVals := strings.Split(t, orSeparator)
@ -300,7 +294,10 @@ func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias s
panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, fieldName))) panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, fieldName)))
} }
if current.fn, ok = v.validations[current.tag]; !ok { if wrapper, ok := v.validations[current.tag]; ok {
current.fn = wrapper.fn
current.runValidationWhenNil = wrapper.runValidatinOnNil
} else {
panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, current.tag, fieldName))) panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, current.tag, fieldName)))
} }

@ -38,6 +38,10 @@ type FieldLevel interface {
// NOTE: when not successful ok will be false, this can happen when a nested struct is nil and so the field // NOTE: when not successful ok will be false, this can happen when a nested struct is nil and so the field
// could not be retrieved because it didn't exist. // could not be retrieved because it didn't exist.
GetStructFieldOK() (reflect.Value, reflect.Kind, bool) GetStructFieldOK() (reflect.Value, reflect.Kind, bool)
// GetStructFieldOKAdvanced is the same as GetStructFieldOK except that it accepts the parent struct to start looking for
// the field and namespace allowing more extensibility for validators.
GetStructFieldOKAdvanced(val reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool)
} }
var _ FieldLevel = new(validate) var _ FieldLevel = new(validate)
@ -67,3 +71,9 @@ func (v *validate) Param() string {
func (v *validate) GetStructFieldOK() (reflect.Value, reflect.Kind, bool) { func (v *validate) GetStructFieldOK() (reflect.Value, reflect.Kind, bool) {
return v.getStructFieldOKInternal(v.slflParent, v.ct.param) return v.getStructFieldOKInternal(v.slflParent, v.ct.param)
} }
// GetStructFieldOKAdvanced is the same as GetStructFieldOK except that it accepts the parent struct to start looking for
// the field and namespace allowing more extensibility for validators.
func (v *validate) GetStructFieldOKAdvanced(val reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool) {
return v.getStructFieldOKInternal(val, namespace)
}

@ -94,7 +94,6 @@ func (v *validate) validateStruct(ctx context.Context, parent reflect.Value, cur
// traverseField validates any field, be it a struct or single field, ensures it's validity and passes it along to be validated via it's tag options // traverseField validates any field, be it a struct or single field, ensures it's validity and passes it along to be validated via it's tag options
func (v *validate) traverseField(ctx context.Context, parent reflect.Value, current reflect.Value, ns []byte, structNs []byte, cf *cField, ct *cTag) { func (v *validate) traverseField(ctx context.Context, parent reflect.Value, current reflect.Value, ns []byte, structNs []byte, cf *cField, ct *cTag) {
var typ reflect.Type var typ reflect.Type
var kind reflect.Kind var kind reflect.Kind
@ -112,16 +111,13 @@ func (v *validate) traverseField(ctx context.Context, parent reflect.Value, curr
} }
if ct.hasTag { if ct.hasTag {
if kind == reflect.Invalid {
v.str1 = string(append(ns, cf.altName...)) v.str1 = string(append(ns, cf.altName...))
if v.v.hasTagNameFunc { if v.v.hasTagNameFunc {
v.str2 = string(append(structNs, cf.name...)) v.str2 = string(append(structNs, cf.name...))
} else { } else {
v.str2 = v.str1 v.str2 = v.str1
} }
if kind == reflect.Invalid {
v.errs = append(v.errs, v.errs = append(v.errs,
&fieldError{ &fieldError{
v: v.v, v: v.v,
@ -135,10 +131,16 @@ func (v *validate) traverseField(ctx context.Context, parent reflect.Value, curr
kind: kind, kind: kind,
}, },
) )
return return
} }
v.str1 = string(append(ns, cf.altName...))
if v.v.hasTagNameFunc {
v.str2 = string(append(structNs, cf.name...))
} else {
v.str2 = v.str1
}
if !ct.runValidationWhenNil {
v.errs = append(v.errs, v.errs = append(v.errs,
&fieldError{ &fieldError{
v: v.v, v: v.v,
@ -154,9 +156,9 @@ func (v *validate) traverseField(ctx context.Context, parent reflect.Value, curr
typ: current.Type(), typ: current.Type(),
}, },
) )
return return
} }
}
case reflect.Struct: case reflect.Struct:

@ -23,6 +23,10 @@ const (
noStructLevelTag = "nostructlevel" noStructLevelTag = "nostructlevel"
omitempty = "omitempty" omitempty = "omitempty"
isdefault = "isdefault" isdefault = "isdefault"
requiredWithoutAllTag = "required_without_all"
requiredWithoutTag = "required_without"
requiredWithTag = "required_with"
requiredWithAllTag = "required_with_all"
skipValidationTag = "-" skipValidationTag = "-"
diveTag = "dive" diveTag = "dive"
keysTag = "keys" keysTag = "keys"
@ -55,6 +59,11 @@ type CustomTypeFunc func(field reflect.Value) interface{}
// TagNameFunc allows for adding of a custom tag name parser // TagNameFunc allows for adding of a custom tag name parser
type TagNameFunc func(field reflect.StructField) string type TagNameFunc func(field reflect.StructField) string
type internalValidationFuncWrapper struct {
fn FuncCtx
runValidatinOnNil bool
}
// Validate contains the validator settings and cache // Validate contains the validator settings and cache
type Validate struct { type Validate struct {
tagName string tagName string
@ -65,7 +74,7 @@ type Validate struct {
structLevelFuncs map[reflect.Type]StructLevelFuncCtx structLevelFuncs map[reflect.Type]StructLevelFuncCtx
customFuncs map[reflect.Type]CustomTypeFunc customFuncs map[reflect.Type]CustomTypeFunc
aliases map[string]string aliases map[string]string
validations map[string]FuncCtx validations map[string]internalValidationFuncWrapper
transTagFunc map[ut.Translator]map[string]TranslationFunc // map[<locale>]map[<tag>]TranslationFunc transTagFunc map[ut.Translator]map[string]TranslationFunc // map[<locale>]map[<tag>]TranslationFunc
tagCache *tagCache tagCache *tagCache
structCache *structCache structCache *structCache
@ -83,7 +92,7 @@ func New() *Validate {
v := &Validate{ v := &Validate{
tagName: defaultTagName, tagName: defaultTagName,
aliases: make(map[string]string, len(bakedInAliases)), aliases: make(map[string]string, len(bakedInAliases)),
validations: make(map[string]FuncCtx, len(bakedInValidators)), validations: make(map[string]internalValidationFuncWrapper, len(bakedInValidators)),
tagCache: tc, tagCache: tc,
structCache: sc, structCache: sc,
} }
@ -96,8 +105,14 @@ func New() *Validate {
// must copy validators for separate validations to be used in each instance // must copy validators for separate validations to be used in each instance
for k, val := range bakedInValidators { for k, val := range bakedInValidators {
switch k {
// these require that even if the value is nil that the validation should run, omitempty still overrides this behaviour
case requiredWithTag, requiredWithAllTag, requiredWithoutTag, requiredWithoutAllTag:
_ = v.registerValidation(k, wrapFunc(val), true, true)
default:
// no need to error check here, baked in will always be valid // no need to error check here, baked in will always be valid
_ = v.registerValidation(k, wrapFunc(val), true) _ = v.registerValidation(k, wrapFunc(val), true, false)
}
} }
v.pool = &sync.Pool{ v.pool = &sync.Pool{
@ -140,18 +155,21 @@ func (v *Validate) RegisterTagNameFunc(fn TagNameFunc) {
// NOTES: // NOTES:
// - if the key already exists, the previous validation function will be replaced. // - if the key already exists, the previous validation function will be replaced.
// - this method is not thread-safe it is intended that these all be registered prior to any validation // - this method is not thread-safe it is intended that these all be registered prior to any validation
func (v *Validate) RegisterValidation(tag string, fn Func) error { func (v *Validate) RegisterValidation(tag string, fn Func, callValidationEvenIfNull ...bool) error {
return v.RegisterValidationCtx(tag, wrapFunc(fn)) return v.RegisterValidationCtx(tag, wrapFunc(fn), callValidationEvenIfNull...)
} }
// RegisterValidationCtx does the same as RegisterValidation on accepts a FuncCtx validation // RegisterValidationCtx does the same as RegisterValidation on accepts a FuncCtx validation
// allowing context.Context validation support. // allowing context.Context validation support.
func (v *Validate) RegisterValidationCtx(tag string, fn FuncCtx) error { func (v *Validate) RegisterValidationCtx(tag string, fn FuncCtx, callValidationEvenIfNull ...bool) error {
return v.registerValidation(tag, fn, false) var nilCheckable bool
if len(callValidationEvenIfNull) > 0 {
nilCheckable = callValidationEvenIfNull[0]
}
return v.registerValidation(tag, fn, false, nilCheckable)
} }
func (v *Validate) registerValidation(tag string, fn FuncCtx, bakedIn bool) error { func (v *Validate) registerValidation(tag string, fn FuncCtx, bakedIn bool, nilCheckable bool) error {
if len(tag) == 0 { if len(tag) == 0 {
return errors.New("Function Key cannot be empty") return errors.New("Function Key cannot be empty")
} }
@ -161,13 +179,10 @@ func (v *Validate) registerValidation(tag string, fn FuncCtx, bakedIn bool) erro
} }
_, ok := restrictedTags[tag] _, ok := restrictedTags[tag]
if !bakedIn && (ok || strings.ContainsAny(tag, restrictedTagChars)) { if !bakedIn && (ok || strings.ContainsAny(tag, restrictedTagChars)) {
panic(fmt.Sprintf(restrictedTagErr, tag)) panic(fmt.Sprintf(restrictedTagErr, tag))
} }
v.validations[tag] = internalValidationFuncWrapper{fn: fn, runValidatinOnNil: nilCheckable}
v.validations[tag] = fn
return nil return nil
} }

@ -8677,14 +8677,20 @@ func TestRequiredWithAll(t *testing.T) {
func TestRequiredWithout(t *testing.T) { func TestRequiredWithout(t *testing.T) {
type Inner struct {
Field *string
}
fieldVal := "test" fieldVal := "test"
test := struct { test := struct {
Inner *Inner
Field1 string `validate:"omitempty" json:"field_1"` Field1 string `validate:"omitempty" json:"field_1"`
Field2 *string `validate:"required_without=Field1" json:"field_2"` Field2 *string `validate:"required_without=Field1" json:"field_2"`
Field3 map[string]string `validate:"required_without=Field2" json:"field_3"` Field3 map[string]string `validate:"required_without=Field2" json:"field_3"`
Field4 interface{} `validate:"required_without=Field3" json:"field_4"` Field4 interface{} `validate:"required_without=Field3" json:"field_4"`
Field5 string `validate:"required_without=Field3" json:"field_5"` Field5 string `validate:"required_without=Field3" json:"field_5"`
}{ }{
Inner: &Inner{Field: &fieldVal},
Field2: &fieldVal, Field2: &fieldVal,
Field3: map[string]string{"key": "val"}, Field3: map[string]string{"key": "val"},
Field4: "test", Field4: "test",
@ -8694,29 +8700,35 @@ func TestRequiredWithout(t *testing.T) {
validate := New() validate := New()
errs := validate.Struct(test) errs := validate.Struct(test)
Equal(t, errs, nil)
if errs != nil {
t.Fatalf("failed Error: %s", errs)
}
test2 := struct { test2 := struct {
Field1 string `validate:"omitempty" json:"field_1"` Inner *Inner
Inner2 *Inner
Field1 string `json:"field_1"`
Field2 *string `validate:"required_without=Field1" json:"field_2"` Field2 *string `validate:"required_without=Field1" json:"field_2"`
Field3 map[string]string `validate:"required_without=Field2" json:"field_3"` Field3 map[string]string `validate:"required_without=Field2" json:"field_3"`
Field4 interface{} `validate:"required_without=Field3" json:"field_4"` Field4 interface{} `validate:"required_without=Field3" json:"field_4"`
Field5 string `validate:"required_without=Field3" json:"field_5"` Field5 string `validate:"required_without=Field3" json:"field_5"`
Field6 string `validate:"required_without=Field1" json:"field_6"` Field6 string `validate:"required_without=Field1" json:"field_6"`
Field7 string `validate:"required_without=Inner.Field" json:"field_7"`
Field8 string `validate:"required_without=Inner.Field" json:"field_8"`
}{ }{
Inner: &Inner{},
Field3: map[string]string{"key": "val"}, Field3: map[string]string{"key": "val"},
Field4: "test", Field4: "test",
Field5: "test", Field5: "test",
} }
errs = validate.Struct(&test2) errs = validate.Struct(&test2)
NotEqual(t, errs, nil)
if errs == nil { ve := errs.(ValidationErrors)
t.Fatalf("failed Error: %s", errs) Equal(t, len(ve), 4)
} AssertError(t, errs, "Field2", "Field2", "Field2", "Field2", "required_without")
AssertError(t, errs, "Field6", "Field6", "Field6", "Field6", "required_without")
AssertError(t, errs, "Field7", "Field7", "Field7", "Field7", "required_without")
AssertError(t, errs, "Field8", "Field8", "Field8", "Field8", "required_without")
} }
func TestRequiredWithoutAll(t *testing.T) { func TestRequiredWithoutAll(t *testing.T) {
@ -8739,10 +8751,7 @@ func TestRequiredWithoutAll(t *testing.T) {
validate := New() validate := New()
errs := validate.Struct(test) errs := validate.Struct(test)
Equal(t, errs, nil)
if errs != nil {
t.Fatalf("failed Error: %s", errs)
}
test2 := struct { test2 := struct {
Field1 string `validate:"omitempty" json:"field_1"` Field1 string `validate:"omitempty" json:"field_1"`
@ -8750,7 +8759,7 @@ func TestRequiredWithoutAll(t *testing.T) {
Field3 map[string]string `validate:"required_without_all=Field2" json:"field_3"` Field3 map[string]string `validate:"required_without_all=Field2" json:"field_3"`
Field4 interface{} `validate:"required_without_all=Field3" json:"field_4"` Field4 interface{} `validate:"required_without_all=Field3" json:"field_4"`
Field5 string `validate:"required_without_all=Field3" json:"field_5"` Field5 string `validate:"required_without_all=Field3" json:"field_5"`
Field6 string `validate:"required_without_all=Field1" json:"field_6"` Field6 string `validate:"required_without_all=Field1 Field3" json:"field_6"`
}{ }{
Field3: map[string]string{"key": "val"}, Field3: map[string]string{"key": "val"},
Field4: "test", Field4: "test",
@ -8758,8 +8767,23 @@ func TestRequiredWithoutAll(t *testing.T) {
} }
errs = validate.Struct(test2) errs = validate.Struct(test2)
NotEqual(t, errs, nil)
if errs == nil { ve := errs.(ValidationErrors)
t.Fatalf("failed Error: %s", errs) Equal(t, len(ve), 1)
AssertError(t, errs, "Field2", "Field2", "Field2", "Field2", "required_without_all")
}
func TestLookup(t *testing.T) {
type Lookup struct {
FieldA *string `json:"fieldA,omitempty" validate:"required_without=FieldB"`
FieldB *string `json:"fieldB,omitempty" validate:"required_without=FieldA"`
}
fieldAValue := "1232"
lookup := Lookup{
FieldA: &fieldAValue,
FieldB: nil,
} }
Equal(t, New().Struct(lookup), nil)
} }

Loading…
Cancel
Save