diff --git a/baked_in.go b/baked_in.go index 59746e3..ee6e4a9 100644 --- a/baked_in.go +++ b/baked_in.go @@ -260,7 +260,7 @@ func contains(v *Validate, topStruct reflect.Value, currentStructOrField reflect func isNeField(v *Validate, topStruct reflect.Value, currentStructOrField reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - currentField, currentKind, ok := v.getStructFieldOK(currentStructOrField, param) + currentField, currentKind, ok := v.GetStructFieldOK(currentStructOrField, param) if !ok || currentKind != fieldKind { return true @@ -307,7 +307,7 @@ func isNe(v *Validate, topStruct reflect.Value, currentStructOrField reflect.Val func isLteCrossStructField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - topField, topKind, ok := v.getStructFieldOK(topStruct, param) + topField, topKind, ok := v.GetStructFieldOK(topStruct, param) if !ok || topKind != fieldKind { return false } @@ -348,7 +348,7 @@ func isLteCrossStructField(v *Validate, topStruct reflect.Value, current reflect func isLtCrossStructField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - topField, topKind, ok := v.getStructFieldOK(topStruct, param) + topField, topKind, ok := v.GetStructFieldOK(topStruct, param) if !ok || topKind != fieldKind { return false } @@ -389,7 +389,7 @@ func isLtCrossStructField(v *Validate, topStruct reflect.Value, current reflect. func isGteCrossStructField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - topField, topKind, ok := v.getStructFieldOK(topStruct, param) + topField, topKind, ok := v.GetStructFieldOK(topStruct, param) if !ok || topKind != fieldKind { return false } @@ -430,7 +430,7 @@ func isGteCrossStructField(v *Validate, topStruct reflect.Value, current reflect func isGtCrossStructField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - topField, topKind, ok := v.getStructFieldOK(topStruct, param) + topField, topKind, ok := v.GetStructFieldOK(topStruct, param) if !ok || topKind != fieldKind { return false } @@ -471,7 +471,7 @@ func isGtCrossStructField(v *Validate, topStruct reflect.Value, current reflect. func isNeCrossStructField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - topField, currentKind, ok := v.getStructFieldOK(topStruct, param) + topField, currentKind, ok := v.GetStructFieldOK(topStruct, param) if !ok || currentKind != fieldKind { return true } @@ -512,7 +512,7 @@ func isNeCrossStructField(v *Validate, topStruct reflect.Value, current reflect. func isEqCrossStructField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - topField, topKind, ok := v.getStructFieldOK(topStruct, param) + topField, topKind, ok := v.GetStructFieldOK(topStruct, param) if !ok || topKind != fieldKind { return false } @@ -553,7 +553,7 @@ func isEqCrossStructField(v *Validate, topStruct reflect.Value, current reflect. func isEqField(v *Validate, topStruct reflect.Value, currentStructOrField reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - currentField, currentKind, ok := v.getStructFieldOK(currentStructOrField, param) + currentField, currentKind, ok := v.GetStructFieldOK(currentStructOrField, param) if !ok || currentKind != fieldKind { return false } @@ -718,7 +718,7 @@ func hasValue(v *Validate, topStruct reflect.Value, currentStructOrField reflect func isGteField(v *Validate, topStruct reflect.Value, currentStructOrField reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - currentField, currentKind, ok := v.getStructFieldOK(currentStructOrField, param) + currentField, currentKind, ok := v.GetStructFieldOK(currentStructOrField, param) if !ok || currentKind != fieldKind { return false } @@ -759,7 +759,7 @@ func isGteField(v *Validate, topStruct reflect.Value, currentStructOrField refle func isGtField(v *Validate, topStruct reflect.Value, currentStructOrField reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - currentField, currentKind, ok := v.getStructFieldOK(currentStructOrField, param) + currentField, currentKind, ok := v.GetStructFieldOK(currentStructOrField, param) if !ok || currentKind != fieldKind { return false } @@ -927,7 +927,7 @@ func hasMinOf(v *Validate, topStruct reflect.Value, currentStructOrField reflect func isLteField(v *Validate, topStruct reflect.Value, currentStructOrField reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - currentField, currentKind, ok := v.getStructFieldOK(currentStructOrField, param) + currentField, currentKind, ok := v.GetStructFieldOK(currentStructOrField, param) if !ok || currentKind != fieldKind { return false } @@ -968,7 +968,7 @@ func isLteField(v *Validate, topStruct reflect.Value, currentStructOrField refle func isLtField(v *Validate, topStruct reflect.Value, currentStructOrField reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool { - currentField, currentKind, ok := v.getStructFieldOK(currentStructOrField, param) + currentField, currentKind, ok := v.GetStructFieldOK(currentStructOrField, param) if !ok || currentKind != fieldKind { return false } diff --git a/util.go b/util.go index 404bb78..96ab20e 100644 --- a/util.go +++ b/util.go @@ -29,7 +29,11 @@ var ( } ) -func (v *Validate) extractType(current reflect.Value) (reflect.Value, reflect.Kind) { +// ExtractType gets the actual underlying type of field value. +// It will dive into pointers, customTypes and return you the +// underlying value and it's kind. +// it is exposed for use within you Custom Functions +func (v *Validate) ExtractType(current reflect.Value) (reflect.Value, reflect.Kind) { switch current.Kind() { case reflect.Ptr: @@ -38,7 +42,7 @@ func (v *Validate) extractType(current reflect.Value) (reflect.Value, reflect.Ki return current, reflect.Ptr } - return v.extractType(current.Elem()) + return v.ExtractType(current.Elem()) case reflect.Interface: @@ -46,7 +50,7 @@ func (v *Validate) extractType(current reflect.Value) (reflect.Value, reflect.Ki return current, reflect.Interface } - return v.extractType(current.Elem()) + return v.ExtractType(current.Elem()) case reflect.Invalid: return current, reflect.Invalid @@ -55,7 +59,7 @@ func (v *Validate) extractType(current reflect.Value) (reflect.Value, reflect.Ki if v.hasCustomFuncs { if fn, ok := v.customTypeFuncs[current.Type()]; ok { - return v.extractType(reflect.ValueOf(fn(current))) + return v.ExtractType(reflect.ValueOf(fn(current))) } } @@ -63,9 +67,13 @@ func (v *Validate) extractType(current reflect.Value) (reflect.Value, reflect.Ki } } -func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool) { +// GetStructFieldOK traverses a struct to retrieve a specific field denoted by the provided namespace and +// returns the field, field kind and whether is was successful in retrieving the field at all. +// NOTE: when not successful ok will be false, this can happen when a nested struct is nil and so the field +// could not be retrived because it didnt exist. +func (v *Validate) GetStructFieldOK(current reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool) { - current, kind := v.extractType(current) + current, kind := v.ExtractType(current) if kind == reflect.Invalid { return current, kind, false @@ -108,7 +116,7 @@ func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (re current = current.FieldByName(fld) - return v.getStructFieldOK(current, ns) + return v.GetStructFieldOK(current, ns) } case reflect.Array, reflect.Slice: @@ -129,7 +137,7 @@ func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (re } } - return v.getStructFieldOK(current.Index(arrIdx), namespace[startIdx:]) + return v.GetStructFieldOK(current.Index(arrIdx), namespace[startIdx:]) case reflect.Map: idx := strings.Index(namespace, leftBracket) + 1 @@ -148,47 +156,47 @@ func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (re switch current.Type().Key().Kind() { case reflect.Int: i, _ := strconv.Atoi(key) - return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:]) + return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:]) case reflect.Int8: i, _ := strconv.ParseInt(key, 10, 8) - return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(int8(i))), namespace[endIdx+1:]) + return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(int8(i))), namespace[endIdx+1:]) case reflect.Int16: i, _ := strconv.ParseInt(key, 10, 16) - return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(int16(i))), namespace[endIdx+1:]) + return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(int16(i))), namespace[endIdx+1:]) case reflect.Int32: i, _ := strconv.ParseInt(key, 10, 32) - return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(int32(i))), namespace[endIdx+1:]) + return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(int32(i))), namespace[endIdx+1:]) case reflect.Int64: i, _ := strconv.ParseInt(key, 10, 64) - return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:]) + return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:]) case reflect.Uint: i, _ := strconv.ParseUint(key, 10, 0) - return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(uint(i))), namespace[endIdx+1:]) + return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(uint(i))), namespace[endIdx+1:]) case reflect.Uint8: i, _ := strconv.ParseUint(key, 10, 8) - return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(uint8(i))), namespace[endIdx+1:]) + return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(uint8(i))), namespace[endIdx+1:]) case reflect.Uint16: i, _ := strconv.ParseUint(key, 10, 16) - return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(uint16(i))), namespace[endIdx+1:]) + return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(uint16(i))), namespace[endIdx+1:]) case reflect.Uint32: i, _ := strconv.ParseUint(key, 10, 32) - return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(uint32(i))), namespace[endIdx+1:]) + return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(uint32(i))), namespace[endIdx+1:]) case reflect.Uint64: i, _ := strconv.ParseUint(key, 10, 64) - return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:]) + return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:]) case reflect.Float32: f, _ := strconv.ParseFloat(key, 32) - return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(float32(f))), namespace[endIdx+1:]) + return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(float32(f))), namespace[endIdx+1:]) case reflect.Float64: f, _ := strconv.ParseFloat(key, 64) - return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(f)), namespace[endIdx+1:]) + return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(f)), namespace[endIdx+1:]) case reflect.Bool: b, _ := strconv.ParseBool(key) - return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(b)), namespace[endIdx+1:]) + return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(b)), namespace[endIdx+1:]) // reflect.Type = string default: - return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(key)), namespace[endIdx+1:]) + return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(key)), namespace[endIdx+1:]) } } diff --git a/validator.go b/validator.go index c36371d..1703ddc 100644 --- a/validator.go +++ b/validator.go @@ -282,7 +282,7 @@ func (v *Validate) FieldWithValue(val interface{}, field interface{}, tag string func (v *Validate) StructPartial(current interface{}, fields ...string) error { v.initCheck() - sv, _ := v.extractType(reflect.ValueOf(current)) + sv, _ := v.ExtractType(reflect.ValueOf(current)) name := sv.Type().Name() m := map[string]*struct{}{} @@ -340,7 +340,7 @@ func (v *Validate) StructPartial(current interface{}, fields ...string) error { func (v *Validate) StructExcept(current interface{}, fields ...string) error { v.initCheck() - sv, _ := v.extractType(reflect.ValueOf(current)) + sv, _ := v.ExtractType(reflect.ValueOf(current)) name := sv.Type().Name() m := map[string]*struct{}{} @@ -435,7 +435,7 @@ func (v *Validate) traverseField(topStruct reflect.Value, currentStruct reflect. v.tagsCache.Set(tag, cTag) } - current, kind := v.extractType(current) + current, kind := v.ExtractType(current) var typ reflect.Type switch kind { diff --git a/validator_test.go b/validator_test.go index 63b79c1..b32b0cd 100644 --- a/validator_test.go +++ b/validator_test.go @@ -1088,125 +1088,125 @@ func TestCrossNamespaceFieldValidation(t *testing.T) { val := reflect.ValueOf(test) - current, kind, ok := validate.getStructFieldOK(val, "Inner.CreatedAt") + current, kind, ok := validate.GetStructFieldOK(val, "Inner.CreatedAt") Equal(t, ok, true) Equal(t, kind, reflect.Struct) tm, ok := current.Interface().(time.Time) Equal(t, ok, true) Equal(t, tm, now) - current, kind, ok = validate.getStructFieldOK(val, "Inner.Slice[1]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.Slice[1]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = validate.getStructFieldOK(val, "Inner.CrazyNonExistantField") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.CrazyNonExistantField") Equal(t, ok, false) - current, kind, ok = validate.getStructFieldOK(val, "Inner.Slice[101]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.Slice[101]") Equal(t, ok, false) - current, kind, ok = validate.getStructFieldOK(val, "Inner.Map[key3]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.Map[key3]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val3") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapMap[key2][key2-1]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapMap[key2][key2-1]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapStructs[key2].Name") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapStructs[key2].Name") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "name2") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapMapStruct[key3][key3-1].Name") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapMapStruct[key3][key3-1].Name") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "name3") - current, kind, ok = validate.getStructFieldOK(val, "Inner.SliceSlice[2][0]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.SliceSlice[2][0]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "7") - current, kind, ok = validate.getStructFieldOK(val, "Inner.SliceSliceStruct[2][1].Name") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.SliceSliceStruct[2][1].Name") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "name8") - current, kind, ok = validate.getStructFieldOK(val, "Inner.SliceMap[1][key5]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.SliceMap[1][key5]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val5") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapSlice[key3][2]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapSlice[key3][2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "9") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapInt[2]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapInt[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapInt8[2]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapInt8[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapInt16[2]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapInt16[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapInt32[2]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapInt32[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapInt64[2]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapInt64[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapUint[2]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapUint[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapUint8[2]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapUint8[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapUint16[2]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapUint16[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapUint32[2]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapUint32[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapUint64[2]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapUint64[2]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapFloat32[3.03]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapFloat32[3.03]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val3") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapFloat64[2.02]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapFloat64[2.02]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val2") - current, kind, ok = validate.getStructFieldOK(val, "Inner.MapBool[true]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.MapBool[true]") Equal(t, ok, true) Equal(t, kind, reflect.String) Equal(t, current.String(), "val1") @@ -1232,19 +1232,19 @@ func TestCrossNamespaceFieldValidation(t *testing.T) { val = reflect.ValueOf(test) - current, kind, ok = validate.getStructFieldOK(val, "Inner.SliceStructs[2]") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.SliceStructs[2]") Equal(t, ok, true) Equal(t, kind, reflect.Ptr) Equal(t, current.String(), "<*validator.SliceStruct Value>") Equal(t, current.IsNil(), true) - current, kind, ok = validate.getStructFieldOK(val, "Inner.SliceStructs[2].Name") + current, kind, ok = validate.GetStructFieldOK(val, "Inner.SliceStructs[2].Name") Equal(t, ok, false) Equal(t, kind, reflect.Ptr) Equal(t, current.String(), "<*validator.SliceStruct Value>") Equal(t, current.IsNil(), true) - PanicMatches(t, func() { validate.getStructFieldOK(reflect.ValueOf(1), "crazyinput") }, "Invalid field namespace") + PanicMatches(t, func() { validate.GetStructFieldOK(reflect.ValueOf(1), "crazyinput") }, "Invalid field namespace") } func TestExistsValidation(t *testing.T) {