add traverseMap

for #78
pull/82/head
joeybloggs 10 years ago
parent 98f4165fae
commit 14f176e8ac
  1. 107
      validator.go
  2. 16
      validator_test.go

@ -20,21 +20,22 @@ import (
) )
const ( const (
utf8HexComma = "0x2C" utf8HexComma = "0x2C"
tagSeparator = "," tagSeparator = ","
orSeparator = "|" orSeparator = "|"
noValidationTag = "-" noValidationTag = "-"
tagKeySeparator = "=" tagKeySeparator = "="
structOnlyTag = "structonly" structOnlyTag = "structonly"
omitempty = "omitempty" omitempty = "omitempty"
required = "required" required = "required"
fieldErrMsg = "Field validation for \"%s\" failed on the \"%s\" tag" fieldErrMsg = "Field validation for \"%s\" failed on the \"%s\" tag"
sliceErrMsg = "Field validation for \"%s\" failed at index \"%d\" failed with error(s): %s" sliceErrMsg = "Field validation for \"%s\" failed at index \"%d\" with error(s): %s"
mapErrMsg = "Field validation for \"%s\" failed with key \"%v\" failed with error(s): %s" mapErrMsg = "Field validation for \"%s\" failed on key \"%v\" with error(s): %s"
structErrMsg = "Struct:%s\n" structErrMsg = "Struct:%s\n"
diveTag = "dive" diveTag = "dive"
diveSplit = "," + diveTag diveSplit = "," + diveTag
indexFieldName = "%s[%d]" arrayIndexFieldName = "%s[%d]"
mapIndexFieldName = "%s[%v]"
) )
var structPool *pool var structPool *pool
@ -670,7 +671,18 @@ func (v *Validate) fieldWithNameAndValue(val interface{}, current interface{}, f
} }
} else if cField.isMap { } else if cField.isMap {
// return if error here if errs := v.traverseMap(val, current, valueField, cField); errs != nil && len(errs) > 0 {
return &FieldError{
Field: cField.name,
Kind: cField.kind,
Type: cField.typ,
Value: f,
IsPlaceholderErr: true,
IsMap: true,
MapErrs: errs,
}
}
} else { } else {
// throw error, if not a slice or map then should not have gotten here // throw error, if not a slice or map then should not have gotten here
panic("dive error! can't dive on a non slice or map") panic("dive error! can't dive on a non slice or map")
@ -680,6 +692,65 @@ func (v *Validate) fieldWithNameAndValue(val interface{}, current interface{}, f
return nil return nil
} }
func (v *Validate) traverseMap(val interface{}, current interface{}, valueField reflect.Value, cField *cachedField) map[interface{}]error {
errs := map[interface{}]error{}
for _, key := range valueField.MapKeys() {
idxField := valueField.MapIndex(key)
if cField.sliceSubKind == reflect.Ptr && !idxField.IsNil() {
idxField = idxField.Elem()
cField.sliceSubKind = idxField.Kind()
}
switch cField.sliceSubKind {
case reflect.Struct, reflect.Interface:
if cField.isTimeSubtype {
if fieldError := v.fieldWithNameAndValue(val, current, idxField.Interface(), cField.diveTag, fmt.Sprintf(mapIndexFieldName, cField.name, key.Interface()), false, nil); fieldError != nil {
errs[key.Interface()] = fieldError
}
continue
}
if idxField.Kind() == reflect.Ptr && idxField.IsNil() {
if strings.Contains(cField.tag, omitempty) {
continue
}
if strings.Contains(cField.tag, required) {
errs[key.Interface()] = &FieldError{
Field: cField.name,
Tag: required,
Value: idxField.Interface(),
Kind: reflect.Ptr,
Type: cField.sliceSubtype,
}
}
continue
}
if structErrors := v.structRecursive(val, current, idxField.Interface()); structErrors != nil {
errs[key.Interface()] = structErrors
}
default:
if fieldError := v.fieldWithNameAndValue(val, current, idxField.Interface(), cField.diveTag, fmt.Sprintf(mapIndexFieldName, cField.name, key.Interface()), false, nil); fieldError != nil {
errs[key.Interface()] = fieldError
}
}
}
return errs
}
func (v *Validate) traverseSliceOrArray(val interface{}, current interface{}, valueField reflect.Value, cField *cachedField) map[int]error { func (v *Validate) traverseSliceOrArray(val interface{}, current interface{}, valueField reflect.Value, cField *cachedField) map[int]error {
errs := map[int]error{} errs := map[int]error{}
@ -698,7 +769,7 @@ func (v *Validate) traverseSliceOrArray(val interface{}, current interface{}, va
if cField.isTimeSubtype { if cField.isTimeSubtype {
if fieldError := v.fieldWithNameAndValue(val, current, idxField.Interface(), cField.diveTag, cField.name, false, nil); fieldError != nil { if fieldError := v.fieldWithNameAndValue(val, current, idxField.Interface(), cField.diveTag, fmt.Sprintf(arrayIndexFieldName, cField.name, i), false, nil); fieldError != nil {
errs[i] = fieldError errs[i] = fieldError
} }
@ -730,7 +801,7 @@ func (v *Validate) traverseSliceOrArray(val interface{}, current interface{}, va
} }
default: default:
if fieldError := v.fieldWithNameAndValue(val, current, idxField.Interface(), cField.diveTag, fmt.Sprintf(indexFieldName, cField.name, i), false, nil); fieldError != nil { if fieldError := v.fieldWithNameAndValue(val, current, idxField.Interface(), cField.diveTag, fmt.Sprintf(arrayIndexFieldName, cField.name, i), false, nil); fieldError != nil {
errs[i] = fieldError errs[i] = fieldError
} }
} }

@ -227,6 +227,18 @@ func AssertMapFieldError(t *testing.T, s map[string]*FieldError, field string, e
} }
func TestMapDiveValidation(t *testing.T) { func TestMapDiveValidation(t *testing.T) {
type Test struct {
Errs map[int]string `validate:"gt=0,dive,required"`
}
test := &Test{
Errs: map[int]string{0: "ok", 1: "", 4: "ok"},
}
errs := validate.Struct(test)
fmt.Println(errs)
} }
func TestArrayDiveValidation(t *testing.T) { func TestArrayDiveValidation(t *testing.T) {
@ -509,7 +521,7 @@ func TestArrayDiveValidation(t *testing.T) {
Equal(t, innerSliceError1.IsPlaceholderErr, false) Equal(t, innerSliceError1.IsPlaceholderErr, false)
Equal(t, innerSliceError1.IsSliceOrArray, false) Equal(t, innerSliceError1.IsSliceOrArray, false)
Equal(t, len(innerSliceError1.SliceOrArrayErrs), 0) Equal(t, len(innerSliceError1.SliceOrArrayErrs), 0)
Equal(t, innerSliceError1.Field, "Errs[2]") Equal(t, innerSliceError1.Field, "Errs[2][1]")
Equal(t, innerSliceError1.Tag, required) Equal(t, innerSliceError1.Tag, required)
type TestMultiDimensionalTimeTime2 struct { type TestMultiDimensionalTimeTime2 struct {
@ -551,7 +563,7 @@ func TestArrayDiveValidation(t *testing.T) {
Equal(t, innerSliceError1.IsPlaceholderErr, false) Equal(t, innerSliceError1.IsPlaceholderErr, false)
Equal(t, innerSliceError1.IsSliceOrArray, false) Equal(t, innerSliceError1.IsSliceOrArray, false)
Equal(t, len(innerSliceError1.SliceOrArrayErrs), 0) Equal(t, len(innerSliceError1.SliceOrArrayErrs), 0)
Equal(t, innerSliceError1.Field, "Errs[2]") Equal(t, innerSliceError1.Field, "Errs[2][1]")
Equal(t, innerSliceError1.Tag, required) Equal(t, innerSliceError1.Tag, required)
} }

Loading…
Cancel
Save