Go 1.8 context supports for struct validation

pull/293/head
Thao Nguyen 7 years ago
parent fb68f39656
commit 75162c2da6
  1. 9
      baked_in.go
  2. 2
      cache.go
  3. 17
      validator.go
  4. 47
      validator_instance.go
  5. 8
      validator_test.go

@ -1,6 +1,7 @@
package validator package validator
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
@ -16,6 +17,14 @@ import (
// fieldType = fields // fieldType = fields
// param = parameter used in validation i.e. gt=0 param would be 0 // param = parameter used in validation i.e. gt=0 param would be 0
type Func func(fl FieldLevel) bool type Func func(fl FieldLevel) bool
type FuncCtx func(ctx context.Context, fl FieldLevel) bool
// wrapFunc make Func compatible with FuncCtx
func wrapFunc(fn Func) FuncCtx {
return func(ctx context.Context, fl FieldLevel) bool {
return fn(fl)
}
}
var ( var (
restrictedTags = map[string]struct{}{ restrictedTags = map[string]struct{}{

@ -90,7 +90,7 @@ type cTag struct {
hasAlias bool hasAlias bool
typeof tagType typeof tagType
hasTag bool hasTag bool
fn Func fn FuncCtx
next *cTag next *cTag
} }

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"context"
) )
// per validate contruct // per validate contruct
@ -34,7 +35,7 @@ type validate struct {
} }
// parent and current will be the same the first run of validateStruct // parent and current will be the same the first run of validateStruct
func (v *validate) validateStruct(parent reflect.Value, current reflect.Value, typ reflect.Type, ns []byte, structNs []byte, ct *cTag) { func (v *validate) validateStruct(ctx context.Context, parent reflect.Value, current reflect.Value, typ reflect.Type, ns []byte, structNs []byte, ct *cTag) {
cs, ok := v.v.structCache.Get(typ) cs, ok := v.v.structCache.Get(typ)
if !ok { if !ok {
@ -78,7 +79,7 @@ func (v *validate) validateStruct(parent reflect.Value, current reflect.Value, t
} }
} }
v.traverseField(parent, current.Field(f.idx), ns, structNs, f, f.cTags) v.traverseField(ctx, parent, current.Field(f.idx), ns, structNs, f, f.cTags)
} }
} }
@ -97,7 +98,7 @@ func (v *validate) validateStruct(parent reflect.Value, current reflect.Value, t
} }
// traverseField validates any field, be it a struct or single field, ensures it's validity and passes it along to be validated via it's tag options // traverseField validates any field, be it a struct or single field, ensures it's validity and passes it along to be validated via it's tag options
func (v *validate) traverseField(parent reflect.Value, current reflect.Value, ns []byte, structNs []byte, cf *cField, ct *cTag) { func (v *validate) traverseField(ctx context.Context, parent reflect.Value, current reflect.Value, ns []byte, structNs []byte, cf *cField, ct *cTag) {
var typ reflect.Type var typ reflect.Type
var kind reflect.Kind var kind reflect.Kind
@ -192,7 +193,7 @@ func (v *validate) traverseField(parent reflect.Value, current reflect.Value, ns
structNs = append(append(structNs, cf.name...), '.') structNs = append(append(structNs, cf.name...), '.')
} }
v.validateStruct(current, current, typ, ns, structNs, ct) v.validateStruct(ctx, current, current, typ, ns, structNs, ct)
return return
} }
} }
@ -261,7 +262,7 @@ OUTER:
reusableCF.altName = string(v.misc) reusableCF.altName = string(v.misc)
} }
v.traverseField(parent, current.Index(i), ns, structNs, reusableCF, ct) v.traverseField(ctx, parent, current.Index(i), ns, structNs, reusableCF, ct)
} }
case reflect.Map: case reflect.Map:
@ -291,7 +292,7 @@ OUTER:
reusableCF.altName = string(v.misc) reusableCF.altName = string(v.misc)
} }
v.traverseField(parent, current.MapIndex(key), ns, structNs, reusableCF, ct) v.traverseField(ctx, parent, current.MapIndex(key), ns, structNs, reusableCF, ct)
} }
default: default:
@ -314,7 +315,7 @@ OUTER:
v.cf = cf v.cf = cf
v.ct = ct v.ct = ct
if ct.fn(v) { if ct.fn(ctx, v) {
// drain rest of the 'or' values, then continue or leave // drain rest of the 'or' values, then continue or leave
for { for {
@ -407,7 +408,7 @@ OUTER:
// v.ns = ns // v.ns = ns
// v.actualNs = structNs // v.actualNs = structNs
if !ct.fn(v) { if !ct.fn(ctx, v) {
v.str1 = string(append(ns, cf.altName...)) v.str1 = string(append(ns, cf.altName...))

@ -1,6 +1,7 @@
package validator package validator
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@ -61,7 +62,7 @@ type Validate struct {
structLevelFuncs map[reflect.Type]StructLevelFunc structLevelFuncs map[reflect.Type]StructLevelFunc
customFuncs map[reflect.Type]CustomTypeFunc customFuncs map[reflect.Type]CustomTypeFunc
aliases map[string]string aliases map[string]string
validations map[string]Func validations map[string]FuncCtx
transTagFunc map[ut.Translator]map[string]TranslationFunc // map[<locale>]map[<tag>]TranslationFunc transTagFunc map[ut.Translator]map[string]TranslationFunc // map[<locale>]map[<tag>]TranslationFunc
tagCache *tagCache tagCache *tagCache
structCache *structCache structCache *structCache
@ -79,7 +80,7 @@ func New() *Validate {
v := &Validate{ v := &Validate{
tagName: defaultTagName, tagName: defaultTagName,
aliases: make(map[string]string, len(bakedInAliases)), aliases: make(map[string]string, len(bakedInAliases)),
validations: make(map[string]Func, len(bakedInValidators)), validations: make(map[string]FuncCtx, len(bakedInValidators)),
tagCache: tc, tagCache: tc,
structCache: sc, structCache: sc,
} }
@ -93,7 +94,7 @@ func New() *Validate {
for k, val := range bakedInValidators { for k, val := range bakedInValidators {
// no need to error check here, baked in will alwaays be valid // no need to error check here, baked in will alwaays be valid
v.registerValidation(k, val, true) v.registerValidation(k, wrapFunc(val), true)
} }
v.pool = &sync.Pool{ v.pool = &sync.Pool{
@ -128,10 +129,15 @@ func (v *Validate) RegisterTagNameFunc(fn TagNameFunc) {
// - if the key already exists, the previous validation function will be replaced. // - if the key already exists, the previous validation function will be replaced.
// - this method is not thread-safe it is intended that these all be registered prior to any validation // - this method is not thread-safe it is intended that these all be registered prior to any validation
func (v *Validate) RegisterValidation(tag string, fn Func) error { func (v *Validate) RegisterValidation(tag string, fn Func) error {
return v.registerValidation(tag, wrapFunc(fn), false)
}
// RegisterValidationCtx adds a validation which supports context.Context
func (v *Validate) RegisterValidationCtx(tag string, fn FuncCtx) error {
return v.registerValidation(tag, fn, false) return v.registerValidation(tag, fn, false)
} }
func (v *Validate) registerValidation(tag string, fn Func, bakedIn bool) error { func (v *Validate) registerValidation(tag string, fn FuncCtx, bakedIn bool) error {
if len(tag) == 0 { if len(tag) == 0 {
return errors.New("Function Key cannot be empty") return errors.New("Function Key cannot be empty")
@ -225,11 +231,8 @@ func (v *Validate) RegisterTranslation(tag string, trans ut.Translator, register
return return
} }
// Struct validates a structs exposed fields, and automatically validates nested structs, unless otherwise specified. // StructCtx go1.8 context supports Struct validation
// func (v *Validate) StructCtx(ctx context.Context, s interface{}) (err error) {
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) Struct(s interface{}) (err error) {
val := reflect.ValueOf(s) val := reflect.ValueOf(s)
top := val top := val
@ -248,7 +251,7 @@ func (v *Validate) Struct(s interface{}) (err error) {
vd.isPartial = false vd.isPartial = false
// vd.hasExcludes = false // only need to reset in StructPartial and StructExcept // vd.hasExcludes = false // only need to reset in StructPartial and StructExcept
vd.validateStruct(top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil) vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
if len(vd.errs) > 0 { if len(vd.errs) > 0 {
err = vd.errs err = vd.errs
@ -260,6 +263,15 @@ func (v *Validate) Struct(s interface{}) (err error) {
return return
} }
// Struct validates a structs exposed fields, and automatically validates nested structs, unless otherwise specified.
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) Struct(s interface{}) (err error) {
ctx := context.Background()
return v.StructCtx(ctx, s)
}
// StructFiltered validates a structs exposed fields, that pass the FilterFunc check and automatically validates // StructFiltered validates a structs exposed fields, that pass the FilterFunc check and automatically validates
// nested structs, unless otherwise specified. // nested structs, unless otherwise specified.
// //
@ -267,6 +279,7 @@ func (v *Validate) Struct(s interface{}) (err error) {
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) StructFiltered(s interface{}, fn FilterFunc) (err error) { func (v *Validate) StructFiltered(s interface{}, fn FilterFunc) (err error) {
ctx := context.Background()
val := reflect.ValueOf(s) val := reflect.ValueOf(s)
top := val top := val
@ -285,7 +298,7 @@ func (v *Validate) StructFiltered(s interface{}, fn FilterFunc) (err error) {
vd.ffn = fn vd.ffn = fn
// vd.hasExcludes = false // only need to reset in StructPartial and StructExcept // vd.hasExcludes = false // only need to reset in StructPartial and StructExcept
vd.validateStruct(top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil) vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
if len(vd.errs) > 0 { if len(vd.errs) > 0 {
err = vd.errs err = vd.errs
@ -305,6 +318,7 @@ func (v *Validate) StructFiltered(s interface{}, fn FilterFunc) (err error) {
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) StructPartial(s interface{}, fields ...string) (err error) { func (v *Validate) StructPartial(s interface{}, fields ...string) (err error) {
ctx := context.Background()
val := reflect.ValueOf(s) val := reflect.ValueOf(s)
top := val top := val
@ -364,7 +378,7 @@ func (v *Validate) StructPartial(s interface{}, fields ...string) (err error) {
} }
} }
vd.validateStruct(top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil) vd.validateStruct(ctx, top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil)
if len(vd.errs) > 0 { if len(vd.errs) > 0 {
err = vd.errs err = vd.errs
@ -384,6 +398,7 @@ func (v *Validate) StructPartial(s interface{}, fields ...string) (err error) {
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) StructExcept(s interface{}, fields ...string) (err error) { func (v *Validate) StructExcept(s interface{}, fields ...string) (err error) {
ctx := context.Background()
val := reflect.ValueOf(s) val := reflect.ValueOf(s)
top := val top := val
@ -419,7 +434,7 @@ func (v *Validate) StructExcept(s interface{}, fields ...string) (err error) {
vd.includeExclude[string(vd.misc)] = struct{}{} vd.includeExclude[string(vd.misc)] = struct{}{}
} }
vd.validateStruct(top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil) vd.validateStruct(ctx, top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil)
if len(vd.errs) > 0 { if len(vd.errs) > 0 {
err = vd.errs err = vd.errs
@ -445,6 +460,7 @@ func (v *Validate) StructExcept(s interface{}, fields ...string) (err error) {
// validate Array, Slice and maps fields which may contain more than one error // validate Array, Slice and maps fields which may contain more than one error
func (v *Validate) Var(field interface{}, tag string) (err error) { func (v *Validate) Var(field interface{}, tag string) (err error) {
ctx := context.Background()
if len(tag) == 0 || tag == skipValidationTag { if len(tag) == 0 || tag == skipValidationTag {
return nil return nil
} }
@ -470,7 +486,7 @@ func (v *Validate) Var(field interface{}, tag string) (err error) {
vd.top = val vd.top = val
vd.isPartial = false vd.isPartial = false
vd.traverseField(val, val, vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag) vd.traverseField(ctx, val, val, vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag)
if len(vd.errs) > 0 { if len(vd.errs) > 0 {
err = vd.errs err = vd.errs
@ -497,6 +513,7 @@ func (v *Validate) Var(field interface{}, tag string) (err error) {
// validate Array, Slice and maps fields which may contain more than one error // validate Array, Slice and maps fields which may contain more than one error
func (v *Validate) VarWithValue(field interface{}, other interface{}, tag string) (err error) { func (v *Validate) VarWithValue(field interface{}, other interface{}, tag string) (err error) {
ctx := context.Background()
if len(tag) == 0 || tag == skipValidationTag { if len(tag) == 0 || tag == skipValidationTag {
return nil return nil
} }
@ -522,7 +539,7 @@ func (v *Validate) VarWithValue(field interface{}, other interface{}, tag string
vd.top = otherVal vd.top = otherVal
vd.isPartial = false vd.isPartial = false
vd.traverseField(otherVal, reflect.ValueOf(field), vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag) vd.traverseField(ctx, otherVal, reflect.ValueOf(field), vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag)
if len(vd.errs) > 0 { if len(vd.errs) > 0 {
err = vd.errs err = vd.errs

@ -2,6 +2,7 @@ package validator
import ( import (
"bytes" "bytes"
"context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
@ -5127,6 +5128,10 @@ func TestAddFunctions(t *testing.T) {
return true return true
} }
fnCtx := func(ctx context.Context, fl FieldLevel) bool {
return true
}
validate := New() validate := New()
errs := validate.RegisterValidation("new", fn) errs := validate.RegisterValidation("new", fn)
@ -5141,6 +5146,9 @@ func TestAddFunctions(t *testing.T) {
errs = validate.RegisterValidation("new", fn) errs = validate.RegisterValidation("new", fn)
Equal(t, errs, nil) Equal(t, errs, nil)
errs = validate.RegisterValidationCtx("new", fnCtx)
Equal(t, errs, nil)
PanicMatches(t, func() { validate.RegisterValidation("dive", fn) }, "Tag 'dive' either contains restricted characters or is the same as a restricted tag needed for normal operation") PanicMatches(t, func() { validate.RegisterValidation("dive", fn) }, "Tag 'dive' either contains restricted characters or is the same as a restricted tag needed for normal operation")
} }

Loading…
Cancel
Save