package wrr import ( "context" "math" "strconv" "sync" "sync/atomic" "time" "google.golang.org/grpc" "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/base" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" "google.golang.org/grpc/status" "github.com/go-kratos/kratos/pkg/conf/env" "github.com/go-kratos/kratos/pkg/log" nmd "github.com/go-kratos/kratos/pkg/net/metadata" wmeta "github.com/go-kratos/kratos/pkg/net/rpc/warden/internal/metadata" "github.com/go-kratos/kratos/pkg/stat/metric" ) var _ base.PickerBuilder = &wrrPickerBuilder{} var _ balancer.Picker = &wrrPicker{} // var dwrrFeature feature.Feature = "dwrr" // Name is the name of round_robin balancer. const Name = "wrr" // newBuilder creates a new weighted-roundrobin balancer builder. func newBuilder() balancer.Builder { return base.NewBalancerBuilder(Name, &wrrPickerBuilder{}) } func init() { //feature.DefaultGate.Add(map[feature.Feature]feature.Spec{ // dwrrFeature: {Default: false}, //}) balancer.Register(newBuilder()) } type serverInfo struct { cpu int64 success uint64 // float64 bits } type subConn struct { conn balancer.SubConn addr resolver.Address meta wmeta.MD err metric.RollingCounter latency metric.RollingGauge si serverInfo // effective weight ewt int64 // current weight cwt int64 // last score score float64 } func (c *subConn) errSummary() (err int64, req int64) { c.err.Reduce(func(iterator metric.Iterator) float64 { for iterator.Next() { bucket := iterator.Bucket() req += bucket.Count for _, p := range bucket.Points { err += int64(p) } } return 0 }) return } func (c *subConn) latencySummary() (latency float64, count int64) { c.latency.Reduce(func(iterator metric.Iterator) float64 { for iterator.Next() { bucket := iterator.Bucket() count += bucket.Count for _, p := range bucket.Points { latency += p } } return 0 }) return latency / float64(count), count } // statistics is info for log type statistics struct { addr string ewt int64 cs float64 ss float64 latency float64 cpu float64 req int64 } // Stats is grpc Interceptor for client to collect server stats func Stats() grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) { var ( trailer metadata.MD md nmd.MD ok bool ) if md, ok = nmd.FromContext(ctx); !ok { md = nmd.MD{} } else { md = md.Copy() } ctx = nmd.NewContext(ctx, md) opts = append(opts, grpc.Trailer(&trailer)) err = invoker(ctx, method, req, reply, cc, opts...) conn, ok := md["conn"].(*subConn) if !ok { return } if strs, ok := trailer[wmeta.CPUUsage]; ok { if cpu, err2 := strconv.ParseInt(strs[0], 10, 64); err2 == nil && cpu > 0 { atomic.StoreInt64(&conn.si.cpu, cpu) } } return } } type wrrPickerBuilder struct{} func (*wrrPickerBuilder) Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker { p := &wrrPicker{ colors: make(map[string]*wrrPicker), } for addr, sc := range readySCs { meta, ok := addr.Metadata.(wmeta.MD) if !ok { meta = wmeta.MD{ Weight: 10, } } subc := &subConn{ conn: sc, addr: addr, meta: meta, ewt: int64(meta.Weight), score: -1, err: metric.NewRollingCounter(metric.RollingCounterOpts{ Size: 10, BucketDuration: time.Millisecond * 100, }), latency: metric.NewRollingGauge(metric.RollingGaugeOpts{ Size: 10, BucketDuration: time.Millisecond * 100, }), si: serverInfo{cpu: 500, success: math.Float64bits(1)}, } if meta.Color == "" { p.subConns = append(p.subConns, subc) continue } // if color not empty, use color picker cp, ok := p.colors[meta.Color] if !ok { cp = &wrrPicker{} p.colors[meta.Color] = cp } cp.subConns = append(cp.subConns, subc) } return p } type wrrPicker struct { // subConns is the snapshot of the weighted-roundrobin balancer when this picker was // created. The slice is immutable. Each Get() will do a round robin // selection from it and return the selected SubConn. subConns []*subConn colors map[string]*wrrPicker updateAt int64 mu sync.Mutex } func (p *wrrPicker) Pick(ctx context.Context, opts balancer.PickInfo) (balancer.SubConn, func(balancer.DoneInfo), error) { // FIXME refactor to unify the color logic color := nmd.String(ctx, nmd.Color) if color == "" && env.Color != "" { color = env.Color } if color != "" { if cp, ok := p.colors[color]; ok { return cp.pick(ctx, opts) } } return p.pick(ctx, opts) } func (p *wrrPicker) pick(ctx context.Context, opts balancer.PickInfo) (balancer.SubConn, func(balancer.DoneInfo), error) { var ( conn *subConn totalWeight int64 ) if len(p.subConns) <= 0 { return nil, nil, balancer.ErrNoSubConnAvailable } p.mu.Lock() // nginx wrr load balancing algorithm: http://blog.csdn.net/zhangskd/article/details/50194069 for _, sc := range p.subConns { totalWeight += sc.ewt sc.cwt += sc.ewt if conn == nil || conn.cwt < sc.cwt { conn = sc } } conn.cwt -= totalWeight p.mu.Unlock() start := time.Now() if cmd, ok := nmd.FromContext(ctx); ok { cmd["conn"] = conn } //if !feature.DefaultGate.Enabled(dwrrFeature) { // return conn.conn, nil, nil //} return conn.conn, func(di balancer.DoneInfo) { ev := int64(0) // error value ,if error set 1 if di.Err != nil { if st, ok := status.FromError(di.Err); ok { // only counter the local grpc error, ignore any business error if st.Code() != codes.Unknown && st.Code() != codes.OK { ev = 1 } } } conn.err.Add(ev) now := time.Now() conn.latency.Add(now.Sub(start).Nanoseconds() / 1e5) u := atomic.LoadInt64(&p.updateAt) if now.UnixNano()-u < int64(time.Second) { return } if !atomic.CompareAndSwapInt64(&p.updateAt, u, now.UnixNano()) { return } var ( stats = make([]statistics, len(p.subConns)) count int total float64 ) for i, conn := range p.subConns { cpu := float64(atomic.LoadInt64(&conn.si.cpu)) ss := math.Float64frombits(atomic.LoadUint64(&conn.si.success)) errc, req := conn.errSummary() lagv, lagc := conn.latencySummary() if req > 0 && lagc > 0 && lagv > 0 { // client-side success ratio cs := 1 - (float64(errc) / float64(req)) if cs <= 0 { cs = 0.1 } else if cs <= 0.2 && req <= 5 { cs = 0.2 } conn.score = math.Sqrt((cs * ss * ss * 1e9) / (lagv * cpu)) stats[i] = statistics{cs: cs, ss: ss, latency: lagv, cpu: cpu, req: req} } stats[i].addr = conn.addr.Addr if conn.score > 0 { total += conn.score count++ } } // count must be greater than 1,otherwise will lead ewt to 0 if count < 2 { return } avgscore := total / float64(count) p.mu.Lock() for i, conn := range p.subConns { if conn.score <= 0 { conn.score = avgscore } conn.ewt = int64(conn.score * float64(conn.meta.Weight)) stats[i].ewt = conn.ewt } p.mu.Unlock() log.Info("warden wrr(%s): %+v", conn.addr.ServerName, stats) }, nil }