186 lines
3.9 KiB
186 lines
3.9 KiB
6 years ago
|
package pipeline
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/bilibili/kratos/pkg/net/metadata"
|
||
|
xtime "github.com/bilibili/kratos/pkg/time"
|
||
|
)
|
||
|
|
||
|
// ErrFull channel full error
|
||
|
var ErrFull = errors.New("channel full")
|
||
|
|
||
|
type message struct {
|
||
|
key string
|
||
|
value interface{}
|
||
|
}
|
||
|
|
||
|
// Pipeline pipeline struct
|
||
|
type Pipeline struct {
|
||
|
Do func(c context.Context, index int, values map[string][]interface{})
|
||
|
Split func(key string) int
|
||
|
chans []chan *message
|
||
|
mirrorChans []chan *message
|
||
|
config *Config
|
||
|
wait sync.WaitGroup
|
||
|
}
|
||
|
|
||
|
// Config Pipeline config
|
||
|
type Config struct {
|
||
|
// MaxSize merge size
|
||
|
MaxSize int
|
||
|
// Interval merge interval
|
||
|
Interval xtime.Duration
|
||
|
// Buffer channel size
|
||
|
Buffer int
|
||
|
// Worker channel number
|
||
|
Worker int
|
||
|
// Smooth smoothing interval
|
||
|
Smooth bool
|
||
|
}
|
||
|
|
||
|
func (c *Config) fix() {
|
||
|
if c.MaxSize <= 0 {
|
||
|
c.MaxSize = 1000
|
||
|
}
|
||
|
if c.Interval <= 0 {
|
||
|
c.Interval = xtime.Duration(time.Second)
|
||
|
}
|
||
|
if c.Buffer <= 0 {
|
||
|
c.Buffer = 1000
|
||
|
}
|
||
|
if c.Worker <= 0 {
|
||
|
c.Worker = 10
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// NewPipeline new pipline
|
||
|
func NewPipeline(config *Config) (res *Pipeline) {
|
||
|
if config == nil {
|
||
|
config = &Config{}
|
||
|
}
|
||
|
config.fix()
|
||
|
res = &Pipeline{
|
||
|
chans: make([]chan *message, config.Worker),
|
||
|
mirrorChans: make([]chan *message, config.Worker),
|
||
|
config: config,
|
||
|
}
|
||
|
for i := 0; i < config.Worker; i++ {
|
||
|
res.chans[i] = make(chan *message, config.Buffer)
|
||
|
res.mirrorChans[i] = make(chan *message, config.Buffer)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Start start all mergeproc
|
||
|
func (p *Pipeline) Start() {
|
||
|
if p.Do == nil {
|
||
|
panic("pipeline: do func is nil")
|
||
|
}
|
||
|
if p.Split == nil {
|
||
|
panic("pipeline: split func is nil")
|
||
|
}
|
||
|
var mirror bool
|
||
|
p.wait.Add(len(p.chans) + len(p.mirrorChans))
|
||
|
for i, ch := range p.chans {
|
||
|
go p.mergeproc(mirror, i, ch)
|
||
|
}
|
||
|
mirror = true
|
||
|
for i, ch := range p.mirrorChans {
|
||
|
go p.mergeproc(mirror, i, ch)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// SyncAdd sync add a value to channal, channel shard in split method
|
||
|
func (p *Pipeline) SyncAdd(c context.Context, key string, value interface{}) {
|
||
|
ch, msg := p.add(c, key, value)
|
||
|
ch <- msg
|
||
|
}
|
||
|
|
||
|
// Add async add a value to channal, channel shard in split method
|
||
|
func (p *Pipeline) Add(c context.Context, key string, value interface{}) (err error) {
|
||
|
ch, msg := p.add(c, key, value)
|
||
|
select {
|
||
|
case ch <- msg:
|
||
|
default:
|
||
|
err = ErrFull
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (p *Pipeline) add(c context.Context, key string, value interface{}) (ch chan *message, m *message) {
|
||
|
shard := p.Split(key) % p.config.Worker
|
||
|
if metadata.String(c, metadata.Mirror) != "" {
|
||
|
ch = p.mirrorChans[shard]
|
||
|
} else {
|
||
|
ch = p.chans[shard]
|
||
|
}
|
||
|
m = &message{key: key, value: value}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Close all goroutinue
|
||
|
func (p *Pipeline) Close() (err error) {
|
||
|
for _, ch := range p.chans {
|
||
|
ch <- nil
|
||
|
}
|
||
|
for _, ch := range p.mirrorChans {
|
||
|
ch <- nil
|
||
|
}
|
||
|
p.wait.Wait()
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (p *Pipeline) mergeproc(mirror bool, index int, ch <-chan *message) {
|
||
|
defer p.wait.Done()
|
||
|
var (
|
||
|
m *message
|
||
|
vals = make(map[string][]interface{}, p.config.MaxSize)
|
||
|
closed bool
|
||
|
count int
|
||
|
inteval = p.config.Interval
|
||
|
oldTicker = true
|
||
|
)
|
||
|
if p.config.Smooth && index > 0 {
|
||
|
inteval = xtime.Duration(int64(index) * (int64(p.config.Interval) / int64(p.config.Worker)))
|
||
|
}
|
||
|
ticker := time.NewTicker(time.Duration(inteval))
|
||
|
for {
|
||
|
select {
|
||
|
case m = <-ch:
|
||
|
if m == nil {
|
||
|
closed = true
|
||
|
break
|
||
|
}
|
||
|
count++
|
||
|
vals[m.key] = append(vals[m.key], m.value)
|
||
|
if count >= p.config.MaxSize {
|
||
|
break
|
||
|
}
|
||
|
continue
|
||
|
case <-ticker.C:
|
||
|
if p.config.Smooth && oldTicker {
|
||
|
ticker.Stop()
|
||
|
ticker = time.NewTicker(time.Duration(p.config.Interval))
|
||
|
oldTicker = false
|
||
|
}
|
||
|
}
|
||
|
if len(vals) > 0 {
|
||
|
ctx := context.Background()
|
||
|
if mirror {
|
||
|
ctx = metadata.NewContext(ctx, metadata.MD{metadata.Mirror: "1"})
|
||
|
}
|
||
|
p.Do(ctx, index, vals)
|
||
|
vals = make(map[string][]interface{}, p.config.MaxSize)
|
||
|
count = 0
|
||
|
}
|
||
|
if closed {
|
||
|
ticker.Stop()
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|