You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
581 lines
16 KiB
581 lines
16 KiB
/**
|
|
* Package validator
|
|
*
|
|
* MISC:
|
|
* - anonymous structs - they don't have names so expect the Struct name within StructErrors to be blank
|
|
*
|
|
*/
|
|
|
|
package validator
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
"unicode"
|
|
)
|
|
|
|
const (
|
|
utf8HexComma = "0x2C"
|
|
utf8Pipe = "0x7C"
|
|
tagSeparator = ","
|
|
orSeparator = "|"
|
|
tagKeySeparator = "="
|
|
structOnlyTag = "structonly"
|
|
omitempty = "omitempty"
|
|
skipValidationTag = "-"
|
|
diveTag = "dive"
|
|
existsTag = "exists"
|
|
fieldErrMsg = "Key: \"%s\" Error:Field validation for \"%s\" failed on the \"%s\" tag"
|
|
arrayIndexFieldName = "%s" + leftBracket + "%d" + rightBracket
|
|
mapIndexFieldName = "%s" + leftBracket + "%v" + rightBracket
|
|
invalidValidation = "Invalid validation tag on field %s"
|
|
undefinedValidation = "Undefined validation function on field %s"
|
|
)
|
|
|
|
var (
|
|
timeType = reflect.TypeOf(time.Time{})
|
|
timePtrType = reflect.TypeOf(&time.Time{})
|
|
errsPool = &sync.Pool{New: newValidationErrors}
|
|
tagsCache = &tagCacheMap{m: map[string][]*tagCache{}}
|
|
emptyStructPtr = new(struct{})
|
|
)
|
|
|
|
// returns new ValidationErrors to the pool
|
|
func newValidationErrors() interface{} {
|
|
return ValidationErrors{}
|
|
}
|
|
|
|
type tagCache struct {
|
|
tagVals [][]string
|
|
isOrVal bool
|
|
}
|
|
|
|
type tagCacheMap struct {
|
|
lock sync.RWMutex
|
|
m map[string][]*tagCache
|
|
}
|
|
|
|
func (s *tagCacheMap) Get(key string) ([]*tagCache, bool) {
|
|
s.lock.RLock()
|
|
defer s.lock.RUnlock()
|
|
value, ok := s.m[key]
|
|
return value, ok
|
|
}
|
|
|
|
func (s *tagCacheMap) Set(key string, value []*tagCache) {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
s.m[key] = value
|
|
}
|
|
|
|
// Validate contains the validator settings passed in using the Config struct
|
|
type Validate struct {
|
|
config Config
|
|
}
|
|
|
|
// Config contains the options that a Validator instance will use.
|
|
// It is passed to the New() function
|
|
type Config struct {
|
|
TagName string
|
|
ValidationFuncs map[string]Func
|
|
CustomTypeFuncs map[reflect.Type]CustomTypeFunc
|
|
hasCustomFuncs bool
|
|
}
|
|
|
|
// CustomTypeFunc allows for overriding or adding custom field type handler functions
|
|
// field = field value of the type to return a value to be validated
|
|
// example Valuer from sql drive see https://golang.org/src/database/sql/driver/types.go?s=1210:1293#L29
|
|
type CustomTypeFunc func(field reflect.Value) interface{}
|
|
|
|
// Func accepts all values needed for file and cross field validation
|
|
// topStruct = top level struct when validating by struct otherwise nil
|
|
// currentStruct = current level struct when validating by struct otherwise optional comparison value
|
|
// field = field value for validation
|
|
// param = parameter used in validation i.e. gt=0 param would be 0
|
|
type Func func(v *Validate, topStruct reflect.Value, currentStruct reflect.Value, field reflect.Value, fieldtype reflect.Type, fieldKind reflect.Kind, param string) bool
|
|
|
|
// ValidationErrors is a type of map[string]*FieldError
|
|
// it exists to allow for multiple errors to be passed from this library
|
|
// and yet still subscribe to the error interface
|
|
type ValidationErrors map[string]*FieldError
|
|
|
|
// Error is intended for use in development + debugging and not intended to be a production error message.
|
|
// It allows ValidationErrors to subscribe to the Error interface.
|
|
// All information to create an error message specific to your application is contained within
|
|
// the FieldError found within the ValidationErrors map
|
|
func (ve ValidationErrors) Error() string {
|
|
|
|
buff := bytes.NewBufferString("")
|
|
|
|
for key, err := range ve {
|
|
buff.WriteString(fmt.Sprintf(fieldErrMsg, key, err.Field, err.Tag))
|
|
buff.WriteString("\n")
|
|
}
|
|
|
|
return strings.TrimSpace(buff.String())
|
|
}
|
|
|
|
// FieldError contains a single field's validation error along
|
|
// with other properties that may be needed for error message creation
|
|
type FieldError struct {
|
|
Field string
|
|
Tag string
|
|
Kind reflect.Kind
|
|
Type reflect.Type
|
|
Param string
|
|
Value interface{}
|
|
}
|
|
|
|
// New creates a new Validate instance for use.
|
|
func New(config Config) *Validate {
|
|
|
|
if config.CustomTypeFuncs != nil && len(config.CustomTypeFuncs) > 0 {
|
|
config.hasCustomFuncs = true
|
|
}
|
|
|
|
return &Validate{config: config}
|
|
}
|
|
|
|
// RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key
|
|
// NOTE: if the key already exists, the previous validation function will be replaced.
|
|
// NOTE: this method is not thread-safe
|
|
func (v *Validate) RegisterValidation(key string, f Func) error {
|
|
|
|
if len(key) == 0 {
|
|
return errors.New("Function Key cannot be empty")
|
|
}
|
|
|
|
if f == nil {
|
|
return errors.New("Function cannot be empty")
|
|
}
|
|
|
|
v.config.ValidationFuncs[key] = f
|
|
|
|
return nil
|
|
}
|
|
|
|
// RegisterCustomTypeFunc registers a CustomTypeFunc against a number of types
|
|
func (v *Validate) RegisterCustomTypeFunc(fn CustomTypeFunc, types ...interface{}) {
|
|
|
|
if v.config.CustomTypeFuncs == nil {
|
|
v.config.CustomTypeFuncs = map[reflect.Type]CustomTypeFunc{}
|
|
}
|
|
|
|
for _, t := range types {
|
|
v.config.CustomTypeFuncs[reflect.TypeOf(t)] = fn
|
|
}
|
|
|
|
v.config.hasCustomFuncs = true
|
|
}
|
|
|
|
// Field validates a single field using tag style validation and returns ValidationErrors
|
|
// NOTE: it returns ValidationErrors instead of a single FieldError because this can also
|
|
// validate Array, Slice and maps fields which may contain more than one error
|
|
func (v *Validate) Field(field interface{}, tag string) ValidationErrors {
|
|
|
|
errs := errsPool.Get().(ValidationErrors)
|
|
fieldVal := reflect.ValueOf(field)
|
|
|
|
v.traverseField(fieldVal, fieldVal, fieldVal, "", errs, false, tag, "", false, false, nil)
|
|
|
|
if len(errs) == 0 {
|
|
errsPool.Put(errs)
|
|
return nil
|
|
}
|
|
|
|
return errs
|
|
}
|
|
|
|
// FieldWithValue validates a single field, against another fields value using tag style validation and returns ValidationErrors
|
|
// NOTE: it returns ValidationErrors instead of a single FieldError because this can also
|
|
// validate Array, Slice and maps fields which may contain more than one error
|
|
func (v *Validate) FieldWithValue(val interface{}, field interface{}, tag string) ValidationErrors {
|
|
|
|
errs := errsPool.Get().(ValidationErrors)
|
|
topVal := reflect.ValueOf(val)
|
|
|
|
v.traverseField(topVal, topVal, reflect.ValueOf(field), "", errs, false, tag, "", false, false, nil)
|
|
|
|
if len(errs) == 0 {
|
|
errsPool.Put(errs)
|
|
return nil
|
|
}
|
|
|
|
return errs
|
|
}
|
|
|
|
// StructPartial validates the fields that are listed by name in the map including nested structs, unless otherwise specified. Items in the map that are NOT found in the struct will cause a panic.
|
|
func (v *Validate) StructPartial(current interface{}, fields ...string) ValidationErrors {
|
|
|
|
sv, _ := v.extractType(reflect.ValueOf(current))
|
|
name := sv.Type().Name()
|
|
m := map[string]*struct{}{}
|
|
|
|
var i int
|
|
|
|
if fields != nil {
|
|
for _, k := range fields {
|
|
|
|
flds := strings.Split(k, ".")
|
|
if len(flds) > 0 {
|
|
|
|
key := name
|
|
for _, s := range flds {
|
|
|
|
idx := strings.Index(s, "[")
|
|
|
|
if idx != -1 {
|
|
for idx != -1 {
|
|
i++
|
|
key += s[:idx]
|
|
m[key] = emptyStructPtr
|
|
|
|
idx2 := strings.Index(s, "]")
|
|
idx2++
|
|
key += s[idx:idx2]
|
|
m[key] = emptyStructPtr
|
|
s = s[idx2:]
|
|
idx = strings.Index(s, "[")
|
|
|
|
if i == 10 {
|
|
idx = -1
|
|
}
|
|
}
|
|
} else {
|
|
|
|
key += s
|
|
m[key] = emptyStructPtr
|
|
}
|
|
|
|
key += "."
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
errs := errsPool.Get().(ValidationErrors)
|
|
|
|
v.tranverseStruct(sv, sv, sv, "", errs, true, len(m) != 0, false, m)
|
|
|
|
if len(errs) == 0 {
|
|
errsPool.Put(errs)
|
|
return nil
|
|
}
|
|
|
|
return errs
|
|
}
|
|
|
|
// StructExcept validates the fields in the struct that are NOT listed by name in the map including nested structs, unless otherwise specified. Items in the map that are NOT found in the struct will cause a panic.
|
|
func (v *Validate) StructExcept(current interface{}, fields ...string) ValidationErrors {
|
|
|
|
sv, _ := v.extractType(reflect.ValueOf(current))
|
|
name := sv.Type().Name()
|
|
m := map[string]*struct{}{}
|
|
|
|
for _, key := range fields {
|
|
m[name+"."+key] = emptyStructPtr
|
|
}
|
|
|
|
errs := errsPool.Get().(ValidationErrors)
|
|
|
|
v.tranverseStruct(sv, sv, sv, "", errs, true, len(m) != 0, true, m)
|
|
|
|
if len(errs) == 0 {
|
|
errsPool.Put(errs)
|
|
return nil
|
|
}
|
|
|
|
return errs
|
|
}
|
|
|
|
// Struct validates a structs exposed fields, and automatically validates nested structs, unless otherwise specified.
|
|
func (v *Validate) Struct(current interface{}) ValidationErrors {
|
|
|
|
errs := errsPool.Get().(ValidationErrors)
|
|
sv := reflect.ValueOf(current)
|
|
|
|
v.tranverseStruct(sv, sv, sv, "", errs, true, false, false, nil)
|
|
|
|
if len(errs) == 0 {
|
|
errsPool.Put(errs)
|
|
return nil
|
|
}
|
|
|
|
return errs
|
|
}
|
|
|
|
// tranverseStruct traverses a structs fields and then passes them to be validated by traverseField
|
|
func (v *Validate) tranverseStruct(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, useStructName bool, partial bool, exclude bool, includeExclude map[string]*struct{}) {
|
|
|
|
if current.Kind() == reflect.Ptr && !current.IsNil() {
|
|
current = current.Elem()
|
|
}
|
|
|
|
if current.Kind() != reflect.Struct && current.Kind() != reflect.Interface {
|
|
panic("value passed for validation is not a struct")
|
|
}
|
|
|
|
var ok bool
|
|
typ := current.Type()
|
|
|
|
if useStructName {
|
|
errPrefix += typ.Name() + "."
|
|
}
|
|
|
|
numFields := current.NumField()
|
|
|
|
var fld reflect.StructField
|
|
|
|
for i := 0; i < numFields; i++ {
|
|
fld = typ.Field(i)
|
|
|
|
if !unicode.IsUpper(rune(fld.Name[0])) {
|
|
continue
|
|
}
|
|
|
|
if partial {
|
|
|
|
_, ok = includeExclude[errPrefix+fld.Name]
|
|
|
|
if (ok && exclude) || (!ok && !exclude) {
|
|
continue
|
|
}
|
|
}
|
|
|
|
v.traverseField(topStruct, currentStruct, current.Field(i), errPrefix, errs, true, fld.Tag.Get(v.config.TagName), fld.Name, partial, exclude, includeExclude)
|
|
}
|
|
}
|
|
|
|
// traverseField validates any field, be it a struct or single field, ensures it's validity and passes it along to be validated via it's tag options
|
|
func (v *Validate) traverseField(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, isStructField bool, tag string, name string, partial bool, exclude bool, includeExclude map[string]*struct{}) {
|
|
|
|
if tag == skipValidationTag {
|
|
return
|
|
}
|
|
|
|
current, kind := v.extractType(current)
|
|
var typ reflect.Type
|
|
|
|
switch kind {
|
|
case reflect.Ptr, reflect.Interface, reflect.Invalid:
|
|
if strings.Contains(tag, omitempty) {
|
|
return
|
|
}
|
|
|
|
if len(tag) > 0 {
|
|
|
|
tags := strings.Split(tag, tagSeparator)
|
|
var param string
|
|
vals := strings.SplitN(tags[0], tagKeySeparator, 2)
|
|
|
|
if len(vals) > 1 {
|
|
param = vals[1]
|
|
}
|
|
|
|
if kind == reflect.Invalid {
|
|
errs[errPrefix+name] = &FieldError{
|
|
Field: name,
|
|
Tag: vals[0],
|
|
Param: param,
|
|
Kind: kind,
|
|
}
|
|
return
|
|
}
|
|
|
|
errs[errPrefix+name] = &FieldError{
|
|
Field: name,
|
|
Tag: vals[0],
|
|
Param: param,
|
|
Value: current.Interface(),
|
|
Kind: kind,
|
|
Type: current.Type(),
|
|
}
|
|
|
|
return
|
|
}
|
|
// if we get here tag length is zero and we can leave
|
|
if kind == reflect.Invalid {
|
|
return
|
|
}
|
|
|
|
case reflect.Struct:
|
|
typ = current.Type()
|
|
|
|
if typ != timeType {
|
|
|
|
// required passed validation above so stop here
|
|
// if only validating the structs existance.
|
|
if strings.Contains(tag, structOnlyTag) {
|
|
return
|
|
}
|
|
|
|
v.tranverseStruct(topStruct, current, current, errPrefix+name+".", errs, false, partial, exclude, includeExclude)
|
|
return
|
|
}
|
|
}
|
|
|
|
if len(tag) == 0 {
|
|
return
|
|
}
|
|
|
|
typ = current.Type()
|
|
|
|
tags, isCached := tagsCache.Get(tag)
|
|
|
|
if !isCached {
|
|
|
|
tags = []*tagCache{}
|
|
|
|
for _, t := range strings.Split(tag, tagSeparator) {
|
|
|
|
if t == diveTag {
|
|
tags = append(tags, &tagCache{tagVals: [][]string{{t}}})
|
|
break
|
|
}
|
|
|
|
// if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C"
|
|
orVals := strings.Split(t, orSeparator)
|
|
cTag := &tagCache{isOrVal: len(orVals) > 1, tagVals: make([][]string, len(orVals))}
|
|
tags = append(tags, cTag)
|
|
|
|
var key string
|
|
var param string
|
|
|
|
for i, val := range orVals {
|
|
vals := strings.SplitN(val, tagKeySeparator, 2)
|
|
key = vals[0]
|
|
|
|
if len(key) == 0 {
|
|
panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, name)))
|
|
}
|
|
|
|
if len(vals) > 1 {
|
|
param = strings.Replace(strings.Replace(vals[1], utf8HexComma, ",", -1), utf8Pipe, "|", -1)
|
|
}
|
|
|
|
cTag.tagVals[i] = []string{key, param}
|
|
}
|
|
}
|
|
tagsCache.Set(tag, tags)
|
|
}
|
|
|
|
var dive bool
|
|
var diveSubTag string
|
|
|
|
for _, cTag := range tags {
|
|
|
|
if cTag.tagVals[0][0] == existsTag {
|
|
continue
|
|
}
|
|
|
|
if cTag.tagVals[0][0] == diveTag {
|
|
dive = true
|
|
diveSubTag = strings.TrimLeft(strings.SplitN(tag, diveTag, 2)[1], ",")
|
|
break
|
|
}
|
|
|
|
if cTag.tagVals[0][0] == omitempty {
|
|
|
|
if !hasValue(v, topStruct, currentStruct, current, typ, kind, "") {
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
|
|
if v.validateField(topStruct, currentStruct, current, typ, kind, errPrefix, errs, cTag, name) {
|
|
return
|
|
}
|
|
}
|
|
|
|
if dive {
|
|
// traverse slice or map here
|
|
// or panic ;)
|
|
switch kind {
|
|
case reflect.Slice, reflect.Array:
|
|
v.traverseSlice(topStruct, currentStruct, current, errPrefix, errs, diveSubTag, name, partial, exclude, includeExclude)
|
|
case reflect.Map:
|
|
v.traverseMap(topStruct, currentStruct, current, errPrefix, errs, diveSubTag, name, partial, exclude, includeExclude)
|
|
default:
|
|
// throw error, if not a slice or map then should not have gotten here
|
|
// bad dive tag
|
|
panic("dive error! can't dive on a non slice or map")
|
|
}
|
|
}
|
|
}
|
|
|
|
// traverseSlice traverses a Slice or Array's elements and passes them to traverseField for validation
|
|
func (v *Validate) traverseSlice(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, tag string, name string, partial bool, exclude bool, includeExclude map[string]*struct{}) {
|
|
|
|
for i := 0; i < current.Len(); i++ {
|
|
v.traverseField(topStruct, currentStruct, current.Index(i), errPrefix, errs, false, tag, fmt.Sprintf(arrayIndexFieldName, name, i), partial, exclude, includeExclude)
|
|
}
|
|
}
|
|
|
|
// traverseMap traverses a map's elements and passes them to traverseField for validation
|
|
func (v *Validate) traverseMap(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, tag string, name string, partial bool, exclude bool, includeExclude map[string]*struct{}) {
|
|
|
|
for _, key := range current.MapKeys() {
|
|
v.traverseField(topStruct, currentStruct, current.MapIndex(key), errPrefix, errs, false, tag, fmt.Sprintf(mapIndexFieldName, name, key.Interface()), partial, exclude, includeExclude)
|
|
}
|
|
}
|
|
|
|
// validateField validates a field based on the provided tag's key and param values and returns true if there is an error or false if all ok
|
|
func (v *Validate) validateField(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, currentType reflect.Type, currentKind reflect.Kind, errPrefix string, errs ValidationErrors, cTag *tagCache, name string) bool {
|
|
|
|
var valFunc Func
|
|
var ok bool
|
|
|
|
if cTag.isOrVal {
|
|
|
|
errTag := ""
|
|
|
|
for _, val := range cTag.tagVals {
|
|
|
|
valFunc, ok = v.config.ValidationFuncs[val[0]]
|
|
if !ok {
|
|
panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, name)))
|
|
}
|
|
|
|
if valFunc(v, topStruct, currentStruct, current, currentType, currentKind, val[1]) {
|
|
return false
|
|
}
|
|
|
|
errTag += orSeparator + val[0]
|
|
}
|
|
|
|
errs[errPrefix+name] = &FieldError{
|
|
Field: name,
|
|
Tag: errTag[1:],
|
|
Value: current.Interface(),
|
|
Type: currentType,
|
|
Kind: currentKind,
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
valFunc, ok = v.config.ValidationFuncs[cTag.tagVals[0][0]]
|
|
if !ok {
|
|
panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, name)))
|
|
}
|
|
|
|
if valFunc(v, topStruct, currentStruct, current, currentType, currentKind, cTag.tagVals[0][1]) {
|
|
return false
|
|
}
|
|
|
|
errs[errPrefix+name] = &FieldError{
|
|
Field: name,
|
|
Tag: cTag.tagVals[0][0],
|
|
Value: current.Interface(),
|
|
Param: cTag.tagVals[0][1],
|
|
Type: currentType,
|
|
Kind: currentKind,
|
|
}
|
|
|
|
return true
|
|
}
|
|
|