diff --git a/baked_in.go b/baked_in.go index 702e78d..ab3d6df 100644 --- a/baked_in.go +++ b/baked_in.go @@ -63,6 +63,7 @@ var ( // or even disregard and use your own map if so desired. bakedInValidators = map[string]Func{ "required": hasValue, + "required_with": requiredWith, "isdefault": isDefault, "len": hasLengthOf, "min": hasMinOf, @@ -1313,6 +1314,39 @@ func hasValue(fl FieldLevel) bool { } } +// RequiredWith is the validation function for validating if the current field's if any of the other specified fields are present. +func requiredWith(fl FieldLevel) bool { + + field := fl.Field() + params := parseOneOfParam2(fl.Param()) + for _, param := range params { + isParamFieldPresent := false + + paramField := fl.Parent().FieldByName(param) + + switch paramField.Kind() { + case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func: + isParamFieldPresent = !paramField.IsNil() + default: + isParamFieldPresent = paramField.IsValid() && paramField.Interface() != reflect.Zero(field.Type()).Interface() + } + + if isParamFieldPresent { + switch field.Kind() { + case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func: + return !field.IsNil() + default: + if fl.(*validate).fldIsPointer && field.Interface() != nil { + return true + } + return field.IsValid() && field.Interface() != reflect.Zero(field.Type()).Interface() + } + } + } + + return true +} + // IsGteField is the validation function for validating if the current field's value is greater than or equal to the field specified by the param's value. func isGteField(fl FieldLevel) bool { diff --git a/validator_test.go b/validator_test.go index 41f12cb..825d515 100644 --- a/validator_test.go +++ b/validator_test.go @@ -8572,7 +8572,7 @@ func TestStartsWithValidation(t *testing.T) { ExpectedNil bool }{ {Value: "(/^ヮ^)/*:・゚✧ glitter", Tag: "startswith=(/^ヮ^)/*:・゚✧", ExpectedNil: true}, - {Value: "abcd", Tag: "startswith=(/^ヮ^)/*:・゚✧", ExpectedNil: false}, + {Value: "abcd", Tag: "startswith=(/^ヮ^)/*:・゚✧", ExpectedNil: false}, } validate := New() @@ -8619,4 +8619,23 @@ func TestEndsWithValidation(t *testing.T) { } } +func TestRequiredWith(t *testing.T) { + test := struct { + Field1 string `validate:"omitempty" json:"field_1"` + Field2 string `validate:"omitempty" json:"field_2"` + Field3 string `validate:"required_with=Field1 Field2" json:"field_3"` + }{ + Field1: "test_field1", + Field2: "test_field2", + Field3: "test_field3", + } + + validate := New() + + errs := validate.Struct(test) + + if errs != nil { + t.Fatalf("failed Error: %s", errs) + } +}