From 886822a5f7783d5088df6f3f63639582c0138b89 Mon Sep 17 00:00:00 2001 From: september Date: Fri, 17 Dec 2021 09:16:26 +0800 Subject: [PATCH] feat: custom type func support interface --- util.go | 39 +++++++++++++++++++++++++++++++++++++-- validator_instance.go | 19 ++++++++++++------- validator_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 9 deletions(-) diff --git a/util.go b/util.go index 36da855..0fefcaf 100644 --- a/util.go +++ b/util.go @@ -42,8 +42,7 @@ BEGIN: default: if v.v.hasCustomFuncs { - - if fn, ok := v.v.customFuncs[current.Type()]; ok { + if fn, ok := v.findCustomTypeFunc(current.Type()); ok { current = reflect.ValueOf(fn(current)) goto BEGIN } @@ -53,6 +52,42 @@ BEGIN: } } +func (v *validate) findCustomTypeFunc(typ reflect.Type) (CustomTypeFunc, bool) { + // fast path + if fn, ok := v.v.customFuncs.Load(typ); ok { + if fn == nil { + return nil, false + } + return fn.(CustomTypeFunc), true + } + + // slow path + var fn CustomTypeFunc + var found bool + + // iterate the customFuncs to find if the typ implement any interface type registered + v.v.customFuncs.Range(func(key, value interface{}) bool { + keyTyp := key.(reflect.Type) + if keyTyp.Kind() != reflect.Interface { + return true + } + + if typ.Implements(keyTyp) { + fn = value.(CustomTypeFunc) + found = true + return false + } + return true + }) + + if found { + v.v.customFuncs.Store(typ, fn) + } else { + v.v.customFuncs.Store(typ, nil) + } + return fn, found +} + // getStructFieldOKInternal traverses a struct to retrieve a specific field denoted by the provided namespace and // returns the field, field kind and whether is was successful in retrieving the field at all. // diff --git a/validator_instance.go b/validator_instance.go index 9493da4..01ba777 100644 --- a/validator_instance.go +++ b/validator_instance.go @@ -82,7 +82,7 @@ type Validate struct { hasTagNameFunc bool tagNameFunc TagNameFunc structLevelFuncs map[reflect.Type]StructLevelFuncCtx - customFuncs map[reflect.Type]CustomTypeFunc + customFuncs sync.Map aliases map[string]string validations map[string]internalValidationFuncWrapper transTagFunc map[ut.Translator]map[string]TranslationFunc // map[]map[]TranslationFunc @@ -152,7 +152,7 @@ func (v *Validate) SetTagName(name string) { // ValidateMapCtx validates a map using a map of validation rules and allows passing of contextual // validation validation information via context.Context. -func (v Validate) ValidateMapCtx(ctx context.Context, data map[string]interface{}, rules map[string]interface{}) map[string]interface{} { +func (v *Validate) ValidateMapCtx(ctx context.Context, data map[string]interface{}, rules map[string]interface{}) map[string]interface{} { errs := make(map[string]interface{}) for field, rule := range rules { if ruleObj, ok := rule.(map[string]interface{}); ok { @@ -317,12 +317,17 @@ func (v *Validate) RegisterStructValidationMapRules(rules map[string]string, typ // NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation func (v *Validate) RegisterCustomTypeFunc(fn CustomTypeFunc, types ...interface{}) { - if v.customFuncs == nil { - v.customFuncs = make(map[reflect.Type]CustomTypeFunc) - } - for _, t := range types { - v.customFuncs[reflect.TypeOf(t)] = fn + var rt reflect.Type + switch tt := t.(type) { + case reflect.Type: + rt = tt + case *reflect.Type: + rt = *tt + default: + rt = reflect.TypeOf(t) + } + v.customFuncs.Store(rt, fn) } v.hasCustomFuncs = true diff --git a/validator_test.go b/validator_test.go index 7e314d6..88bf534 100644 --- a/validator_test.go +++ b/validator_test.go @@ -2215,6 +2215,45 @@ func TestExistsValidation(t *testing.T) { Equal(t, errs, nil) } +func TestSQLValue3Validation(t *testing.T) { + validate := New() + validate.RegisterCustomTypeFunc(ValidateValuerType, reflect.TypeOf((*driver.Valuer)(nil)).Elem()) + + val := valuer{ + Name: "", + } + + errs := validate.Var(val, "required") + NotEqual(t, errs, nil) + AssertError(t, errs, "", "", "", "", "required") + + val.Name = "Valid Name" + errs = validate.VarCtx(context.Background(), val, "required") + Equal(t, errs, nil) + + val.Name = "errorme" + + PanicMatches(t, func() { _ = validate.Var(val, "required") }, "SQL Driver Valuer error: some kind of error") + + myVal := valuer{ + Name: "", + } + + errs = validate.Var(myVal, "required") + NotEqual(t, errs, nil) + AssertError(t, errs, "", "", "", "", "required") + + intVal := sql.NullInt64{} + errs = validate.Var(intVal, "required") + NotEqual(t, errs, nil) + AssertError(t, errs, "", "", "", "", "required") + + intVal.Int64 = 10 + intVal.Valid = true + errs = validate.Var(intVal, "required") + Equal(t, errs, nil) +} + func TestSQLValue2Validation(t *testing.T) { validate := New() validate.RegisterCustomTypeFunc(ValidateValuerType, valuer{}, (*driver.Valuer)(nil), sql.NullString{}, sql.NullInt64{}, sql.NullBool{}, sql.NullFloat64{})