diff --git a/baked_in.go b/baked_in.go index 8e5f648..daebec8 100644 --- a/baked_in.go +++ b/baked_in.go @@ -26,6 +26,7 @@ var BakedInValidators = map[string]Func{ "gt": isGt, "gte": isGte, "eqfield": isEqField, + "eqcsfield": isEqCrossStructField, "nefield": isNeField, "gtefield": isGteField, "gtfield": isGtField, @@ -253,68 +254,98 @@ func isNe(v *Validate, topStruct reflect.Value, currentStruct reflect.Value, fie return !isEq(v, topStruct, currentStruct, field, fieldType, fieldKind, param) } -func isEqField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { +func isEqCrossStructField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - // if current == nil { - if !current.IsValid() { - panic("struct or field value not passed for cross validation") - } + // if !topStruct.IsValid() { + // panic("struct or field value not passed for cross validation") + // } - if current.Kind() == reflect.Ptr && !current.IsNil() { - current = current.Elem() + topField, topKind, ok := v.getStructFieldOK(topStruct, param) + if !ok || topKind != fieldKind { + // fmt.Println("NOT OK:", ok) + return false } - switch current.Kind() { + // fmt.Println("HERE", fieldKind) + switch fieldKind { + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return topField.Int() == field.Int() + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return topField.Uint() == field.Uint() + + case reflect.Float32, reflect.Float64: + return topField.Float() == field.Float() + + case reflect.Slice, reflect.Map, reflect.Array: + // fmt.Println(topField.Len(), field.Len()) + return int64(topField.Len()) == int64(field.Len()) case reflect.Struct: - if current.Type() == timeType || current.Type() == timePtrType { - break + // Not Same underlying type i.e. struct and time + if fieldType != topField.Type() { + return false } - current = current.FieldByName(param) + if fieldType == timeType { - if current.Kind() == reflect.Invalid { - panic(fmt.Sprintf("Field \"%s\" not found in struct", param)) + t := field.Interface().(time.Time) + fieldTime := topField.Interface().(time.Time) + + return fieldTime.Equal(t) } } - if current.Kind() == reflect.Ptr && !current.IsNil() { - current = current.Elem() + // default reflect.String: + return topField.String() == current.String() +} + +func isEqField(v *Validate, topStruct reflect.Value, currentStruct reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { + + // if !currentStruct.IsValid() { + // panic("struct or field value not passed for cross validation") + // } + + currentField, currentKind, ok := v.getStructFieldOK(currentStruct, param) + if !ok || currentKind != fieldKind { + return false } switch fieldKind { - case reflect.String: - return field.String() == current.String() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return field.Int() == current.Int() + return field.Int() == currentField.Int() case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return field.Uint() == current.Uint() + return field.Uint() == currentField.Uint() case reflect.Float32, reflect.Float64: - return field.Float() == current.Float() + return field.Float() == currentField.Float() case reflect.Slice, reflect.Map, reflect.Array: - return int64(field.Len()) == int64(current.Len()) + return int64(field.Len()) == int64(currentField.Len()) case reflect.Struct: - if fieldType == timeType || fieldType == timePtrType { - if current.Type() != timeType && current.Type() != timePtrType { - panic("Bad Top Level field type") - } + // Not Same underlying type i.e. struct and time + if fieldType != currentField.Type() { + return false + } - t := current.Interface().(time.Time) + if fieldType == timeType { + + t := currentField.Interface().(time.Time) fieldTime := field.Interface().(time.Time) return fieldTime.Equal(t) } + } - panic(fmt.Sprintf("Bad field type %T", field.Interface())) + // default reflect.String: + return field.String() == currentField.String() } func isEq(v *Validate, topStruct reflect.Value, currentStruct reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { diff --git a/util.go b/util.go index 1494690..dc612d3 100644 --- a/util.go +++ b/util.go @@ -51,12 +51,28 @@ func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (re current, kind := v.extractType(current) + // fmt.Println("SOK:", current, kind, namespace) + + // if len(namespace) == 0 { + // // if kind == reflect.Invalid { + // // return current, kind, false + // // } + // return current, kind, true + // } + + if kind == reflect.Invalid { + return current, kind, false + } + + if len(namespace) == 0 { + return current, kind, true + } + switch kind { - case reflect.Ptr, reflect.Interface, reflect.Invalid: + // case reflect.Invalid: + // return current, kind, false - if len(namespace) == 0 { - return current, kind, true - } + case reflect.Ptr, reflect.Interface: return current, kind, false @@ -64,6 +80,7 @@ func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (re typ := current.Type() fld := namespace + ns := namespace if typ != timeType && typ != timePtrType { @@ -71,23 +88,34 @@ func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (re if idx != -1 { fld = namespace[:idx] + ns = namespace[idx+1:] + } else { + ns = "" + idx = len(namespace) } - ns := namespace[idx+1:] + // ns := namespace[idx+1:] bracketIdx := strings.Index(fld, leftBracket) if bracketIdx != -1 { fld = fld[:bracketIdx] - if idx == -1 { - ns = namespace[bracketIdx:] - } else { - ns = namespace[bracketIdx:] - } + ns = namespace[bracketIdx:] + // if idx == -1 { + // ns = namespace[bracketIdx:] + // } else { + // ns = namespace[bracketIdx:] + // } } current = current.FieldByName(fld) + // if current.Kind() == reflect.Invalid { + // return current, reflect.Invalid, false + // } + + // fmt.Println("NS:", ns, idx) + return v.getStructFieldOK(current, ns) } @@ -126,5 +154,7 @@ func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (re return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(namespace[idx:idx2])), namespace[endIdx+1:]) } - return current, kind, true + // if got here there was more namespace, cannot go any deeper + panic("Invalid field namespace") + // return current, kind, false } diff --git a/validator.go b/validator.go index 13ea451..e3866f4 100644 --- a/validator.go +++ b/validator.go @@ -311,7 +311,7 @@ func (v *Validate) traverseField(topStruct reflect.Value, currentStruct reflect. case reflect.Struct: typ = current.Type() - if typ != timeType && typ != timePtrType { + if typ != timeType { // required passed validation above so stop here // if only validating the structs existance. diff --git a/validator_test.go b/validator_test.go index 2c9302b..defdb22 100644 --- a/validator_test.go +++ b/validator_test.go @@ -192,6 +192,110 @@ func ValidateValuerType(field reflect.Value) interface{} { return nil } +func TestCrossStructEqFieldValidation(t *testing.T) { + + type Inner struct { + CreatedAt *time.Time + } + + type Test struct { + Inner *Inner + CreatedAt *time.Time `validate:"eqcsfield=Inner.CreatedAt"` + } + + now := time.Now().UTC() + + inner := &Inner{ + CreatedAt: &now, + } + + test := &Test{ + Inner: inner, + CreatedAt: &now, + } + + errs := validate.Struct(test) + Equal(t, errs, nil) + + newTime := time.Now().UTC() + test.CreatedAt = &newTime + + errs = validate.Struct(test) + NotEqual(t, errs, nil) + AssertError(t, errs, "Test.CreatedAt", "CreatedAt", "eqcsfield") + + var j uint64 + var k float64 + s := "abcd" + i := 1 + j = 1 + k = 1.543 + arr := []string{"test"} + + var j2 uint64 + var k2 float64 + s2 := "abcd" + i2 := 1 + j2 = 1 + k2 = 1.543 + arr2 := []string{"test"} + arr3 := []string{"test", "test2"} + now2 := now + + errs = validate.FieldWithValue(s, s2, "eqcsfield") + Equal(t, errs, nil) + + errs = validate.FieldWithValue(i2, i, "eqcsfield") + Equal(t, errs, nil) + + errs = validate.FieldWithValue(j2, j, "eqcsfield") + Equal(t, errs, nil) + + errs = validate.FieldWithValue(k2, k, "eqcsfield") + Equal(t, errs, nil) + + errs = validate.FieldWithValue(arr2, arr, "eqcsfield") + Equal(t, errs, nil) + + errs = validate.FieldWithValue(now2, now, "eqcsfield") + Equal(t, errs, nil) + + errs = validate.FieldWithValue(arr3, arr, "eqcsfield") + NotEqual(t, errs, nil) + AssertError(t, errs, "", "", "eqcsfield") + + type SInner struct { + Name string + } + + type TStruct struct { + Inner *SInner + CreatedAt *time.Time `validate:"eqcsfield=Inner"` + } + + sinner := &SInner{ + Name: "NAME", + } + + test2 := &TStruct{ + Inner: sinner, + CreatedAt: &now, + } + + errs = validate.Struct(test2) + NotEqual(t, errs, nil) + AssertError(t, errs, "TStruct.CreatedAt", "CreatedAt", "eqcsfield") + + test2.Inner = nil + errs = validate.Struct(test2) + NotEqual(t, errs, nil) + AssertError(t, errs, "TStruct.CreatedAt", "CreatedAt", "eqcsfield") + + errs = validate.FieldWithValue(nil, 1, "eqcsfield") + NotEqual(t, errs, nil) + AssertError(t, errs, "", "", "eqcsfield") +} + func TestCrossNamespaceFieldValidation(t *testing.T) { type SliceStruct struct { @@ -328,6 +432,14 @@ func TestCrossNamespaceFieldValidation(t *testing.T) { Equal(t, kind, reflect.Ptr) Equal(t, current.String(), "<*validator.SliceStruct Value>") Equal(t, current.IsNil(), true) + + current, kind, ok = validate.getStructFieldOK(val, "Inner.SliceStructs[2].Name") + Equal(t, ok, false) + Equal(t, kind, reflect.Ptr) + Equal(t, current.String(), "<*validator.SliceStruct Value>") + Equal(t, current.IsNil(), true) + + PanicMatches(t, func() { validate.getStructFieldOK(reflect.ValueOf(1), "crazyinput") }, "Invalid field namespace") } func TestExistsValidation(t *testing.T) { @@ -2063,9 +2175,11 @@ func TestIsNeFieldValidation(t *testing.T) { errs = validate.Struct(sv) Equal(t, errs, nil) + errs = validate.FieldWithValue(nil, 1, "nefield") + Equal(t, errs, nil) + channel := make(chan string) - PanicMatches(t, func() { validate.FieldWithValue(nil, 1, "nefield") }, "struct or field value not passed for cross validation") PanicMatches(t, func() { validate.FieldWithValue(5, channel, "nefield") }, "Bad field type chan string") PanicMatches(t, func() { validate.FieldWithValue(5, now, "nefield") }, "Bad Top Level field type") @@ -2182,9 +2296,12 @@ func TestIsEqFieldValidation(t *testing.T) { NotEqual(t, errs, nil) AssertError(t, errs, "Test.Start", "Start", "eqfield") + errs = validate.FieldWithValue(nil, 1, "eqfield") + NotEqual(t, errs, nil) + AssertError(t, errs, "", "", "eqfield") + channel := make(chan string) - PanicMatches(t, func() { validate.FieldWithValue(nil, 1, "eqfield") }, "struct or field value not passed for cross validation") PanicMatches(t, func() { validate.FieldWithValue(5, channel, "eqfield") }, "Bad field type chan string") PanicMatches(t, func() { validate.FieldWithValue(5, now, "eqfield") }, "Bad Top Level field type") @@ -2199,6 +2316,28 @@ func TestIsEqFieldValidation(t *testing.T) { } PanicMatches(t, func() { validate.Struct(sv2) }, "Field \"NonExistantField\" not found in struct") + + type Inner struct { + Name string + } + + type TStruct struct { + Inner *Inner + CreatedAt *time.Time `validate:"eqfield=Inner"` + } + + inner := &Inner{ + Name: "NAME", + } + + test := &TStruct{ + Inner: inner, + CreatedAt: &now, + } + + errs = validate.Struct(test) + NotEqual(t, errs, nil) + AssertError(t, errs, "TStruct.CreatedAt", "CreatedAt", "eqfield") } func TestIsEqValidation(t *testing.T) {