diff --git a/validator_instance.go b/validator_instance.go index 08638ae..10b9da2 100644 --- a/validator_instance.go +++ b/validator_instance.go @@ -189,26 +189,36 @@ func (v *Validate) RegisterAlias(alias, tags string) { // RegisterStructValidation registers a StructLevelFunc against a number of types. // +// It returns error when type being passed is a pointer. // NOTE: // - this method is not thread-safe it is intended that these all be registered prior to any validation -func (v *Validate) RegisterStructValidation(fn StructLevelFunc, types ...interface{}) { - v.RegisterStructValidationCtx(wrapStructLevelFunc(fn), types...) +func (v *Validate) RegisterStructValidation(fn StructLevelFunc, types ...interface{}) error { + return v.RegisterStructValidationCtx(wrapStructLevelFunc(fn), types...) } // RegisterStructValidationCtx registers a StructLevelFuncCtx against a number of types and allows passing // of contextual validation information via context.Context. // +// It returns error when type being passed is a pointer. // NOTE: // - this method is not thread-safe it is intended that these all be registered prior to any validation -func (v *Validate) RegisterStructValidationCtx(fn StructLevelFuncCtx, types ...interface{}) { +func (v *Validate) RegisterStructValidationCtx(fn StructLevelFuncCtx, types ...interface{}) error { if v.structLevelFuncs == nil { v.structLevelFuncs = make(map[reflect.Type]StructLevelFuncCtx) } + for _, t := range types { + if reflect.ValueOf(t).Kind() == reflect.Ptr { + return fmt.Errorf("error") + } + } + for _, t := range types { v.structLevelFuncs[reflect.TypeOf(t)] = fn } + + return nil } // RegisterCustomTypeFunc registers a CustomTypeFunc against a number of types diff --git a/validator_test.go b/validator_test.go index 2efabc5..3d6692c 100644 --- a/validator_test.go +++ b/validator_test.go @@ -8149,3 +8149,13 @@ func TestKeyOrs(t *testing.T) { AssertDeepError(t, errs, "Test2.Test1[badtestkey]", "Test2.Test1[badtestkey]", "Test1[badtestkey]", "Test1[badtestkey]", "okkey", "eq=testkey|eq=testkeyok") AssertDeepError(t, errs, "Test2.Test1[badtestkey]", "Test2.Test1[badtestkey]", "Test1[badtestkey]", "Test1[badtestkey]", "eq", "eq") } + +func TestStructLevelValidationsPointerPassing(t *testing.T) { + v1 := New() + err1 := v1.RegisterStructValidation(StructValidationTestStruct, &TestStruct{}) + NotEqual(t, err1, nil) + + v2 := New() + err2 := v2.RegisterStructValidation(StructValidationTestStruct, TestStruct{}) + Equal(t, err2, nil) +}