diff --git a/pkg/sync/pipeline/README.md b/pkg/sync/pipeline/README.md new file mode 100644 index 000000000..b056c0901 --- /dev/null +++ b/pkg/sync/pipeline/README.md @@ -0,0 +1,3 @@ +# pkg/sync/pipeline + +提供内存批量聚合工具 diff --git a/pkg/sync/pipeline/fanout/README.md b/pkg/sync/pipeline/fanout/README.md new file mode 100644 index 000000000..cf88ddd8e --- /dev/null +++ b/pkg/sync/pipeline/fanout/README.md @@ -0,0 +1,14 @@ +# pkg/sync/pipeline/fanout + +功能: + +* 支持定义Worker 数量的goroutine,进行消费 +* 内部支持的元数据传递(pkg/net/metadata) + +示例: +```golang +//名称为cache 执行线程为1 buffer长度为1024 +cache := fanout.New("cache", fanout.Worker(1), fanout.Buffer(1024)) +cache.Do(c, func(c context.Context) { SomeFunc(c, args...) }) +cache.Close() +``` \ No newline at end of file diff --git a/pkg/sync/pipeline/fanout/example_test.go b/pkg/sync/pipeline/fanout/example_test.go new file mode 100644 index 000000000..5de973199 --- /dev/null +++ b/pkg/sync/pipeline/fanout/example_test.go @@ -0,0 +1,22 @@ +package fanout + +import "context" + +// addCache 加缓存的例子 +func addCache(c context.Context, id, value int) { + // some thing... +} + +func Example() { + // 这里只是举个例子 真正使用的时候 应该用bm/rpc 传过来的context + var c = context.Background() + // 新建一个fanout 对象 名称为cache 名称主要用来上报监控和打日志使用 最好不要重复 + // (可选参数) worker数量为1 表示后台只有1个线程在工作 + // (可选参数) buffer 为1024 表示缓存chan长度为1024 如果chan慢了 再调用Do方法就会报错 设定长度主要为了防止OOM + cache := New("cache", Worker(1), Buffer(1024)) + // 需要异步执行的方法 + // 这里传进来的c里面的meta信息会被复制 超时会忽略 addCache拿到的context已经没有超时信息了 + cache.Do(c, func(c context.Context) { addCache(c, 0, 0) }) + // 程序结束的时候关闭fanout 会等待后台线程完成后返回 + cache.Close() +} diff --git a/pkg/sync/pipeline/fanout/fanout.go b/pkg/sync/pipeline/fanout/fanout.go new file mode 100644 index 000000000..ad20b3e12 --- /dev/null +++ b/pkg/sync/pipeline/fanout/fanout.go @@ -0,0 +1,151 @@ +package fanout + +import ( + "context" + "errors" + "runtime" + "sync" + + "github.com/bilibili/kratos/pkg/log" + "github.com/bilibili/kratos/pkg/net/metadata" + "github.com/bilibili/kratos/pkg/net/trace" + "github.com/bilibili/kratos/pkg/stat/prom" +) + +var ( + // ErrFull chan full. + ErrFull = errors.New("fanout: chan full") + stats = prom.BusinessInfoCount + traceTags = []trace.Tag{ + trace.Tag{Key: trace.TagSpanKind, Value: "background"}, + trace.Tag{Key: trace.TagComponent, Value: "sync/pipeline/fanout"}, + } +) + +type options struct { + worker int + buffer int +} + +// Option fanout option +type Option func(*options) + +// Worker specifies the worker of fanout +func Worker(n int) Option { + if n <= 0 { + panic("fanout: worker should > 0") + } + return func(o *options) { + o.worker = n + } +} + +// Buffer specifies the buffer of fanout +func Buffer(n int) Option { + if n <= 0 { + panic("fanout: buffer should > 0") + } + return func(o *options) { + o.buffer = n + } +} + +type item struct { + f func(c context.Context) + ctx context.Context +} + +// Fanout async consume data from chan. +type Fanout struct { + name string + ch chan item + options *options + waiter sync.WaitGroup + + ctx context.Context + cancel func() +} + +// New new a fanout struct. +func New(name string, opts ...Option) *Fanout { + if name == "" { + name = "fanout" + } + o := &options{ + worker: 1, + buffer: 1024, + } + for _, op := range opts { + op(o) + } + c := &Fanout{ + ch: make(chan item, o.buffer), + name: name, + options: o, + } + c.ctx, c.cancel = context.WithCancel(context.Background()) + c.waiter.Add(o.worker) + for i := 0; i < o.worker; i++ { + go c.proc() + } + return c +} + +func (c *Fanout) proc() { + defer c.waiter.Done() + for { + select { + case t := <-c.ch: + wrapFunc(t.f)(t.ctx) + stats.State(c.name+"_channel", int64(len(c.ch))) + case <-c.ctx.Done(): + return + } + } +} + +func wrapFunc(f func(c context.Context)) (res func(context.Context)) { + res = func(ctx context.Context) { + defer func() { + if r := recover(); r != nil { + buf := make([]byte, 64*1024) + buf = buf[:runtime.Stack(buf, false)] + log.Error("panic in fanout proc, err: %s, stack: %s", r, buf) + } + }() + f(ctx) + if tr, ok := trace.FromContext(ctx); ok { + tr.Finish(nil) + } + } + return +} + +// Do save a callback func. +func (c *Fanout) Do(ctx context.Context, f func(ctx context.Context)) (err error) { + if f == nil || c.ctx.Err() != nil { + return c.ctx.Err() + } + nakeCtx := metadata.WithContext(ctx) + if tr, ok := trace.FromContext(ctx); ok { + tr = tr.Fork("", "Fanout:Do").SetTag(traceTags...) + nakeCtx = trace.NewContext(nakeCtx, tr) + } + select { + case c.ch <- item{f: f, ctx: nakeCtx}: + default: + err = ErrFull + } + stats.State(c.name+"_channel", int64(len(c.ch))) + return +} + +// Close close fanout +func (c *Fanout) Close() error { + if err := c.ctx.Err(); err != nil { + return err + } + c.cancel() + c.waiter.Wait() + return nil +} diff --git a/pkg/sync/pipeline/fanout/fanout_test.go b/pkg/sync/pipeline/fanout/fanout_test.go new file mode 100644 index 000000000..ef1df1169 --- /dev/null +++ b/pkg/sync/pipeline/fanout/fanout_test.go @@ -0,0 +1,30 @@ +package fanout + +import ( + "context" + "testing" + "time" +) + +func TestFanout_Do(t *testing.T) { + ca := New("cache", Worker(1), Buffer(1024)) + var run bool + ca.Do(context.Background(), func(c context.Context) { + run = true + panic("error") + }) + time.Sleep(time.Millisecond * 50) + t.Log("not panic") + if !run { + t.Fatal("expect run be true") + } +} + +func TestFanout_Close(t *testing.T) { + ca := New("cache", Worker(1), Buffer(1024)) + ca.Close() + err := ca.Do(context.Background(), func(c context.Context) {}) + if err == nil { + t.Fatal("expect get err") + } +} diff --git a/pkg/sync/pipeline/pipeline.go b/pkg/sync/pipeline/pipeline.go new file mode 100644 index 000000000..ef9263cb9 --- /dev/null +++ b/pkg/sync/pipeline/pipeline.go @@ -0,0 +1,185 @@ +package pipeline + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/bilibili/kratos/pkg/net/metadata" + xtime "github.com/bilibili/kratos/pkg/time" +) + +// ErrFull channel full error +var ErrFull = errors.New("channel full") + +type message struct { + key string + value interface{} +} + +// Pipeline pipeline struct +type Pipeline struct { + Do func(c context.Context, index int, values map[string][]interface{}) + Split func(key string) int + chans []chan *message + mirrorChans []chan *message + config *Config + wait sync.WaitGroup +} + +// Config Pipeline config +type Config struct { + // MaxSize merge size + MaxSize int + // Interval merge interval + Interval xtime.Duration + // Buffer channel size + Buffer int + // Worker channel number + Worker int + // Smooth smoothing interval + Smooth bool +} + +func (c *Config) fix() { + if c.MaxSize <= 0 { + c.MaxSize = 1000 + } + if c.Interval <= 0 { + c.Interval = xtime.Duration(time.Second) + } + if c.Buffer <= 0 { + c.Buffer = 1000 + } + if c.Worker <= 0 { + c.Worker = 10 + } +} + +// NewPipeline new pipline +func NewPipeline(config *Config) (res *Pipeline) { + if config == nil { + config = &Config{} + } + config.fix() + res = &Pipeline{ + chans: make([]chan *message, config.Worker), + mirrorChans: make([]chan *message, config.Worker), + config: config, + } + for i := 0; i < config.Worker; i++ { + res.chans[i] = make(chan *message, config.Buffer) + res.mirrorChans[i] = make(chan *message, config.Buffer) + } + return +} + +// Start start all mergeproc +func (p *Pipeline) Start() { + if p.Do == nil { + panic("pipeline: do func is nil") + } + if p.Split == nil { + panic("pipeline: split func is nil") + } + var mirror bool + p.wait.Add(len(p.chans) + len(p.mirrorChans)) + for i, ch := range p.chans { + go p.mergeproc(mirror, i, ch) + } + mirror = true + for i, ch := range p.mirrorChans { + go p.mergeproc(mirror, i, ch) + } +} + +// SyncAdd sync add a value to channal, channel shard in split method +func (p *Pipeline) SyncAdd(c context.Context, key string, value interface{}) { + ch, msg := p.add(c, key, value) + ch <- msg +} + +// Add async add a value to channal, channel shard in split method +func (p *Pipeline) Add(c context.Context, key string, value interface{}) (err error) { + ch, msg := p.add(c, key, value) + select { + case ch <- msg: + default: + err = ErrFull + } + return +} + +func (p *Pipeline) add(c context.Context, key string, value interface{}) (ch chan *message, m *message) { + shard := p.Split(key) % p.config.Worker + if metadata.String(c, metadata.Mirror) != "" { + ch = p.mirrorChans[shard] + } else { + ch = p.chans[shard] + } + m = &message{key: key, value: value} + return +} + +// Close all goroutinue +func (p *Pipeline) Close() (err error) { + for _, ch := range p.chans { + ch <- nil + } + for _, ch := range p.mirrorChans { + ch <- nil + } + p.wait.Wait() + return +} + +func (p *Pipeline) mergeproc(mirror bool, index int, ch <-chan *message) { + defer p.wait.Done() + var ( + m *message + vals = make(map[string][]interface{}, p.config.MaxSize) + closed bool + count int + inteval = p.config.Interval + oldTicker = true + ) + if p.config.Smooth && index > 0 { + inteval = xtime.Duration(int64(index) * (int64(p.config.Interval) / int64(p.config.Worker))) + } + ticker := time.NewTicker(time.Duration(inteval)) + for { + select { + case m = <-ch: + if m == nil { + closed = true + break + } + count++ + vals[m.key] = append(vals[m.key], m.value) + if count >= p.config.MaxSize { + break + } + continue + case <-ticker.C: + if p.config.Smooth && oldTicker { + ticker.Stop() + ticker = time.NewTicker(time.Duration(p.config.Interval)) + oldTicker = false + } + } + if len(vals) > 0 { + ctx := context.Background() + if mirror { + ctx = metadata.NewContext(ctx, metadata.MD{metadata.Mirror: "1"}) + } + p.Do(ctx, index, vals) + vals = make(map[string][]interface{}, p.config.MaxSize) + count = 0 + } + if closed { + ticker.Stop() + return + } + } +} diff --git a/pkg/sync/pipeline/pipeline_test.go b/pkg/sync/pipeline/pipeline_test.go new file mode 100644 index 000000000..dcfdd8902 --- /dev/null +++ b/pkg/sync/pipeline/pipeline_test.go @@ -0,0 +1,132 @@ +package pipeline + +import ( + "context" + "reflect" + "strconv" + "testing" + "time" + + "github.com/bilibili/kratos/pkg/net/metadata" + xtime "github.com/bilibili/kratos/pkg/time" +) + +func TestPipeline(t *testing.T) { + conf := &Config{ + MaxSize: 3, + Interval: xtime.Duration(time.Millisecond * 20), + Buffer: 3, + Worker: 10, + } + type recv struct { + mirror string + ch int + values map[string][]interface{} + } + var runs []recv + do := func(c context.Context, ch int, values map[string][]interface{}) { + runs = append(runs, recv{ + mirror: metadata.String(c, metadata.Mirror), + values: values, + ch: ch, + }) + } + split := func(s string) int { + n, _ := strconv.Atoi(s) + return n + } + p := NewPipeline(conf) + p.Do = do + p.Split = split + p.Start() + p.Add(context.Background(), "1", 1) + p.Add(context.Background(), "1", 2) + p.Add(context.Background(), "11", 3) + p.Add(context.Background(), "2", 3) + time.Sleep(time.Millisecond * 60) + mirrorCtx := metadata.NewContext(context.Background(), metadata.MD{metadata.Mirror: "1"}) + p.Add(mirrorCtx, "2", 3) + time.Sleep(time.Millisecond * 60) + p.SyncAdd(mirrorCtx, "5", 5) + time.Sleep(time.Millisecond * 60) + p.Close() + expt := []recv{ + { + mirror: "", + ch: 1, + values: map[string][]interface{}{ + "1": {1, 2}, + "11": {3}, + }, + }, + { + mirror: "", + ch: 2, + values: map[string][]interface{}{ + "2": {3}, + }, + }, + { + mirror: "1", + ch: 2, + values: map[string][]interface{}{ + "2": {3}, + }, + }, + { + mirror: "1", + ch: 5, + values: map[string][]interface{}{ + "5": {5}, + }, + }, + } + if !reflect.DeepEqual(runs, expt) { + t.Errorf("expect get %+v,\n got: %+v", expt, runs) + } +} + +func TestPipelineSmooth(t *testing.T) { + conf := &Config{ + MaxSize: 100, + Interval: xtime.Duration(time.Second), + Buffer: 100, + Worker: 10, + Smooth: true, + } + type result struct { + index int + ts time.Time + } + var results []result + do := func(c context.Context, index int, values map[string][]interface{}) { + results = append(results, result{ + index: index, + ts: time.Now(), + }) + } + split := func(s string) int { + n, _ := strconv.Atoi(s) + return n + } + p := NewPipeline(conf) + p.Do = do + p.Split = split + p.Start() + for i := 0; i < 10; i++ { + p.Add(context.Background(), strconv.Itoa(i), 1) + } + time.Sleep(time.Millisecond * 1500) + if len(results) != conf.Worker { + t.Errorf("expect results equal worker") + t.FailNow() + } + for i, r := range results { + if i > 0 { + if r.ts.Sub(results[i-1].ts) < time.Millisecond*20 { + t.Errorf("expect runs be smooth") + t.FailNow() + } + } + } +}