feat: custom type func support interface

pull/868/head
september 3 years ago committed by september
parent c7e0172e0f
commit 886822a5f7
  1. 39
      util.go
  2. 19
      validator_instance.go
  3. 39
      validator_test.go

@ -42,8 +42,7 @@ BEGIN:
default: default:
if v.v.hasCustomFuncs { if v.v.hasCustomFuncs {
if fn, ok := v.findCustomTypeFunc(current.Type()); ok {
if fn, ok := v.v.customFuncs[current.Type()]; ok {
current = reflect.ValueOf(fn(current)) current = reflect.ValueOf(fn(current))
goto BEGIN goto BEGIN
} }
@ -53,6 +52,42 @@ BEGIN:
} }
} }
func (v *validate) findCustomTypeFunc(typ reflect.Type) (CustomTypeFunc, bool) {
// fast path
if fn, ok := v.v.customFuncs.Load(typ); ok {
if fn == nil {
return nil, false
}
return fn.(CustomTypeFunc), true
}
// slow path
var fn CustomTypeFunc
var found bool
// iterate the customFuncs to find if the typ implement any interface type registered
v.v.customFuncs.Range(func(key, value interface{}) bool {
keyTyp := key.(reflect.Type)
if keyTyp.Kind() != reflect.Interface {
return true
}
if typ.Implements(keyTyp) {
fn = value.(CustomTypeFunc)
found = true
return false
}
return true
})
if found {
v.v.customFuncs.Store(typ, fn)
} else {
v.v.customFuncs.Store(typ, nil)
}
return fn, found
}
// getStructFieldOKInternal traverses a struct to retrieve a specific field denoted by the provided namespace and // getStructFieldOKInternal traverses a struct to retrieve a specific field denoted by the provided namespace and
// returns the field, field kind and whether is was successful in retrieving the field at all. // returns the field, field kind and whether is was successful in retrieving the field at all.
// //

@ -82,7 +82,7 @@ type Validate struct {
hasTagNameFunc bool hasTagNameFunc bool
tagNameFunc TagNameFunc tagNameFunc TagNameFunc
structLevelFuncs map[reflect.Type]StructLevelFuncCtx structLevelFuncs map[reflect.Type]StructLevelFuncCtx
customFuncs map[reflect.Type]CustomTypeFunc customFuncs sync.Map
aliases map[string]string aliases map[string]string
validations map[string]internalValidationFuncWrapper validations map[string]internalValidationFuncWrapper
transTagFunc map[ut.Translator]map[string]TranslationFunc // map[<locale>]map[<tag>]TranslationFunc transTagFunc map[ut.Translator]map[string]TranslationFunc // map[<locale>]map[<tag>]TranslationFunc
@ -152,7 +152,7 @@ func (v *Validate) SetTagName(name string) {
// ValidateMapCtx validates a map using a map of validation rules and allows passing of contextual // ValidateMapCtx validates a map using a map of validation rules and allows passing of contextual
// validation validation information via context.Context. // validation validation information via context.Context.
func (v Validate) ValidateMapCtx(ctx context.Context, data map[string]interface{}, rules map[string]interface{}) map[string]interface{} { func (v *Validate) ValidateMapCtx(ctx context.Context, data map[string]interface{}, rules map[string]interface{}) map[string]interface{} {
errs := make(map[string]interface{}) errs := make(map[string]interface{})
for field, rule := range rules { for field, rule := range rules {
if ruleObj, ok := rule.(map[string]interface{}); ok { if ruleObj, ok := rule.(map[string]interface{}); ok {
@ -317,12 +317,17 @@ func (v *Validate) RegisterStructValidationMapRules(rules map[string]string, typ
// NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation // NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
func (v *Validate) RegisterCustomTypeFunc(fn CustomTypeFunc, types ...interface{}) { func (v *Validate) RegisterCustomTypeFunc(fn CustomTypeFunc, types ...interface{}) {
if v.customFuncs == nil {
v.customFuncs = make(map[reflect.Type]CustomTypeFunc)
}
for _, t := range types { for _, t := range types {
v.customFuncs[reflect.TypeOf(t)] = fn var rt reflect.Type
switch tt := t.(type) {
case reflect.Type:
rt = tt
case *reflect.Type:
rt = *tt
default:
rt = reflect.TypeOf(t)
}
v.customFuncs.Store(rt, fn)
} }
v.hasCustomFuncs = true v.hasCustomFuncs = true

@ -2215,6 +2215,45 @@ func TestExistsValidation(t *testing.T) {
Equal(t, errs, nil) Equal(t, errs, nil)
} }
func TestSQLValue3Validation(t *testing.T) {
validate := New()
validate.RegisterCustomTypeFunc(ValidateValuerType, reflect.TypeOf((*driver.Valuer)(nil)).Elem())
val := valuer{
Name: "",
}
errs := validate.Var(val, "required")
NotEqual(t, errs, nil)
AssertError(t, errs, "", "", "", "", "required")
val.Name = "Valid Name"
errs = validate.VarCtx(context.Background(), val, "required")
Equal(t, errs, nil)
val.Name = "errorme"
PanicMatches(t, func() { _ = validate.Var(val, "required") }, "SQL Driver Valuer error: some kind of error")
myVal := valuer{
Name: "",
}
errs = validate.Var(myVal, "required")
NotEqual(t, errs, nil)
AssertError(t, errs, "", "", "", "", "required")
intVal := sql.NullInt64{}
errs = validate.Var(intVal, "required")
NotEqual(t, errs, nil)
AssertError(t, errs, "", "", "", "", "required")
intVal.Int64 = 10
intVal.Valid = true
errs = validate.Var(intVal, "required")
Equal(t, errs, nil)
}
func TestSQLValue2Validation(t *testing.T) { func TestSQLValue2Validation(t *testing.T) {
validate := New() validate := New()
validate.RegisterCustomTypeFunc(ValidateValuerType, valuer{}, (*driver.Valuer)(nil), sql.NullString{}, sql.NullInt64{}, sql.NullBool{}, sql.NullFloat64{}) validate.RegisterCustomTypeFunc(ValidateValuerType, valuer{}, (*driver.Valuer)(nil), sql.NullString{}, sql.NullInt64{}, sql.NullBool{}, sql.NullFloat64{})

Loading…
Cancel
Save