From 75162c2da6a438f8fa962eb3b7d5ed48e4564d17 Mon Sep 17 00:00:00 2001 From: Thao Nguyen Date: Sun, 23 Jul 2017 23:29:27 +0700 Subject: [PATCH] Go 1.8 context supports for struct validation --- baked_in.go | 9 +++++++++ cache.go | 2 +- validator.go | 17 ++++++++-------- validator_instance.go | 47 +++++++++++++++++++++++++++++-------------- validator_test.go | 8 ++++++++ 5 files changed, 59 insertions(+), 24 deletions(-) diff --git a/baked_in.go b/baked_in.go index 9e0b173..d8c13c9 100644 --- a/baked_in.go +++ b/baked_in.go @@ -1,6 +1,7 @@ package validator import ( + "context" "fmt" "net" "net/url" @@ -16,6 +17,14 @@ import ( // fieldType = fields // param = parameter used in validation i.e. gt=0 param would be 0 type Func func(fl FieldLevel) bool +type FuncCtx func(ctx context.Context, fl FieldLevel) bool + +// wrapFunc make Func compatible with FuncCtx +func wrapFunc(fn Func) FuncCtx { + return func(ctx context.Context, fl FieldLevel) bool { + return fn(fl) + } +} var ( restrictedTags = map[string]struct{}{ diff --git a/cache.go b/cache.go index d596bd5..a45120d 100644 --- a/cache.go +++ b/cache.go @@ -90,7 +90,7 @@ type cTag struct { hasAlias bool typeof tagType hasTag bool - fn Func + fn FuncCtx next *cTag } diff --git a/validator.go b/validator.go index 98bf265..da8a80e 100644 --- a/validator.go +++ b/validator.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" "strconv" + "context" ) // per validate contruct @@ -34,7 +35,7 @@ type validate struct { } // parent and current will be the same the first run of validateStruct -func (v *validate) validateStruct(parent reflect.Value, current reflect.Value, typ reflect.Type, ns []byte, structNs []byte, ct *cTag) { +func (v *validate) validateStruct(ctx context.Context, parent reflect.Value, current reflect.Value, typ reflect.Type, ns []byte, structNs []byte, ct *cTag) { cs, ok := v.v.structCache.Get(typ) if !ok { @@ -78,7 +79,7 @@ func (v *validate) validateStruct(parent reflect.Value, current reflect.Value, t } } - v.traverseField(parent, current.Field(f.idx), ns, structNs, f, f.cTags) + v.traverseField(ctx, parent, current.Field(f.idx), ns, structNs, f, f.cTags) } } @@ -97,7 +98,7 @@ func (v *validate) validateStruct(parent reflect.Value, current reflect.Value, t } // traverseField validates any field, be it a struct or single field, ensures it's validity and passes it along to be validated via it's tag options -func (v *validate) traverseField(parent reflect.Value, current reflect.Value, ns []byte, structNs []byte, cf *cField, ct *cTag) { +func (v *validate) traverseField(ctx context.Context, parent reflect.Value, current reflect.Value, ns []byte, structNs []byte, cf *cField, ct *cTag) { var typ reflect.Type var kind reflect.Kind @@ -192,7 +193,7 @@ func (v *validate) traverseField(parent reflect.Value, current reflect.Value, ns structNs = append(append(structNs, cf.name...), '.') } - v.validateStruct(current, current, typ, ns, structNs, ct) + v.validateStruct(ctx, current, current, typ, ns, structNs, ct) return } } @@ -261,7 +262,7 @@ OUTER: reusableCF.altName = string(v.misc) } - v.traverseField(parent, current.Index(i), ns, structNs, reusableCF, ct) + v.traverseField(ctx, parent, current.Index(i), ns, structNs, reusableCF, ct) } case reflect.Map: @@ -291,7 +292,7 @@ OUTER: reusableCF.altName = string(v.misc) } - v.traverseField(parent, current.MapIndex(key), ns, structNs, reusableCF, ct) + v.traverseField(ctx, parent, current.MapIndex(key), ns, structNs, reusableCF, ct) } default: @@ -314,7 +315,7 @@ OUTER: v.cf = cf v.ct = ct - if ct.fn(v) { + if ct.fn(ctx, v) { // drain rest of the 'or' values, then continue or leave for { @@ -407,7 +408,7 @@ OUTER: // v.ns = ns // v.actualNs = structNs - if !ct.fn(v) { + if !ct.fn(ctx, v) { v.str1 = string(append(ns, cf.altName...)) diff --git a/validator_instance.go b/validator_instance.go index 5f2cd74..213c8cc 100644 --- a/validator_instance.go +++ b/validator_instance.go @@ -1,6 +1,7 @@ package validator import ( + "context" "errors" "fmt" "reflect" @@ -61,7 +62,7 @@ type Validate struct { structLevelFuncs map[reflect.Type]StructLevelFunc customFuncs map[reflect.Type]CustomTypeFunc aliases map[string]string - validations map[string]Func + validations map[string]FuncCtx transTagFunc map[ut.Translator]map[string]TranslationFunc // map[]map[]TranslationFunc tagCache *tagCache structCache *structCache @@ -79,7 +80,7 @@ func New() *Validate { v := &Validate{ tagName: defaultTagName, aliases: make(map[string]string, len(bakedInAliases)), - validations: make(map[string]Func, len(bakedInValidators)), + validations: make(map[string]FuncCtx, len(bakedInValidators)), tagCache: tc, structCache: sc, } @@ -93,7 +94,7 @@ func New() *Validate { for k, val := range bakedInValidators { // no need to error check here, baked in will alwaays be valid - v.registerValidation(k, val, true) + v.registerValidation(k, wrapFunc(val), true) } v.pool = &sync.Pool{ @@ -128,10 +129,15 @@ func (v *Validate) RegisterTagNameFunc(fn TagNameFunc) { // - if the key already exists, the previous validation function will be replaced. // - this method is not thread-safe it is intended that these all be registered prior to any validation func (v *Validate) RegisterValidation(tag string, fn Func) error { + return v.registerValidation(tag, wrapFunc(fn), false) +} + +// RegisterValidationCtx adds a validation which supports context.Context +func (v *Validate) RegisterValidationCtx(tag string, fn FuncCtx) error { return v.registerValidation(tag, fn, false) } -func (v *Validate) registerValidation(tag string, fn Func, bakedIn bool) error { +func (v *Validate) registerValidation(tag string, fn FuncCtx, bakedIn bool) error { if len(tag) == 0 { return errors.New("Function Key cannot be empty") @@ -225,11 +231,8 @@ func (v *Validate) RegisterTranslation(tag string, trans ut.Translator, register return } -// Struct validates a structs exposed fields, and automatically validates nested structs, unless otherwise specified. -// -// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise. -// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors. -func (v *Validate) Struct(s interface{}) (err error) { +// StructCtx go1.8 context supports Struct validation +func (v *Validate) StructCtx(ctx context.Context, s interface{}) (err error) { val := reflect.ValueOf(s) top := val @@ -248,7 +251,7 @@ func (v *Validate) Struct(s interface{}) (err error) { vd.isPartial = false // vd.hasExcludes = false // only need to reset in StructPartial and StructExcept - vd.validateStruct(top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil) + vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil) if len(vd.errs) > 0 { err = vd.errs @@ -260,6 +263,15 @@ func (v *Validate) Struct(s interface{}) (err error) { return } +// Struct validates a structs exposed fields, and automatically validates nested structs, unless otherwise specified. +// +// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise. +// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors. +func (v *Validate) Struct(s interface{}) (err error) { + ctx := context.Background() + return v.StructCtx(ctx, s) +} + // StructFiltered validates a structs exposed fields, that pass the FilterFunc check and automatically validates // nested structs, unless otherwise specified. // @@ -267,6 +279,7 @@ func (v *Validate) Struct(s interface{}) (err error) { // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors. func (v *Validate) StructFiltered(s interface{}, fn FilterFunc) (err error) { + ctx := context.Background() val := reflect.ValueOf(s) top := val @@ -285,7 +298,7 @@ func (v *Validate) StructFiltered(s interface{}, fn FilterFunc) (err error) { vd.ffn = fn // vd.hasExcludes = false // only need to reset in StructPartial and StructExcept - vd.validateStruct(top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil) + vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil) if len(vd.errs) > 0 { err = vd.errs @@ -305,6 +318,7 @@ func (v *Validate) StructFiltered(s interface{}, fn FilterFunc) (err error) { // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors. func (v *Validate) StructPartial(s interface{}, fields ...string) (err error) { + ctx := context.Background() val := reflect.ValueOf(s) top := val @@ -364,7 +378,7 @@ func (v *Validate) StructPartial(s interface{}, fields ...string) (err error) { } } - vd.validateStruct(top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil) + vd.validateStruct(ctx, top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil) if len(vd.errs) > 0 { err = vd.errs @@ -384,6 +398,7 @@ func (v *Validate) StructPartial(s interface{}, fields ...string) (err error) { // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors. func (v *Validate) StructExcept(s interface{}, fields ...string) (err error) { + ctx := context.Background() val := reflect.ValueOf(s) top := val @@ -419,7 +434,7 @@ func (v *Validate) StructExcept(s interface{}, fields ...string) (err error) { vd.includeExclude[string(vd.misc)] = struct{}{} } - vd.validateStruct(top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil) + vd.validateStruct(ctx, top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil) if len(vd.errs) > 0 { err = vd.errs @@ -445,6 +460,7 @@ func (v *Validate) StructExcept(s interface{}, fields ...string) (err error) { // validate Array, Slice and maps fields which may contain more than one error func (v *Validate) Var(field interface{}, tag string) (err error) { + ctx := context.Background() if len(tag) == 0 || tag == skipValidationTag { return nil } @@ -470,7 +486,7 @@ func (v *Validate) Var(field interface{}, tag string) (err error) { vd.top = val vd.isPartial = false - vd.traverseField(val, val, vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag) + vd.traverseField(ctx, val, val, vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag) if len(vd.errs) > 0 { err = vd.errs @@ -497,6 +513,7 @@ func (v *Validate) Var(field interface{}, tag string) (err error) { // validate Array, Slice and maps fields which may contain more than one error func (v *Validate) VarWithValue(field interface{}, other interface{}, tag string) (err error) { + ctx := context.Background() if len(tag) == 0 || tag == skipValidationTag { return nil } @@ -522,7 +539,7 @@ func (v *Validate) VarWithValue(field interface{}, other interface{}, tag string vd.top = otherVal vd.isPartial = false - vd.traverseField(otherVal, reflect.ValueOf(field), vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag) + vd.traverseField(ctx, otherVal, reflect.ValueOf(field), vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag) if len(vd.errs) > 0 { err = vd.errs diff --git a/validator_test.go b/validator_test.go index c1dcb63..e7936ce 100644 --- a/validator_test.go +++ b/validator_test.go @@ -2,6 +2,7 @@ package validator import ( "bytes" + "context" "database/sql" "database/sql/driver" "encoding/json" @@ -5127,6 +5128,10 @@ func TestAddFunctions(t *testing.T) { return true } + fnCtx := func(ctx context.Context, fl FieldLevel) bool { + return true + } + validate := New() errs := validate.RegisterValidation("new", fn) @@ -5141,6 +5146,9 @@ func TestAddFunctions(t *testing.T) { errs = validate.RegisterValidation("new", fn) Equal(t, errs, nil) + errs = validate.RegisterValidationCtx("new", fnCtx) + Equal(t, errs, nil) + PanicMatches(t, func() { validate.RegisterValidation("dive", fn) }, "Tag 'dive' either contains restricted characters or is the same as a restricted tag needed for normal operation") }