diff --git a/baked_in.go b/baked_in.go index 28cd1be..a482196 100644 --- a/baked_in.go +++ b/baked_in.go @@ -6,7 +6,9 @@ import ( "net" "net/url" "reflect" + "strconv" "strings" + "sync" "time" "unicode/utf8" ) @@ -135,9 +137,44 @@ var ( "hostname_rfc1123": isHostnameRFC1123, // RFC 1123 "fqdn": isFQDN, "unique": isUnique, + "oneof": isOneOf, } ) +var oneofValCache = map[string][]string{} +var oneofValCacheLock = sync.Mutex{} + +func isOneOf(fl FieldLevel) bool { + param := fl.Param() + oneofValCacheLock.Lock() + vals, ok := oneofValCache[param] + if !ok { + vals = strings.Fields(param) + oneofValCache[param] = vals + } + oneofValCacheLock.Unlock() + + field := fl.Field() + + var v string + switch field.Kind() { + case reflect.String: + v = field.String() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v = strconv.FormatInt(field.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v = strconv.FormatUint(field.Uint(), 10) + default: + panic(fmt.Sprintf("Bad field type %T", field.Interface())) + } + for i := 0; i < len(vals); i++ { + if vals[i] == v { + return true + } + } + return false +} + // isUnique is the validation function for validating if each array|slice element is unique func isUnique(fl FieldLevel) bool { diff --git a/doc.go b/doc.go index 5d26131..f7efe23 100644 --- a/doc.go +++ b/doc.go @@ -295,6 +295,16 @@ validates the number of items. Usage: ne=10 +One Of + +For strings, ints, and uints, oneof will ensure that the value +is one of the values in the parameter. The parameter should be +a list of values separated by whitespace. Values may be +strings or numbers. + + Usage: oneof=red green + oneof=5 7 9 + Greater Than For numbers, this will ensure that the value is greater than the diff --git a/translations/en/en.go b/translations/en/en.go index 551401f..3ac1c47 100644 --- a/translations/en/en.go +++ b/translations/en/en.go @@ -1294,6 +1294,19 @@ func RegisterDefaultTranslations(v *validator.Validate, trans ut.Translator) (er translation: "{0} must be a valid color", override: false, }, + { + tag: "oneof", + translation: "{0} must be one of [{1}]", + override: false, + customTransFunc: func(ut ut.Translator, fe validator.FieldError) string { + s, err := ut.T(fe.Tag(), fe.Field(), fe.Param()) + if err != nil { + log.Printf("warning: error translating FieldError: %#v", fe) + return fe.(error).Error() + } + return s + }, + }, } for _, t := range translations { diff --git a/translations/en/en_test.go b/translations/en/en_test.go index 93a4fef..1cb44d0 100644 --- a/translations/en/en_test.go +++ b/translations/en/en_test.go @@ -136,6 +136,8 @@ func TestTranslations(t *testing.T) { StrPtrLte *string `validate:"lte=1"` StrPtrGt *string `validate:"gt=10"` StrPtrGte *string `validate:"gte=10"` + OneOfString string `validate:"oneof=red green"` + OneOfInt int `validate:"oneof=5 63"` } var test Test @@ -604,6 +606,14 @@ func TestTranslations(t *testing.T) { ns: "Test.StrPtrGte", expected: "StrPtrGte must be at least 10 characters in length", }, + { + ns: "Test.OneOfString", + expected: "OneOfString must be one of [red green]", + }, + { + ns: "Test.OneOfInt", + expected: "OneOfInt must be one of [5 63]", + }, } for _, tt := range tests { diff --git a/validator_test.go b/validator_test.go index 02c1776..9600d0f 100644 --- a/validator_test.go +++ b/validator_test.go @@ -4314,6 +4314,66 @@ func TestIsEqValidation(t *testing.T) { PanicMatches(t, func() { validate.Var(now, "eq=now") }, "Bad field type time.Time") } +func TestOneOfValidation(t *testing.T) { + validate := New() + + passSpecs := []struct { + f interface{} + t string + }{ + {f: "red", t: "oneof=red green"}, + {f: "green", t: "oneof=red green"}, + {f: 5, t: "oneof=5 6"}, + {f: 6, t: "oneof=5 6"}, + {f: int8(6), t: "oneof=5 6"}, + {f: int16(6), t: "oneof=5 6"}, + {f: int32(6), t: "oneof=5 6"}, + {f: int64(6), t: "oneof=5 6"}, + {f: uint(6), t: "oneof=5 6"}, + {f: uint8(6), t: "oneof=5 6"}, + {f: uint16(6), t: "oneof=5 6"}, + {f: uint32(6), t: "oneof=5 6"}, + {f: uint64(6), t: "oneof=5 6"}, + } + + for _, spec := range passSpecs { + t.Logf("%#v", spec) + errs := validate.Var(spec.f, spec.t) + Equal(t, errs, nil) + } + + failSpecs := []struct { + f interface{} + t string + }{ + {f: "", t: "oneof=red green"}, + {f: "yellow", t: "oneof=red green"}, + {f: 5, t: "oneof=red green"}, + {f: 6, t: "oneof=red green"}, + {f: 6, t: "oneof=7"}, + {f: uint(6), t: "oneof=7"}, + {f: int8(5), t: "oneof=red green"}, + {f: int16(5), t: "oneof=red green"}, + {f: int32(5), t: "oneof=red green"}, + {f: int64(5), t: "oneof=red green"}, + {f: uint(5), t: "oneof=red green"}, + {f: uint8(5), t: "oneof=red green"}, + {f: uint16(5), t: "oneof=red green"}, + {f: uint32(5), t: "oneof=red green"}, + {f: uint64(5), t: "oneof=red green"}, + } + + for _, spec := range failSpecs { + t.Logf("%#v", spec) + errs := validate.Var(spec.f, spec.t) + AssertError(t, errs, "", "", "", "", "oneof") + } + + PanicMatches(t, func() { + validate.Var(3.14, "oneof=red green") + }, "Bad field type float64") +} + func TestBase64Validation(t *testing.T) { validate := New()