diff --git a/baked_in.go b/baked_in.go index 0c6ca4f..2ab23e4 100644 --- a/baked_in.go +++ b/baked_in.go @@ -669,126 +669,86 @@ func hasMinOf(v *Validate, topStruct reflect.Value, currentStruct reflect.Value, return isGte(v, topStruct, currentStruct, field, fieldType, fieldKind, param) } -func isLteField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { +func isLteField(v *Validate, topStruct reflect.Value, currentStruct reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - if !current.IsValid() { - panic("struct not passed for cross validation") - } - - if current.Kind() == reflect.Ptr && !current.IsNil() { - current = current.Elem() - } - - switch current.Kind() { - - case reflect.Struct: - - if current.Type() == timeType || current.Type() == timePtrType { - break - } - - current = current.FieldByName(param) - - if current.Kind() == reflect.Invalid { - panic(fmt.Sprintf("Field \"%s\" not found in struct", param)) - } - } - - if current.Kind() == reflect.Ptr && !current.IsNil() { - current = current.Elem() + currentField, currentKind, ok := v.getStructFieldOK(currentStruct, param) + if !ok || currentKind != fieldKind { + return false } switch fieldKind { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return field.Int() <= current.Int() + return field.Int() <= currentField.Int() case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return field.Uint() <= current.Uint() + return field.Uint() <= currentField.Uint() case reflect.Float32, reflect.Float64: - return field.Float() <= current.Float() + return field.Float() <= currentField.Float() case reflect.Struct: - if field.Type() == timeType || field.Type() == timePtrType { + // Not Same underlying type i.e. struct and time + if fieldType != currentField.Type() { + return false + } - if current.Type() != timeType && current.Type() != timePtrType { - panic("Bad Top Level field type") - } + if fieldType == timeType { - t := current.Interface().(time.Time) + t := currentField.Interface().(time.Time) fieldTime := field.Interface().(time.Time) return fieldTime.Before(t) || fieldTime.Equal(t) } } - panic(fmt.Sprintf("Bad field type %T", field.Interface())) + // default reflect.String + return len(field.String()) <= len(currentField.String()) } -func isLtField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { +func isLtField(v *Validate, topStruct reflect.Value, currentStruct reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - if !current.IsValid() { - panic("struct not passed for cross validation") - } - - if current.Kind() == reflect.Ptr && !current.IsNil() { - current = current.Elem() - } - - switch current.Kind() { - - case reflect.Struct: - - if current.Type() == timeType || current.Type() == timePtrType { - break - } - - current = current.FieldByName(param) - - if current.Kind() == reflect.Invalid { - panic(fmt.Sprintf("Field \"%s\" not found in struct", param)) - } - } - - if current.Kind() == reflect.Ptr && !current.IsNil() { - current = current.Elem() + currentField, currentKind, ok := v.getStructFieldOK(currentStruct, param) + if !ok || currentKind != fieldKind { + return false } switch fieldKind { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return field.Int() < current.Int() + return field.Int() < currentField.Int() case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return field.Uint() < current.Uint() + return field.Uint() < currentField.Uint() case reflect.Float32, reflect.Float64: - return field.Float() < current.Float() + return field.Float() < currentField.Float() case reflect.Struct: - if field.Type() == timeType || field.Type() == timePtrType { + // Not Same underlying type i.e. struct and time + if fieldType != currentField.Type() { + return false + } - if current.Type() != timeType && current.Type() != timePtrType { - panic("Bad Top Level field type") - } + if fieldType == timeType { - t := current.Interface().(time.Time) + t := currentField.Interface().(time.Time) fieldTime := field.Interface().(time.Time) return fieldTime.Before(t) } } - panic(fmt.Sprintf("Bad field type %T", field.Interface())) + // default reflect.String + return len(field.String()) < len(currentField.String()) } func isLte(v *Validate, topStruct reflect.Value, currentStruct reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { diff --git a/validator_test.go b/validator_test.go index e9a9fb6..93cdb55 100644 --- a/validator_test.go +++ b/validator_test.go @@ -2604,6 +2604,13 @@ func TestLtField(t *testing.T) { NotEqual(t, errs, nil) AssertError(t, errs, "", "", "ltfield") + errs = validate.FieldWithValue(timeTest, &end, "ltfield") + NotEqual(t, errs, nil) + AssertError(t, errs, "", "", "ltfield") + + errs = validate.FieldWithValue("test", "tes", "ltfield") + Equal(t, errs, nil) + type IntTest struct { Val1 int `validate:"required"` Val2 int `validate:"required,ltfield=Val1"` @@ -2691,7 +2698,10 @@ func TestLtField(t *testing.T) { NotEqual(t, errs, nil) AssertError(t, errs, "", "", "ltfield") - PanicMatches(t, func() { validate.FieldWithValue(nil, 5, "ltfield") }, "struct not passed for cross validation") + errs = validate.FieldWithValue(nil, 5, "ltfield") + NotEqual(t, errs, nil) + AssertError(t, errs, "", "", "ltfield") + PanicMatches(t, func() { validate.FieldWithValue(1, "T", "ltfield") }, "Bad field type string") PanicMatches(t, func() { validate.FieldWithValue(1, end, "ltfield") }, "Bad Top Level field type") @@ -2743,6 +2753,16 @@ func TestLteField(t *testing.T) { NotEqual(t, errs, nil) AssertError(t, errs, "", "", "ltefield") + errs = validate.FieldWithValue(timeTest, &end, "ltefield") + NotEqual(t, errs, nil) + AssertError(t, errs, "", "", "ltefield") + + errs = validate.FieldWithValue("test", "tes", "ltefield") + Equal(t, errs, nil) + + errs = validate.FieldWithValue("test", "test", "ltefield") + Equal(t, errs, nil) + type IntTest struct { Val1 int `validate:"required"` Val2 int `validate:"required,ltefield=Val1"` @@ -2830,7 +2850,10 @@ func TestLteField(t *testing.T) { NotEqual(t, errs, nil) AssertError(t, errs, "", "", "ltefield") - PanicMatches(t, func() { validate.FieldWithValue(nil, 5, "ltefield") }, "struct not passed for cross validation") + errs = validate.FieldWithValue(nil, 5, "ltefield") + NotEqual(t, errs, nil) + AssertError(t, errs, "", "", "ltefield") + PanicMatches(t, func() { validate.FieldWithValue(1, "T", "ltefield") }, "Bad field type string") PanicMatches(t, func() { validate.FieldWithValue(1, end, "ltefield") }, "Bad Top Level field type") @@ -2979,7 +3002,10 @@ func TestGteField(t *testing.T) { NotEqual(t, errs, nil) AssertError(t, errs, "", "", "gtefield") - PanicMatches(t, func() { validate.FieldWithValue(nil, 1, "gtefield") }, "struct not passed for cross validation") + errs = validate.FieldWithValue(nil, 1, "gtefield") + NotEqual(t, errs, nil) + AssertError(t, errs, "", "", "gtefield") + PanicMatches(t, func() { validate.FieldWithValue(5, "T", "gtefield") }, "Bad field type string") PanicMatches(t, func() { validate.FieldWithValue(5, start, "gtefield") }, "Bad Top Level field type")