💯Go Struct and Field validation, including Cross Field, Cross Struct, Map, Slice and Array diving
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
validator/validator.go

575 lines
16 KiB

/**
* Package validator
*
* MISC:
* - anonymous structs - they don't have names so expect the Struct name within StructErrors to be blank
*
*/
package validator
import (
"bytes"
"errors"
"fmt"
"reflect"
"strings"
"sync"
"time"
"unicode"
)
const (
utf8HexComma = "0x2C"
utf8Pipe = "0x7C"
tagSeparator = ","
orSeparator = "|"
tagKeySeparator = "="
structOnlyTag = "structonly"
omitempty = "omitempty"
skipValidationTag = "-"
diveTag = "dive"
existsTag = "exists"
fieldErrMsg = "Key: \"%s\" Error:Field validation for \"%s\" failed on the \"%s\" tag"
arrayIndexFieldName = "%s" + leftBracket + "%d" + rightBracket
mapIndexFieldName = "%s" + leftBracket + "%v" + rightBracket
invalidValidation = "Invalid validation tag on field %s"
undefinedValidation = "Undefined validation function on field %s"
)
var (
timeType = reflect.TypeOf(time.Time{})
timePtrType = reflect.TypeOf(&time.Time{})
errsPool = &sync.Pool{New: newValidationErrors}
tagsCache = &tagCacheMap{m: map[string][]*tagCache{}}
emptyStructPtr = new(struct{})
)
// returns new ValidationErrors to the pool
func newValidationErrors() interface{} {
return ValidationErrors{}
}
type tagCache struct {
tagVals [][]string
isOrVal bool
}
type tagCacheMap struct {
lock sync.RWMutex
m map[string][]*tagCache
}
func (s *tagCacheMap) Get(key string) ([]*tagCache, bool) {
s.lock.RLock()
defer s.lock.RUnlock()
value, ok := s.m[key]
return value, ok
}
func (s *tagCacheMap) Set(key string, value []*tagCache) {
s.lock.Lock()
defer s.lock.Unlock()
s.m[key] = value
}
// Validate contains the validator settings passed in using the Config struct
type Validate struct {
config Config
}
// Config contains the options that a Validator instance will use.
// It is passed to the New() function
type Config struct {
TagName string
ValidationFuncs map[string]Func
CustomTypeFuncs map[reflect.Type]CustomTypeFunc
hasCustomFuncs bool
}
// CustomTypeFunc allows for overriding or adding custom field type handler functions
// field = field value of the type to return a value to be validated
// example Valuer from sql drive see https://golang.org/src/database/sql/driver/types.go?s=1210:1293#L29
type CustomTypeFunc func(field reflect.Value) interface{}
// Func accepts all values needed for file and cross field validation
// topStruct = top level struct when validating by struct otherwise nil
// currentStruct = current level struct when validating by struct otherwise optional comparison value
// field = field value for validation
// param = parameter used in validation i.e. gt=0 param would be 0
type Func func(v *Validate, topStruct reflect.Value, currentStruct reflect.Value, field reflect.Value, fieldtype reflect.Type, fieldKind reflect.Kind, param string) bool
// ValidationErrors is a type of map[string]*FieldError
// it exists to allow for multiple errors to be passed from this library
// and yet still subscribe to the error interface
type ValidationErrors map[string]*FieldError
// Error is intended for use in development + debugging and not intended to be a production error message.
// It allows ValidationErrors to subscribe to the Error interface.
// All information to create an error message specific to your application is contained within
// the FieldError found within the ValidationErrors map
func (ve ValidationErrors) Error() string {
buff := bytes.NewBufferString("")
for key, err := range ve {
buff.WriteString(fmt.Sprintf(fieldErrMsg, key, err.Field, err.Tag))
buff.WriteString("\n")
}
return strings.TrimSpace(buff.String())
}
// FieldError contains a single field's validation error along
// with other properties that may be needed for error message creation
type FieldError struct {
Field string
Tag string
Kind reflect.Kind
Type reflect.Type
Param string
Value interface{}
}
// New creates a new Validate instance for use.
func New(config Config) *Validate {
if config.CustomTypeFuncs != nil && len(config.CustomTypeFuncs) > 0 {
config.hasCustomFuncs = true
}
return &Validate{config: config}
}
// RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key
// NOTE: if the key already exists, the previous validation function will be replaced.
// NOTE: this method is not thread-safe
func (v *Validate) RegisterValidation(key string, f Func) error {
if len(key) == 0 {
return errors.New("Function Key cannot be empty")
}
if f == nil {
return errors.New("Function cannot be empty")
}
v.config.ValidationFuncs[key] = f
return nil
}
// RegisterCustomTypeFunc registers a CustomTypeFunc against a number of types
func (v *Validate) RegisterCustomTypeFunc(fn CustomTypeFunc, types ...interface{}) {
if v.config.CustomTypeFuncs == nil {
v.config.CustomTypeFuncs = map[reflect.Type]CustomTypeFunc{}
}
for _, t := range types {
v.config.CustomTypeFuncs[reflect.TypeOf(t)] = fn
}
v.config.hasCustomFuncs = true
}
// Field validates a single field using tag style validation and returns ValidationErrors
// NOTE: it returns ValidationErrors instead of a single FieldError because this can also
// validate Array, Slice and maps fields which may contain more than one error
func (v *Validate) Field(field interface{}, tag string) ValidationErrors {
errs := errsPool.Get().(ValidationErrors)
fieldVal := reflect.ValueOf(field)
v.traverseField(fieldVal, fieldVal, fieldVal, "", errs, false, tag, "", false, false, nil)
if len(errs) == 0 {
errsPool.Put(errs)
return nil
}
return errs
}
// FieldWithValue validates a single field, against another fields value using tag style validation and returns ValidationErrors
// NOTE: it returns ValidationErrors instead of a single FieldError because this can also
// validate Array, Slice and maps fields which may contain more than one error
func (v *Validate) FieldWithValue(val interface{}, field interface{}, tag string) ValidationErrors {
errs := errsPool.Get().(ValidationErrors)
topVal := reflect.ValueOf(val)
v.traverseField(topVal, topVal, reflect.ValueOf(field), "", errs, false, tag, "", false, false, nil)
if len(errs) == 0 {
errsPool.Put(errs)
return nil
}
return errs
}
// StructPartial validates the fields that are listed by name in the map including nested structs, unless otherwise specified. Items in the map that are NOT found in the struct will cause a panic.
func (v *Validate) StructPartial(current interface{}, fields ...string) ValidationErrors {
sv, _ := v.extractType(reflect.ValueOf(current))
name := sv.Type().Name()
m := map[string]*struct{}{}
if fields != nil {
for _, k := range fields {
flds := strings.Split(k, namespaceSeparator)
if len(flds) > 0 {
key := name + namespaceSeparator
for _, s := range flds {
idx := strings.Index(s, leftBracket)
if idx != -1 {
for idx != -1 {
key += s[:idx]
m[key] = emptyStructPtr
idx2 := strings.Index(s, rightBracket)
idx2++
key += s[idx:idx2]
m[key] = emptyStructPtr
s = s[idx2:]
idx = strings.Index(s, leftBracket)
}
} else {
key += s
m[key] = emptyStructPtr
}
key += namespaceSeparator
}
}
}
}
errs := errsPool.Get().(ValidationErrors)
v.tranverseStruct(sv, sv, sv, "", errs, true, len(m) != 0, false, m)
if len(errs) == 0 {
errsPool.Put(errs)
return nil
}
return errs
}
// StructExcept validates the fields in the struct that are NOT listed by name in the map including nested structs, unless otherwise specified. Items in the map that are NOT found in the struct will cause a panic.
func (v *Validate) StructExcept(current interface{}, fields ...string) ValidationErrors {
sv, _ := v.extractType(reflect.ValueOf(current))
name := sv.Type().Name()
m := map[string]*struct{}{}
for _, key := range fields {
m[name+"."+key] = emptyStructPtr
}
errs := errsPool.Get().(ValidationErrors)
v.tranverseStruct(sv, sv, sv, "", errs, true, len(m) != 0, true, m)
if len(errs) == 0 {
errsPool.Put(errs)
return nil
}
return errs
}
// Struct validates a structs exposed fields, and automatically validates nested structs, unless otherwise specified.
func (v *Validate) Struct(current interface{}) ValidationErrors {
errs := errsPool.Get().(ValidationErrors)
sv := reflect.ValueOf(current)
v.tranverseStruct(sv, sv, sv, "", errs, true, false, false, nil)
if len(errs) == 0 {
errsPool.Put(errs)
return nil
}
return errs
}
// tranverseStruct traverses a structs fields and then passes them to be validated by traverseField
func (v *Validate) tranverseStruct(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, useStructName bool, partial bool, exclude bool, includeExclude map[string]*struct{}) {
if current.Kind() == reflect.Ptr && !current.IsNil() {
current = current.Elem()
}
if current.Kind() != reflect.Struct && current.Kind() != reflect.Interface {
panic("value passed for validation is not a struct")
}
var ok bool
typ := current.Type()
if useStructName {
errPrefix += typ.Name() + "."
}
numFields := current.NumField()
var fld reflect.StructField
for i := 0; i < numFields; i++ {
fld = typ.Field(i)
if !unicode.IsUpper(rune(fld.Name[0])) {
continue
}
if partial {
_, ok = includeExclude[errPrefix+fld.Name]
if (ok && exclude) || (!ok && !exclude) {
continue
}
}
v.traverseField(topStruct, currentStruct, current.Field(i), errPrefix, errs, true, fld.Tag.Get(v.config.TagName), fld.Name, partial, exclude, includeExclude)
}
}
// traverseField validates any field, be it a struct or single field, ensures it's validity and passes it along to be validated via it's tag options
func (v *Validate) traverseField(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, isStructField bool, tag string, name string, partial bool, exclude bool, includeExclude map[string]*struct{}) {
if tag == skipValidationTag {
return
}
current, kind := v.extractType(current)
var typ reflect.Type
switch kind {
case reflect.Ptr, reflect.Interface, reflect.Invalid:
if strings.Contains(tag, omitempty) {
return
}
if len(tag) > 0 {
tags := strings.Split(tag, tagSeparator)
var param string
vals := strings.SplitN(tags[0], tagKeySeparator, 2)
if len(vals) > 1 {
param = vals[1]
}
if kind == reflect.Invalid {
errs[errPrefix+name] = &FieldError{
Field: name,
Tag: vals[0],
Param: param,
Kind: kind,
}
return
}
errs[errPrefix+name] = &FieldError{
Field: name,
Tag: vals[0],
Param: param,
Value: current.Interface(),
Kind: kind,
Type: current.Type(),
}
return
}
// if we get here tag length is zero and we can leave
if kind == reflect.Invalid {
return
}
case reflect.Struct:
typ = current.Type()
if typ != timeType {
// required passed validation above so stop here
// if only validating the structs existance.
if strings.Contains(tag, structOnlyTag) {
return
}
v.tranverseStruct(topStruct, current, current, errPrefix+name+".", errs, false, partial, exclude, includeExclude)
return
}
}
if len(tag) == 0 {
return
}
typ = current.Type()
tags, isCached := tagsCache.Get(tag)
if !isCached {
tags = []*tagCache{}
for _, t := range strings.Split(tag, tagSeparator) {
if t == diveTag {
tags = append(tags, &tagCache{tagVals: [][]string{{t}}})
break
}
// if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C"
orVals := strings.Split(t, orSeparator)
cTag := &tagCache{isOrVal: len(orVals) > 1, tagVals: make([][]string, len(orVals))}
tags = append(tags, cTag)
var key string
var param string
for i, val := range orVals {
vals := strings.SplitN(val, tagKeySeparator, 2)
key = vals[0]
if len(key) == 0 {
panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, name)))
}
if len(vals) > 1 {
param = strings.Replace(strings.Replace(vals[1], utf8HexComma, ",", -1), utf8Pipe, "|", -1)
}
cTag.tagVals[i] = []string{key, param}
}
}
tagsCache.Set(tag, tags)
}
var dive bool
var diveSubTag string
for _, cTag := range tags {
if cTag.tagVals[0][0] == existsTag {
continue
}
if cTag.tagVals[0][0] == diveTag {
dive = true
diveSubTag = strings.TrimLeft(strings.SplitN(tag, diveTag, 2)[1], ",")
break
}
if cTag.tagVals[0][0] == omitempty {
if !hasValue(v, topStruct, currentStruct, current, typ, kind, "") {
return
}
continue
}
if v.validateField(topStruct, currentStruct, current, typ, kind, errPrefix, errs, cTag, name) {
return
}
}
if dive {
// traverse slice or map here
// or panic ;)
switch kind {
case reflect.Slice, reflect.Array:
v.traverseSlice(topStruct, currentStruct, current, errPrefix, errs, diveSubTag, name, partial, exclude, includeExclude)
case reflect.Map:
v.traverseMap(topStruct, currentStruct, current, errPrefix, errs, diveSubTag, name, partial, exclude, includeExclude)
default:
// throw error, if not a slice or map then should not have gotten here
// bad dive tag
panic("dive error! can't dive on a non slice or map")
}
}
}
// traverseSlice traverses a Slice or Array's elements and passes them to traverseField for validation
func (v *Validate) traverseSlice(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, tag string, name string, partial bool, exclude bool, includeExclude map[string]*struct{}) {
for i := 0; i < current.Len(); i++ {
v.traverseField(topStruct, currentStruct, current.Index(i), errPrefix, errs, false, tag, fmt.Sprintf(arrayIndexFieldName, name, i), partial, exclude, includeExclude)
}
}
// traverseMap traverses a map's elements and passes them to traverseField for validation
func (v *Validate) traverseMap(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, tag string, name string, partial bool, exclude bool, includeExclude map[string]*struct{}) {
for _, key := range current.MapKeys() {
v.traverseField(topStruct, currentStruct, current.MapIndex(key), errPrefix, errs, false, tag, fmt.Sprintf(mapIndexFieldName, name, key.Interface()), partial, exclude, includeExclude)
}
}
// validateField validates a field based on the provided tag's key and param values and returns true if there is an error or false if all ok
func (v *Validate) validateField(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, currentType reflect.Type, currentKind reflect.Kind, errPrefix string, errs ValidationErrors, cTag *tagCache, name string) bool {
var valFunc Func
var ok bool
if cTag.isOrVal {
errTag := ""
for _, val := range cTag.tagVals {
valFunc, ok = v.config.ValidationFuncs[val[0]]
if !ok {
panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, name)))
}
if valFunc(v, topStruct, currentStruct, current, currentType, currentKind, val[1]) {
return false
}
errTag += orSeparator + val[0]
}
errs[errPrefix+name] = &FieldError{
Field: name,
Tag: errTag[1:],
Value: current.Interface(),
Type: currentType,
Kind: currentKind,
}
return true
}
valFunc, ok = v.config.ValidationFuncs[cTag.tagVals[0][0]]
if !ok {
panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, name)))
}
if valFunc(v, topStruct, currentStruct, current, currentType, currentKind, cTag.tagVals[0][1]) {
return false
}
errs[errPrefix+name] = &FieldError{
Field: name,
Tag: cTag.tagVals[0][0],
Value: current.Interface(),
Param: cTag.tagVals[0][1],
Type: currentType,
Kind: currentKind,
}
return true
}