修复container/group在相同的key并发get时,可能会初始化多次的bug (#606)

* 修复container/group在相同的key并发get时,可能会初始化多次的bug

* update unit test case

Co-authored-by: demons <lu.xu@zenjoy.net>
pull/651/head
Demons 4 years ago committed by GitHub
parent 4160e34fb6
commit e502a9f491
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 29
      pkg/container/group/group.go
  2. 16
      pkg/container/group/group_test.go

@ -8,7 +8,7 @@ import "sync"
// Group is a lazy load container. // Group is a lazy load container.
type Group struct { type Group struct {
new func() interface{} new func() interface{}
objs sync.Map objs map[string]interface{}
sync.RWMutex sync.RWMutex
} }
@ -19,19 +19,29 @@ func NewGroup(new func() interface{}) *Group {
} }
return &Group{ return &Group{
new: new, new: new,
objs: make(map[string]interface{}),
} }
} }
// Get gets the object by the given key. // Get gets the object by the given key.
func (g *Group) Get(key string) interface{} { func (g *Group) Get(key string) interface{} {
g.RLock() g.RLock()
new := g.new obj, ok := g.objs[key]
if ok {
g.RUnlock() g.RUnlock()
obj, ok := g.objs.Load(key) return obj
if !ok { }
obj = new() g.RUnlock()
g.objs.Store(key, obj)
// double check
g.Lock()
defer g.Unlock()
obj, ok = g.objs[key]
if ok {
return obj
} }
obj = g.new()
g.objs[key] = obj
return obj return obj
} }
@ -48,8 +58,7 @@ func (g *Group) Reset(new func() interface{}) {
// Clear deletes all objects. // Clear deletes all objects.
func (g *Group) Clear() { func (g *Group) Clear() {
g.objs.Range(func(key, value interface{}) bool { g.Lock()
g.objs.Delete(key) g.objs = make(map[string]interface{})
return true g.Unlock()
})
} }

@ -36,10 +36,10 @@ func TestGroupReset(t *testing.T) {
}) })
length := 0 length := 0
g.objs.Range(func(_, _ interface{}) bool { for _,_ = range g.objs {
length++ length++
return true }
})
assert.Equal(t, 0, length) assert.Equal(t, 0, length)
g.Get("/x/internal/dummy/user") g.Get("/x/internal/dummy/user")
@ -52,18 +52,16 @@ func TestGroupClear(t *testing.T) {
}) })
g.Get("/x/internal/dummy/user") g.Get("/x/internal/dummy/user")
length := 0 length := 0
g.objs.Range(func(_, _ interface{}) bool { for _,_ = range g.objs {
length++ length++
return true }
})
assert.Equal(t, 1, length) assert.Equal(t, 1, length)
g.Clear() g.Clear()
length = 0 length = 0
g.objs.Range(func(_, _ interface{}) bool { for _,_ = range g.objs {
length++ length++
return true }
})
assert.Equal(t, 0, length) assert.Equal(t, 0, length)
} }

Loading…
Cancel
Save