/** * Package validator * * MISC: * - anonymous structs - they don't have names so expect the Struct name within StructErrors to be blank * */ package validator import ( "bytes" "errors" "fmt" "reflect" "strings" "sync" "time" "unicode" ) const ( utf8HexComma = "0x2C" utf8Pipe = "0x7C" tagSeparator = "," orSeparator = "|" tagKeySeparator = "=" structOnlyTag = "structonly" omitempty = "omitempty" skipValidationTag = "-" diveTag = "dive" existsTag = "exists" fieldErrMsg = "Key: \"%s\" Error:Field validation for \"%s\" failed on the \"%s\" tag" arrayIndexFieldName = "%s" + leftBracket + "%d" + rightBracket mapIndexFieldName = "%s" + leftBracket + "%v" + rightBracket invalidValidation = "Invalid validation tag on field %s" undefinedValidation = "Undefined validation function on field %s" ) var ( timeType = reflect.TypeOf(time.Time{}) timePtrType = reflect.TypeOf(&time.Time{}) errsPool = &sync.Pool{New: newValidationErrors} tagsCache = &tagCacheMap{m: map[string][]*tagCache{}} emptyStructPtr = new(struct{}) ) // returns new ValidationErrors to the pool func newValidationErrors() interface{} { return ValidationErrors{} } type tagCache struct { tagVals [][]string isOrVal bool } type tagCacheMap struct { lock sync.RWMutex m map[string][]*tagCache } func (s *tagCacheMap) Get(key string) ([]*tagCache, bool) { s.lock.RLock() defer s.lock.RUnlock() value, ok := s.m[key] return value, ok } func (s *tagCacheMap) Set(key string, value []*tagCache) { s.lock.Lock() defer s.lock.Unlock() s.m[key] = value } // Validate contains the validator settings passed in using the Config struct type Validate struct { config Config } // Config contains the options that a Validator instance will use. // It is passed to the New() function type Config struct { TagName string ValidationFuncs map[string]Func CustomTypeFuncs map[reflect.Type]CustomTypeFunc hasCustomFuncs bool } // CustomTypeFunc allows for overriding or adding custom field type handler functions // field = field value of the type to return a value to be validated // example Valuer from sql drive see https://golang.org/src/database/sql/driver/types.go?s=1210:1293#L29 type CustomTypeFunc func(field reflect.Value) interface{} // Func accepts all values needed for file and cross field validation // topStruct = top level struct when validating by struct otherwise nil // currentStruct = current level struct when validating by struct otherwise optional comparison value // field = field value for validation // param = parameter used in validation i.e. gt=0 param would be 0 type Func func(v *Validate, topStruct reflect.Value, currentStruct reflect.Value, field reflect.Value, fieldtype reflect.Type, fieldKind reflect.Kind, param string) bool // ValidationErrors is a type of map[string]*FieldError // it exists to allow for multiple errors to be passed from this library // and yet still subscribe to the error interface type ValidationErrors map[string]*FieldError // Error is intended for use in development + debugging and not intended to be a production error message. // It allows ValidationErrors to subscribe to the Error interface. // All information to create an error message specific to your application is contained within // the FieldError found within the ValidationErrors map func (ve ValidationErrors) Error() string { buff := bytes.NewBufferString("") for key, err := range ve { buff.WriteString(fmt.Sprintf(fieldErrMsg, key, err.Field, err.Tag)) buff.WriteString("\n") } return strings.TrimSpace(buff.String()) } // FieldError contains a single field's validation error along // with other properties that may be needed for error message creation type FieldError struct { Field string Tag string Kind reflect.Kind Type reflect.Type Param string Value interface{} } // New creates a new Validate instance for use. func New(config Config) *Validate { if config.CustomTypeFuncs != nil && len(config.CustomTypeFuncs) > 0 { config.hasCustomFuncs = true } return &Validate{config: config} } // RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key // NOTE: if the key already exists, the previous validation function will be replaced. // NOTE: this method is not thread-safe func (v *Validate) RegisterValidation(key string, f Func) error { if len(key) == 0 { return errors.New("Function Key cannot be empty") } if f == nil { return errors.New("Function cannot be empty") } v.config.ValidationFuncs[key] = f return nil } // RegisterCustomTypeFunc registers a CustomTypeFunc against a number of types func (v *Validate) RegisterCustomTypeFunc(fn CustomTypeFunc, types ...interface{}) { if v.config.CustomTypeFuncs == nil { v.config.CustomTypeFuncs = map[reflect.Type]CustomTypeFunc{} } for _, t := range types { v.config.CustomTypeFuncs[reflect.TypeOf(t)] = fn } v.config.hasCustomFuncs = true } // Field validates a single field using tag style validation and returns ValidationErrors // NOTE: it returns ValidationErrors instead of a single FieldError because this can also // validate Array, Slice and maps fields which may contain more than one error func (v *Validate) Field(field interface{}, tag string) ValidationErrors { errs := errsPool.Get().(ValidationErrors) fieldVal := reflect.ValueOf(field) v.traverseField(fieldVal, fieldVal, fieldVal, "", errs, false, tag, "", false, false, nil) if len(errs) == 0 { errsPool.Put(errs) return nil } return errs } // FieldWithValue validates a single field, against another fields value using tag style validation and returns ValidationErrors // NOTE: it returns ValidationErrors instead of a single FieldError because this can also // validate Array, Slice and maps fields which may contain more than one error func (v *Validate) FieldWithValue(val interface{}, field interface{}, tag string) ValidationErrors { errs := errsPool.Get().(ValidationErrors) topVal := reflect.ValueOf(val) v.traverseField(topVal, topVal, reflect.ValueOf(field), "", errs, false, tag, "", false, false, nil) if len(errs) == 0 { errsPool.Put(errs) return nil } return errs } // StructPartial validates the fields that are listed by name in the map including nested structs, unless otherwise specified. Items in the map that are NOT found in the struct will cause a panic. func (v *Validate) StructPartial(current interface{}, fields ...string) ValidationErrors { sv, _ := v.extractType(reflect.ValueOf(current)) name := sv.Type().Name() m := map[string]*struct{}{} var i int if fields != nil { for _, k := range fields { flds := strings.Split(k, ".") if len(flds) > 0 { key := name for _, s := range flds { idx := strings.Index(s, "[") if idx != -1 { for idx != -1 { i++ key += s[:idx] m[key] = emptyStructPtr idx2 := strings.Index(s, "]") idx2++ key += s[idx:idx2] m[key] = emptyStructPtr s = s[idx2:] idx = strings.Index(s, "[") if i == 10 { idx = -1 } } } else { key += s m[key] = emptyStructPtr } key += "." } } } } errs := errsPool.Get().(ValidationErrors) v.tranverseStruct(sv, sv, sv, "", errs, true, len(m) != 0, false, m) if len(errs) == 0 { errsPool.Put(errs) return nil } return errs } // StructExcept validates the fields in the struct that are NOT listed by name in the map including nested structs, unless otherwise specified. Items in the map that are NOT found in the struct will cause a panic. func (v *Validate) StructExcept(current interface{}, fields ...string) ValidationErrors { sv, _ := v.extractType(reflect.ValueOf(current)) name := sv.Type().Name() m := map[string]*struct{}{} for _, key := range fields { m[name+"."+key] = emptyStructPtr } errs := errsPool.Get().(ValidationErrors) v.tranverseStruct(sv, sv, sv, "", errs, true, len(m) != 0, true, m) if len(errs) == 0 { errsPool.Put(errs) return nil } return errs } // Struct validates a structs exposed fields, and automatically validates nested structs, unless otherwise specified. func (v *Validate) Struct(current interface{}) ValidationErrors { errs := errsPool.Get().(ValidationErrors) sv := reflect.ValueOf(current) v.tranverseStruct(sv, sv, sv, "", errs, true, false, false, nil) if len(errs) == 0 { errsPool.Put(errs) return nil } return errs } // tranverseStruct traverses a structs fields and then passes them to be validated by traverseField func (v *Validate) tranverseStruct(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, useStructName bool, partial bool, exclude bool, includeExclude map[string]*struct{}) { if current.Kind() == reflect.Ptr && !current.IsNil() { current = current.Elem() } if current.Kind() != reflect.Struct && current.Kind() != reflect.Interface { panic("value passed for validation is not a struct") } var ok bool typ := current.Type() if useStructName { errPrefix += typ.Name() + "." } numFields := current.NumField() var fld reflect.StructField for i := 0; i < numFields; i++ { fld = typ.Field(i) if !unicode.IsUpper(rune(fld.Name[0])) { continue } if partial { _, ok = includeExclude[errPrefix+fld.Name] if (ok && exclude) || (!ok && !exclude) { continue } } v.traverseField(topStruct, currentStruct, current.Field(i), errPrefix, errs, true, fld.Tag.Get(v.config.TagName), fld.Name, partial, exclude, includeExclude) } } // traverseField validates any field, be it a struct or single field, ensures it's validity and passes it along to be validated via it's tag options func (v *Validate) traverseField(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, isStructField bool, tag string, name string, partial bool, exclude bool, includeExclude map[string]*struct{}) { if tag == skipValidationTag { return } current, kind := v.extractType(current) var typ reflect.Type switch kind { case reflect.Ptr, reflect.Interface, reflect.Invalid: if strings.Contains(tag, omitempty) { return } if len(tag) > 0 { tags := strings.Split(tag, tagSeparator) var param string vals := strings.SplitN(tags[0], tagKeySeparator, 2) if len(vals) > 1 { param = vals[1] } if kind == reflect.Invalid { errs[errPrefix+name] = &FieldError{ Field: name, Tag: vals[0], Param: param, Kind: kind, } return } errs[errPrefix+name] = &FieldError{ Field: name, Tag: vals[0], Param: param, Value: current.Interface(), Kind: kind, Type: current.Type(), } return } // if we get here tag length is zero and we can leave if kind == reflect.Invalid { return } case reflect.Struct: typ = current.Type() if typ != timeType { // required passed validation above so stop here // if only validating the structs existance. if strings.Contains(tag, structOnlyTag) { return } v.tranverseStruct(topStruct, current, current, errPrefix+name+".", errs, false, partial, exclude, includeExclude) return } } if len(tag) == 0 { return } typ = current.Type() tags, isCached := tagsCache.Get(tag) if !isCached { tags = []*tagCache{} for _, t := range strings.Split(tag, tagSeparator) { if t == diveTag { tags = append(tags, &tagCache{tagVals: [][]string{{t}}}) break } // if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C" orVals := strings.Split(t, orSeparator) cTag := &tagCache{isOrVal: len(orVals) > 1, tagVals: make([][]string, len(orVals))} tags = append(tags, cTag) var key string var param string for i, val := range orVals { vals := strings.SplitN(val, tagKeySeparator, 2) key = vals[0] if len(key) == 0 { panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, name))) } if len(vals) > 1 { param = strings.Replace(strings.Replace(vals[1], utf8HexComma, ",", -1), utf8Pipe, "|", -1) } cTag.tagVals[i] = []string{key, param} } } tagsCache.Set(tag, tags) } var dive bool var diveSubTag string for _, cTag := range tags { if cTag.tagVals[0][0] == existsTag { continue } if cTag.tagVals[0][0] == diveTag { dive = true diveSubTag = strings.TrimLeft(strings.SplitN(tag, diveTag, 2)[1], ",") break } if cTag.tagVals[0][0] == omitempty { if !hasValue(v, topStruct, currentStruct, current, typ, kind, "") { return } continue } if v.validateField(topStruct, currentStruct, current, typ, kind, errPrefix, errs, cTag, name) { return } } if dive { // traverse slice or map here // or panic ;) switch kind { case reflect.Slice, reflect.Array: v.traverseSlice(topStruct, currentStruct, current, errPrefix, errs, diveSubTag, name, partial, exclude, includeExclude) case reflect.Map: v.traverseMap(topStruct, currentStruct, current, errPrefix, errs, diveSubTag, name, partial, exclude, includeExclude) default: // throw error, if not a slice or map then should not have gotten here // bad dive tag panic("dive error! can't dive on a non slice or map") } } } // traverseSlice traverses a Slice or Array's elements and passes them to traverseField for validation func (v *Validate) traverseSlice(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, tag string, name string, partial bool, exclude bool, includeExclude map[string]*struct{}) { for i := 0; i < current.Len(); i++ { v.traverseField(topStruct, currentStruct, current.Index(i), errPrefix, errs, false, tag, fmt.Sprintf(arrayIndexFieldName, name, i), partial, exclude, includeExclude) } } // traverseMap traverses a map's elements and passes them to traverseField for validation func (v *Validate) traverseMap(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, tag string, name string, partial bool, exclude bool, includeExclude map[string]*struct{}) { for _, key := range current.MapKeys() { v.traverseField(topStruct, currentStruct, current.MapIndex(key), errPrefix, errs, false, tag, fmt.Sprintf(mapIndexFieldName, name, key.Interface()), partial, exclude, includeExclude) } } // validateField validates a field based on the provided tag's key and param values and returns true if there is an error or false if all ok func (v *Validate) validateField(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, currentType reflect.Type, currentKind reflect.Kind, errPrefix string, errs ValidationErrors, cTag *tagCache, name string) bool { var valFunc Func var ok bool if cTag.isOrVal { errTag := "" for _, val := range cTag.tagVals { valFunc, ok = v.config.ValidationFuncs[val[0]] if !ok { panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, name))) } if valFunc(v, topStruct, currentStruct, current, currentType, currentKind, val[1]) { return false } errTag += orSeparator + val[0] } errs[errPrefix+name] = &FieldError{ Field: name, Tag: errTag[1:], Value: current.Interface(), Type: currentType, Kind: currentKind, } return true } valFunc, ok = v.config.ValidationFuncs[cTag.tagVals[0][0]] if !ok { panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, name))) } if valFunc(v, topStruct, currentStruct, current, currentType, currentKind, cTag.tagVals[0][1]) { return false } errs[errPrefix+name] = &FieldError{ Field: name, Tag: cTag.tagVals[0][0], Value: current.Interface(), Param: cTag.tagVals[0][1], Type: currentType, Kind: currentKind, } return true }