finish eqcsfield + test coverage

pull/161/head
joeybloggs 9 years ago
parent d19088f865
commit 2fe52ca08f
  1. 87
      baked_in.go
  2. 52
      util.go
  3. 2
      validator.go
  4. 143
      validator_test.go

@ -26,6 +26,7 @@ var BakedInValidators = map[string]Func{
"gt": isGt, "gt": isGt,
"gte": isGte, "gte": isGte,
"eqfield": isEqField, "eqfield": isEqField,
"eqcsfield": isEqCrossStructField,
"nefield": isNeField, "nefield": isNeField,
"gtefield": isGteField, "gtefield": isGteField,
"gtfield": isGtField, "gtfield": isGtField,
@ -253,68 +254,98 @@ func isNe(v *Validate, topStruct reflect.Value, currentStruct reflect.Value, fie
return !isEq(v, topStruct, currentStruct, field, fieldType, fieldKind, param) return !isEq(v, topStruct, currentStruct, field, fieldType, fieldKind, param)
} }
func isEqField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { func isEqCrossStructField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {
// if current == nil { // if !topStruct.IsValid() {
if !current.IsValid() { // panic("struct or field value not passed for cross validation")
panic("struct or field value not passed for cross validation") // }
}
if current.Kind() == reflect.Ptr && !current.IsNil() { topField, topKind, ok := v.getStructFieldOK(topStruct, param)
current = current.Elem() if !ok || topKind != fieldKind {
// fmt.Println("NOT OK:", ok)
return false
} }
switch current.Kind() { // fmt.Println("HERE", fieldKind)
switch fieldKind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return topField.Int() == field.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return topField.Uint() == field.Uint()
case reflect.Float32, reflect.Float64:
return topField.Float() == field.Float()
case reflect.Slice, reflect.Map, reflect.Array:
// fmt.Println(topField.Len(), field.Len())
return int64(topField.Len()) == int64(field.Len())
case reflect.Struct: case reflect.Struct:
if current.Type() == timeType || current.Type() == timePtrType { // Not Same underlying type i.e. struct and time
break if fieldType != topField.Type() {
return false
} }
current = current.FieldByName(param) if fieldType == timeType {
if current.Kind() == reflect.Invalid { t := field.Interface().(time.Time)
panic(fmt.Sprintf("Field \"%s\" not found in struct", param)) fieldTime := topField.Interface().(time.Time)
return fieldTime.Equal(t)
} }
} }
if current.Kind() == reflect.Ptr && !current.IsNil() { // default reflect.String:
current = current.Elem() return topField.String() == current.String()
}
func isEqField(v *Validate, topStruct reflect.Value, currentStruct reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {
// if !currentStruct.IsValid() {
// panic("struct or field value not passed for cross validation")
// }
currentField, currentKind, ok := v.getStructFieldOK(currentStruct, param)
if !ok || currentKind != fieldKind {
return false
} }
switch fieldKind { switch fieldKind {
case reflect.String:
return field.String() == current.String()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int() == current.Int() return field.Int() == currentField.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return field.Uint() == current.Uint() return field.Uint() == currentField.Uint()
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
return field.Float() == current.Float() return field.Float() == currentField.Float()
case reflect.Slice, reflect.Map, reflect.Array: case reflect.Slice, reflect.Map, reflect.Array:
return int64(field.Len()) == int64(current.Len()) return int64(field.Len()) == int64(currentField.Len())
case reflect.Struct: case reflect.Struct:
if fieldType == timeType || fieldType == timePtrType {
if current.Type() != timeType && current.Type() != timePtrType { // Not Same underlying type i.e. struct and time
panic("Bad Top Level field type") if fieldType != currentField.Type() {
} return false
}
t := current.Interface().(time.Time) if fieldType == timeType {
t := currentField.Interface().(time.Time)
fieldTime := field.Interface().(time.Time) fieldTime := field.Interface().(time.Time)
return fieldTime.Equal(t) return fieldTime.Equal(t)
} }
} }
panic(fmt.Sprintf("Bad field type %T", field.Interface())) // default reflect.String:
return field.String() == currentField.String()
} }
func isEq(v *Validate, topStruct reflect.Value, currentStruct reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { func isEq(v *Validate, topStruct reflect.Value, currentStruct reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {

@ -51,12 +51,28 @@ func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (re
current, kind := v.extractType(current) current, kind := v.extractType(current)
// fmt.Println("SOK:", current, kind, namespace)
// if len(namespace) == 0 {
// // if kind == reflect.Invalid {
// // return current, kind, false
// // }
// return current, kind, true
// }
if kind == reflect.Invalid {
return current, kind, false
}
if len(namespace) == 0 {
return current, kind, true
}
switch kind { switch kind {
case reflect.Ptr, reflect.Interface, reflect.Invalid: // case reflect.Invalid:
// return current, kind, false
if len(namespace) == 0 { case reflect.Ptr, reflect.Interface:
return current, kind, true
}
return current, kind, false return current, kind, false
@ -64,6 +80,7 @@ func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (re
typ := current.Type() typ := current.Type()
fld := namespace fld := namespace
ns := namespace
if typ != timeType && typ != timePtrType { if typ != timeType && typ != timePtrType {
@ -71,23 +88,34 @@ func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (re
if idx != -1 { if idx != -1 {
fld = namespace[:idx] fld = namespace[:idx]
ns = namespace[idx+1:]
} else {
ns = ""
idx = len(namespace)
} }
ns := namespace[idx+1:] // ns := namespace[idx+1:]
bracketIdx := strings.Index(fld, leftBracket) bracketIdx := strings.Index(fld, leftBracket)
if bracketIdx != -1 { if bracketIdx != -1 {
fld = fld[:bracketIdx] fld = fld[:bracketIdx]
if idx == -1 { ns = namespace[bracketIdx:]
ns = namespace[bracketIdx:] // if idx == -1 {
} else { // ns = namespace[bracketIdx:]
ns = namespace[bracketIdx:] // } else {
} // ns = namespace[bracketIdx:]
// }
} }
current = current.FieldByName(fld) current = current.FieldByName(fld)
// if current.Kind() == reflect.Invalid {
// return current, reflect.Invalid, false
// }
// fmt.Println("NS:", ns, idx)
return v.getStructFieldOK(current, ns) return v.getStructFieldOK(current, ns)
} }
@ -126,5 +154,7 @@ func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (re
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(namespace[idx:idx2])), namespace[endIdx+1:]) return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(namespace[idx:idx2])), namespace[endIdx+1:])
} }
return current, kind, true // if got here there was more namespace, cannot go any deeper
panic("Invalid field namespace")
// return current, kind, false
} }

@ -311,7 +311,7 @@ func (v *Validate) traverseField(topStruct reflect.Value, currentStruct reflect.
case reflect.Struct: case reflect.Struct:
typ = current.Type() typ = current.Type()
if typ != timeType && typ != timePtrType { if typ != timeType {
// required passed validation above so stop here // required passed validation above so stop here
// if only validating the structs existance. // if only validating the structs existance.

@ -192,6 +192,110 @@ func ValidateValuerType(field reflect.Value) interface{} {
return nil return nil
} }
func TestCrossStructEqFieldValidation(t *testing.T) {
type Inner struct {
CreatedAt *time.Time
}
type Test struct {
Inner *Inner
CreatedAt *time.Time `validate:"eqcsfield=Inner.CreatedAt"`
}
now := time.Now().UTC()
inner := &Inner{
CreatedAt: &now,
}
test := &Test{
Inner: inner,
CreatedAt: &now,
}
errs := validate.Struct(test)
Equal(t, errs, nil)
newTime := time.Now().UTC()
test.CreatedAt = &newTime
errs = validate.Struct(test)
NotEqual(t, errs, nil)
AssertError(t, errs, "Test.CreatedAt", "CreatedAt", "eqcsfield")
var j uint64
var k float64
s := "abcd"
i := 1
j = 1
k = 1.543
arr := []string{"test"}
var j2 uint64
var k2 float64
s2 := "abcd"
i2 := 1
j2 = 1
k2 = 1.543
arr2 := []string{"test"}
arr3 := []string{"test", "test2"}
now2 := now
errs = validate.FieldWithValue(s, s2, "eqcsfield")
Equal(t, errs, nil)
errs = validate.FieldWithValue(i2, i, "eqcsfield")
Equal(t, errs, nil)
errs = validate.FieldWithValue(j2, j, "eqcsfield")
Equal(t, errs, nil)
errs = validate.FieldWithValue(k2, k, "eqcsfield")
Equal(t, errs, nil)
errs = validate.FieldWithValue(arr2, arr, "eqcsfield")
Equal(t, errs, nil)
errs = validate.FieldWithValue(now2, now, "eqcsfield")
Equal(t, errs, nil)
errs = validate.FieldWithValue(arr3, arr, "eqcsfield")
NotEqual(t, errs, nil)
AssertError(t, errs, "", "", "eqcsfield")
type SInner struct {
Name string
}
type TStruct struct {
Inner *SInner
CreatedAt *time.Time `validate:"eqcsfield=Inner"`
}
sinner := &SInner{
Name: "NAME",
}
test2 := &TStruct{
Inner: sinner,
CreatedAt: &now,
}
errs = validate.Struct(test2)
NotEqual(t, errs, nil)
AssertError(t, errs, "TStruct.CreatedAt", "CreatedAt", "eqcsfield")
test2.Inner = nil
errs = validate.Struct(test2)
NotEqual(t, errs, nil)
AssertError(t, errs, "TStruct.CreatedAt", "CreatedAt", "eqcsfield")
errs = validate.FieldWithValue(nil, 1, "eqcsfield")
NotEqual(t, errs, nil)
AssertError(t, errs, "", "", "eqcsfield")
}
func TestCrossNamespaceFieldValidation(t *testing.T) { func TestCrossNamespaceFieldValidation(t *testing.T) {
type SliceStruct struct { type SliceStruct struct {
@ -328,6 +432,14 @@ func TestCrossNamespaceFieldValidation(t *testing.T) {
Equal(t, kind, reflect.Ptr) Equal(t, kind, reflect.Ptr)
Equal(t, current.String(), "<*validator.SliceStruct Value>") Equal(t, current.String(), "<*validator.SliceStruct Value>")
Equal(t, current.IsNil(), true) Equal(t, current.IsNil(), true)
current, kind, ok = validate.getStructFieldOK(val, "Inner.SliceStructs[2].Name")
Equal(t, ok, false)
Equal(t, kind, reflect.Ptr)
Equal(t, current.String(), "<*validator.SliceStruct Value>")
Equal(t, current.IsNil(), true)
PanicMatches(t, func() { validate.getStructFieldOK(reflect.ValueOf(1), "crazyinput") }, "Invalid field namespace")
} }
func TestExistsValidation(t *testing.T) { func TestExistsValidation(t *testing.T) {
@ -2063,9 +2175,11 @@ func TestIsNeFieldValidation(t *testing.T) {
errs = validate.Struct(sv) errs = validate.Struct(sv)
Equal(t, errs, nil) Equal(t, errs, nil)
errs = validate.FieldWithValue(nil, 1, "nefield")
Equal(t, errs, nil)
channel := make(chan string) channel := make(chan string)
PanicMatches(t, func() { validate.FieldWithValue(nil, 1, "nefield") }, "struct or field value not passed for cross validation")
PanicMatches(t, func() { validate.FieldWithValue(5, channel, "nefield") }, "Bad field type chan string") PanicMatches(t, func() { validate.FieldWithValue(5, channel, "nefield") }, "Bad field type chan string")
PanicMatches(t, func() { validate.FieldWithValue(5, now, "nefield") }, "Bad Top Level field type") PanicMatches(t, func() { validate.FieldWithValue(5, now, "nefield") }, "Bad Top Level field type")
@ -2182,9 +2296,12 @@ func TestIsEqFieldValidation(t *testing.T) {
NotEqual(t, errs, nil) NotEqual(t, errs, nil)
AssertError(t, errs, "Test.Start", "Start", "eqfield") AssertError(t, errs, "Test.Start", "Start", "eqfield")
errs = validate.FieldWithValue(nil, 1, "eqfield")
NotEqual(t, errs, nil)
AssertError(t, errs, "", "", "eqfield")
channel := make(chan string) channel := make(chan string)
PanicMatches(t, func() { validate.FieldWithValue(nil, 1, "eqfield") }, "struct or field value not passed for cross validation")
PanicMatches(t, func() { validate.FieldWithValue(5, channel, "eqfield") }, "Bad field type chan string") PanicMatches(t, func() { validate.FieldWithValue(5, channel, "eqfield") }, "Bad field type chan string")
PanicMatches(t, func() { validate.FieldWithValue(5, now, "eqfield") }, "Bad Top Level field type") PanicMatches(t, func() { validate.FieldWithValue(5, now, "eqfield") }, "Bad Top Level field type")
@ -2199,6 +2316,28 @@ func TestIsEqFieldValidation(t *testing.T) {
} }
PanicMatches(t, func() { validate.Struct(sv2) }, "Field \"NonExistantField\" not found in struct") PanicMatches(t, func() { validate.Struct(sv2) }, "Field \"NonExistantField\" not found in struct")
type Inner struct {
Name string
}
type TStruct struct {
Inner *Inner
CreatedAt *time.Time `validate:"eqfield=Inner"`
}
inner := &Inner{
Name: "NAME",
}
test := &TStruct{
Inner: inner,
CreatedAt: &now,
}
errs = validate.Struct(test)
NotEqual(t, errs, nil)
AssertError(t, errs, "TStruct.CreatedAt", "CreatedAt", "eqfield")
} }
func TestIsEqValidation(t *testing.T) { func TestIsEqValidation(t *testing.T) {

Loading…
Cancel
Save