From ebfddc5784664ef7bc94e3d9672158fb38ff6cc8 Mon Sep 17 00:00:00 2001 From: A Mashmooli Date: Mon, 29 Apr 2019 15:46:48 +0430 Subject: [PATCH] add required_with_all validator --- baked_in.go | 247 +++++++++++++++++++++++++++------------------- validator_test.go | 21 ++++ 2 files changed, 167 insertions(+), 101 deletions(-) diff --git a/baked_in.go b/baked_in.go index 2f95150..8ad8773 100644 --- a/baked_in.go +++ b/baked_in.go @@ -62,106 +62,107 @@ var ( // you can add, remove or even replace items to suite your needs, // or even disregard and use your own map if so desired. bakedInValidators = map[string]Func{ - "required": hasValue, - "required_with": requiredWith, - "isdefault": isDefault, - "len": hasLengthOf, - "min": hasMinOf, - "max": hasMaxOf, - "eq": isEq, - "ne": isNe, - "lt": isLt, - "lte": isLte, - "gt": isGt, - "gte": isGte, - "eqfield": isEqField, - "eqcsfield": isEqCrossStructField, - "necsfield": isNeCrossStructField, - "gtcsfield": isGtCrossStructField, - "gtecsfield": isGteCrossStructField, - "ltcsfield": isLtCrossStructField, - "ltecsfield": isLteCrossStructField, - "nefield": isNeField, - "gtefield": isGteField, - "gtfield": isGtField, - "ltefield": isLteField, - "ltfield": isLtField, - "fieldcontains": fieldContains, - "fieldexcludes": fieldExcludes, - "alpha": isAlpha, - "alphanum": isAlphanum, - "alphaunicode": isAlphaUnicode, - "alphanumunicode": isAlphanumUnicode, - "numeric": isNumeric, - "number": isNumber, - "hexadecimal": isHexadecimal, - "hexcolor": isHEXColor, - "rgb": isRGB, - "rgba": isRGBA, - "hsl": isHSL, - "hsla": isHSLA, - "email": isEmail, - "url": isURL, - "uri": isURI, - "urn_rfc2141": isUrnRFC2141, // RFC 2141 - "file": isFile, - "base64": isBase64, - "base64url": isBase64URL, - "contains": contains, - "containsany": containsAny, - "containsrune": containsRune, - "excludes": excludes, - "excludesall": excludesAll, - "excludesrune": excludesRune, - "startswith": startsWith, - "endswith": endsWith, - "isbn": isISBN, - "isbn10": isISBN10, - "isbn13": isISBN13, - "eth_addr": isEthereumAddress, - "btc_addr": isBitcoinAddress, - "btc_addr_bech32": isBitcoinBech32Address, - "uuid": isUUID, - "uuid3": isUUID3, - "uuid4": isUUID4, - "uuid5": isUUID5, - "uuid_rfc4122": isUUIDRFC4122, - "uuid3_rfc4122": isUUID3RFC4122, - "uuid4_rfc4122": isUUID4RFC4122, - "uuid5_rfc4122": isUUID5RFC4122, - "ascii": isASCII, - "printascii": isPrintableASCII, - "multibyte": hasMultiByteCharacter, - "datauri": isDataURI, - "latitude": isLatitude, - "longitude": isLongitude, - "ssn": isSSN, - "ipv4": isIPv4, - "ipv6": isIPv6, - "ip": isIP, - "cidrv4": isCIDRv4, - "cidrv6": isCIDRv6, - "cidr": isCIDR, - "tcp4_addr": isTCP4AddrResolvable, - "tcp6_addr": isTCP6AddrResolvable, - "tcp_addr": isTCPAddrResolvable, - "udp4_addr": isUDP4AddrResolvable, - "udp6_addr": isUDP6AddrResolvable, - "udp_addr": isUDPAddrResolvable, - "ip4_addr": isIP4AddrResolvable, - "ip6_addr": isIP6AddrResolvable, - "ip_addr": isIPAddrResolvable, - "unix_addr": isUnixAddrResolvable, - "mac": isMAC, - "hostname": isHostnameRFC952, // RFC 952 - "hostname_rfc1123": isHostnameRFC1123, // RFC 1123 - "fqdn": isFQDN, - "unique": isUnique, - "oneof": isOneOf, - "html": isHTML, - "html_encoded": isHTMLEncoded, - "url_encoded": isURLEncoded, - "dir": isDir, + "required": hasValue, + "required_with": requiredWith, + "required_with_all": requiredWithAll, + "isdefault": isDefault, + "len": hasLengthOf, + "min": hasMinOf, + "max": hasMaxOf, + "eq": isEq, + "ne": isNe, + "lt": isLt, + "lte": isLte, + "gt": isGt, + "gte": isGte, + "eqfield": isEqField, + "eqcsfield": isEqCrossStructField, + "necsfield": isNeCrossStructField, + "gtcsfield": isGtCrossStructField, + "gtecsfield": isGteCrossStructField, + "ltcsfield": isLtCrossStructField, + "ltecsfield": isLteCrossStructField, + "nefield": isNeField, + "gtefield": isGteField, + "gtfield": isGtField, + "ltefield": isLteField, + "ltfield": isLtField, + "fieldcontains": fieldContains, + "fieldexcludes": fieldExcludes, + "alpha": isAlpha, + "alphanum": isAlphanum, + "alphaunicode": isAlphaUnicode, + "alphanumunicode": isAlphanumUnicode, + "numeric": isNumeric, + "number": isNumber, + "hexadecimal": isHexadecimal, + "hexcolor": isHEXColor, + "rgb": isRGB, + "rgba": isRGBA, + "hsl": isHSL, + "hsla": isHSLA, + "email": isEmail, + "url": isURL, + "uri": isURI, + "urn_rfc2141": isUrnRFC2141, // RFC 2141 + "file": isFile, + "base64": isBase64, + "base64url": isBase64URL, + "contains": contains, + "containsany": containsAny, + "containsrune": containsRune, + "excludes": excludes, + "excludesall": excludesAll, + "excludesrune": excludesRune, + "startswith": startsWith, + "endswith": endsWith, + "isbn": isISBN, + "isbn10": isISBN10, + "isbn13": isISBN13, + "eth_addr": isEthereumAddress, + "btc_addr": isBitcoinAddress, + "btc_addr_bech32": isBitcoinBech32Address, + "uuid": isUUID, + "uuid3": isUUID3, + "uuid4": isUUID4, + "uuid5": isUUID5, + "uuid_rfc4122": isUUIDRFC4122, + "uuid3_rfc4122": isUUID3RFC4122, + "uuid4_rfc4122": isUUID4RFC4122, + "uuid5_rfc4122": isUUID5RFC4122, + "ascii": isASCII, + "printascii": isPrintableASCII, + "multibyte": hasMultiByteCharacter, + "datauri": isDataURI, + "latitude": isLatitude, + "longitude": isLongitude, + "ssn": isSSN, + "ipv4": isIPv4, + "ipv6": isIPv6, + "ip": isIP, + "cidrv4": isCIDRv4, + "cidrv6": isCIDRv6, + "cidr": isCIDR, + "tcp4_addr": isTCP4AddrResolvable, + "tcp6_addr": isTCP6AddrResolvable, + "tcp_addr": isTCPAddrResolvable, + "udp4_addr": isUDP4AddrResolvable, + "udp6_addr": isUDP6AddrResolvable, + "udp_addr": isUDPAddrResolvable, + "ip4_addr": isIP4AddrResolvable, + "ip6_addr": isIP6AddrResolvable, + "ip_addr": isIPAddrResolvable, + "unix_addr": isUnixAddrResolvable, + "mac": isMAC, + "hostname": isHostnameRFC952, // RFC 952 + "hostname_rfc1123": isHostnameRFC1123, // RFC 1123 + "fqdn": isFQDN, + "unique": isUnique, + "oneof": isOneOf, + "html": isHTML, + "html_encoded": isHTMLEncoded, + "url_encoded": isURLEncoded, + "dir": isDir, } ) @@ -1315,7 +1316,7 @@ func hasValue(fl FieldLevel) bool { } // RequiredWith is the validation function -// the field under validation must be present and not empty only if any of the other specified fields are present. +// The field under validation must be present and not empty only if any of the other specified fields are present. func requiredWith(fl FieldLevel) bool { field := fl.Field() @@ -1329,6 +1330,9 @@ func requiredWith(fl FieldLevel) bool { case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func: isParamFieldPresent = !paramField.IsNil() default: + if fl.(*validate).fldIsPointer && paramField.Interface() != nil { + isParamFieldPresent = true + } isParamFieldPresent = paramField.IsValid() && paramField.Interface() != reflect.Zero(field.Type()).Interface() } @@ -1348,6 +1352,47 @@ func requiredWith(fl FieldLevel) bool { return true } +// RequiredWithAll is the validation function +// The field under validation must be present and not empty only if all of the other specified fields are present. +func requiredWithAll(fl FieldLevel) bool { + + field := fl.Field() + isValidateCurrentField := true + params := parseOneOfParam2(fl.Param()) + for _, param := range params { + isParamFieldPresent := false + + paramField := fl.Parent().FieldByName(param) + + switch paramField.Kind() { + case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func: + isParamFieldPresent = !paramField.IsNil() + default: + if fl.(*validate).fldIsPointer && paramField.Interface() != nil { + isParamFieldPresent = true + } + isParamFieldPresent = paramField.IsValid() && paramField.Interface() != reflect.Zero(field.Type()).Interface() + } + if !isParamFieldPresent { + isValidateCurrentField = isParamFieldPresent + } + } + + if isValidateCurrentField { + switch field.Kind() { + case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func: + return !field.IsNil() + default: + if fl.(*validate).fldIsPointer && field.Interface() != nil { + return true + } + return field.IsValid() && field.Interface() != reflect.Zero(field.Type()).Interface() + } + } + + return true +} + // IsGteField is the validation function for validating if the current field's value is greater than or equal to the field specified by the param's value. func isGteField(fl FieldLevel) bool { diff --git a/validator_test.go b/validator_test.go index 825d515..244d0c9 100644 --- a/validator_test.go +++ b/validator_test.go @@ -8639,3 +8639,24 @@ func TestRequiredWith(t *testing.T) { t.Fatalf("failed Error: %s", errs) } } + +func TestRequiredWithAll(t *testing.T) { + + test := struct { + Field1 string `validate:"omitempty" json:"field_1"` + Field2 string `validate:"omitempty" json:"field_2"` + Field3 string `validate:"required_with_all=Field1 Field2" json:"field_3"` + }{ + Field1: "test_field1", + Field2: "test_field2", + Field3: "test_field3", + } + + validate := New() + + errs := validate.Struct(test) + + if errs != nil { + t.Fatalf("failed Error: %s", errs) + } +}