Go 1.8 context supports for struct validation

pull/293/head
Thao Nguyen 7 years ago
parent fb68f39656
commit 75162c2da6
  1. 9
      baked_in.go
  2. 2
      cache.go
  3. 17
      validator.go
  4. 47
      validator_instance.go
  5. 8
      validator_test.go

@ -1,6 +1,7 @@
package validator
import (
"context"
"fmt"
"net"
"net/url"
@ -16,6 +17,14 @@ import (
// fieldType = fields
// param = parameter used in validation i.e. gt=0 param would be 0
type Func func(fl FieldLevel) bool
type FuncCtx func(ctx context.Context, fl FieldLevel) bool
// wrapFunc make Func compatible with FuncCtx
func wrapFunc(fn Func) FuncCtx {
return func(ctx context.Context, fl FieldLevel) bool {
return fn(fl)
}
}
var (
restrictedTags = map[string]struct{}{

@ -90,7 +90,7 @@ type cTag struct {
hasAlias bool
typeof tagType
hasTag bool
fn Func
fn FuncCtx
next *cTag
}

@ -4,6 +4,7 @@ import (
"fmt"
"reflect"
"strconv"
"context"
)
// per validate contruct
@ -34,7 +35,7 @@ type validate struct {
}
// parent and current will be the same the first run of validateStruct
func (v *validate) validateStruct(parent reflect.Value, current reflect.Value, typ reflect.Type, ns []byte, structNs []byte, ct *cTag) {
func (v *validate) validateStruct(ctx context.Context, parent reflect.Value, current reflect.Value, typ reflect.Type, ns []byte, structNs []byte, ct *cTag) {
cs, ok := v.v.structCache.Get(typ)
if !ok {
@ -78,7 +79,7 @@ func (v *validate) validateStruct(parent reflect.Value, current reflect.Value, t
}
}
v.traverseField(parent, current.Field(f.idx), ns, structNs, f, f.cTags)
v.traverseField(ctx, parent, current.Field(f.idx), ns, structNs, f, f.cTags)
}
}
@ -97,7 +98,7 @@ func (v *validate) validateStruct(parent reflect.Value, current reflect.Value, t
}
// traverseField validates any field, be it a struct or single field, ensures it's validity and passes it along to be validated via it's tag options
func (v *validate) traverseField(parent reflect.Value, current reflect.Value, ns []byte, structNs []byte, cf *cField, ct *cTag) {
func (v *validate) traverseField(ctx context.Context, parent reflect.Value, current reflect.Value, ns []byte, structNs []byte, cf *cField, ct *cTag) {
var typ reflect.Type
var kind reflect.Kind
@ -192,7 +193,7 @@ func (v *validate) traverseField(parent reflect.Value, current reflect.Value, ns
structNs = append(append(structNs, cf.name...), '.')
}
v.validateStruct(current, current, typ, ns, structNs, ct)
v.validateStruct(ctx, current, current, typ, ns, structNs, ct)
return
}
}
@ -261,7 +262,7 @@ OUTER:
reusableCF.altName = string(v.misc)
}
v.traverseField(parent, current.Index(i), ns, structNs, reusableCF, ct)
v.traverseField(ctx, parent, current.Index(i), ns, structNs, reusableCF, ct)
}
case reflect.Map:
@ -291,7 +292,7 @@ OUTER:
reusableCF.altName = string(v.misc)
}
v.traverseField(parent, current.MapIndex(key), ns, structNs, reusableCF, ct)
v.traverseField(ctx, parent, current.MapIndex(key), ns, structNs, reusableCF, ct)
}
default:
@ -314,7 +315,7 @@ OUTER:
v.cf = cf
v.ct = ct
if ct.fn(v) {
if ct.fn(ctx, v) {
// drain rest of the 'or' values, then continue or leave
for {
@ -407,7 +408,7 @@ OUTER:
// v.ns = ns
// v.actualNs = structNs
if !ct.fn(v) {
if !ct.fn(ctx, v) {
v.str1 = string(append(ns, cf.altName...))

@ -1,6 +1,7 @@
package validator
import (
"context"
"errors"
"fmt"
"reflect"
@ -61,7 +62,7 @@ type Validate struct {
structLevelFuncs map[reflect.Type]StructLevelFunc
customFuncs map[reflect.Type]CustomTypeFunc
aliases map[string]string
validations map[string]Func
validations map[string]FuncCtx
transTagFunc map[ut.Translator]map[string]TranslationFunc // map[<locale>]map[<tag>]TranslationFunc
tagCache *tagCache
structCache *structCache
@ -79,7 +80,7 @@ func New() *Validate {
v := &Validate{
tagName: defaultTagName,
aliases: make(map[string]string, len(bakedInAliases)),
validations: make(map[string]Func, len(bakedInValidators)),
validations: make(map[string]FuncCtx, len(bakedInValidators)),
tagCache: tc,
structCache: sc,
}
@ -93,7 +94,7 @@ func New() *Validate {
for k, val := range bakedInValidators {
// no need to error check here, baked in will alwaays be valid
v.registerValidation(k, val, true)
v.registerValidation(k, wrapFunc(val), true)
}
v.pool = &sync.Pool{
@ -128,10 +129,15 @@ func (v *Validate) RegisterTagNameFunc(fn TagNameFunc) {
// - if the key already exists, the previous validation function will be replaced.
// - this method is not thread-safe it is intended that these all be registered prior to any validation
func (v *Validate) RegisterValidation(tag string, fn Func) error {
return v.registerValidation(tag, wrapFunc(fn), false)
}
// RegisterValidationCtx adds a validation which supports context.Context
func (v *Validate) RegisterValidationCtx(tag string, fn FuncCtx) error {
return v.registerValidation(tag, fn, false)
}
func (v *Validate) registerValidation(tag string, fn Func, bakedIn bool) error {
func (v *Validate) registerValidation(tag string, fn FuncCtx, bakedIn bool) error {
if len(tag) == 0 {
return errors.New("Function Key cannot be empty")
@ -225,11 +231,8 @@ func (v *Validate) RegisterTranslation(tag string, trans ut.Translator, register
return
}
// Struct validates a structs exposed fields, and automatically validates nested structs, unless otherwise specified.
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) Struct(s interface{}) (err error) {
// StructCtx go1.8 context supports Struct validation
func (v *Validate) StructCtx(ctx context.Context, s interface{}) (err error) {
val := reflect.ValueOf(s)
top := val
@ -248,7 +251,7 @@ func (v *Validate) Struct(s interface{}) (err error) {
vd.isPartial = false
// vd.hasExcludes = false // only need to reset in StructPartial and StructExcept
vd.validateStruct(top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
if len(vd.errs) > 0 {
err = vd.errs
@ -260,6 +263,15 @@ func (v *Validate) Struct(s interface{}) (err error) {
return
}
// Struct validates a structs exposed fields, and automatically validates nested structs, unless otherwise specified.
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) Struct(s interface{}) (err error) {
ctx := context.Background()
return v.StructCtx(ctx, s)
}
// StructFiltered validates a structs exposed fields, that pass the FilterFunc check and automatically validates
// nested structs, unless otherwise specified.
//
@ -267,6 +279,7 @@ func (v *Validate) Struct(s interface{}) (err error) {
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) StructFiltered(s interface{}, fn FilterFunc) (err error) {
ctx := context.Background()
val := reflect.ValueOf(s)
top := val
@ -285,7 +298,7 @@ func (v *Validate) StructFiltered(s interface{}, fn FilterFunc) (err error) {
vd.ffn = fn
// vd.hasExcludes = false // only need to reset in StructPartial and StructExcept
vd.validateStruct(top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
if len(vd.errs) > 0 {
err = vd.errs
@ -305,6 +318,7 @@ func (v *Validate) StructFiltered(s interface{}, fn FilterFunc) (err error) {
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) StructPartial(s interface{}, fields ...string) (err error) {
ctx := context.Background()
val := reflect.ValueOf(s)
top := val
@ -364,7 +378,7 @@ func (v *Validate) StructPartial(s interface{}, fields ...string) (err error) {
}
}
vd.validateStruct(top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil)
vd.validateStruct(ctx, top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil)
if len(vd.errs) > 0 {
err = vd.errs
@ -384,6 +398,7 @@ func (v *Validate) StructPartial(s interface{}, fields ...string) (err error) {
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) StructExcept(s interface{}, fields ...string) (err error) {
ctx := context.Background()
val := reflect.ValueOf(s)
top := val
@ -419,7 +434,7 @@ func (v *Validate) StructExcept(s interface{}, fields ...string) (err error) {
vd.includeExclude[string(vd.misc)] = struct{}{}
}
vd.validateStruct(top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil)
vd.validateStruct(ctx, top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil)
if len(vd.errs) > 0 {
err = vd.errs
@ -445,6 +460,7 @@ func (v *Validate) StructExcept(s interface{}, fields ...string) (err error) {
// validate Array, Slice and maps fields which may contain more than one error
func (v *Validate) Var(field interface{}, tag string) (err error) {
ctx := context.Background()
if len(tag) == 0 || tag == skipValidationTag {
return nil
}
@ -470,7 +486,7 @@ func (v *Validate) Var(field interface{}, tag string) (err error) {
vd.top = val
vd.isPartial = false
vd.traverseField(val, val, vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag)
vd.traverseField(ctx, val, val, vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag)
if len(vd.errs) > 0 {
err = vd.errs
@ -497,6 +513,7 @@ func (v *Validate) Var(field interface{}, tag string) (err error) {
// validate Array, Slice and maps fields which may contain more than one error
func (v *Validate) VarWithValue(field interface{}, other interface{}, tag string) (err error) {
ctx := context.Background()
if len(tag) == 0 || tag == skipValidationTag {
return nil
}
@ -522,7 +539,7 @@ func (v *Validate) VarWithValue(field interface{}, other interface{}, tag string
vd.top = otherVal
vd.isPartial = false
vd.traverseField(otherVal, reflect.ValueOf(field), vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag)
vd.traverseField(ctx, otherVal, reflect.ValueOf(field), vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag)
if len(vd.errs) > 0 {
err = vd.errs

@ -2,6 +2,7 @@ package validator
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
@ -5127,6 +5128,10 @@ func TestAddFunctions(t *testing.T) {
return true
}
fnCtx := func(ctx context.Context, fl FieldLevel) bool {
return true
}
validate := New()
errs := validate.RegisterValidation("new", fn)
@ -5141,6 +5146,9 @@ func TestAddFunctions(t *testing.T) {
errs = validate.RegisterValidation("new", fn)
Equal(t, errs, nil)
errs = validate.RegisterValidationCtx("new", fnCtx)
Equal(t, errs, nil)
PanicMatches(t, func() { validate.RegisterValidation("dive", fn) }, "Tag 'dive' either contains restricted characters or is the same as a restricted tag needed for normal operation")
}

Loading…
Cancel
Save