diff --git a/selector/node/ewma/node.go b/selector/node/ewma/node.go index 41bcc37a5..1707f1083 100644 --- a/selector/node/ewma/node.go +++ b/selector/node/ewma/node.go @@ -9,6 +9,7 @@ import ( "time" "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/selector" ) @@ -75,45 +76,44 @@ func (n *Node) load() (load uint64) { predictInterval := avgLag / 5 if predictInterval < int64(time.Millisecond*5) { predictInterval = int64(time.Millisecond * 5) - } else if predictInterval > int64(time.Millisecond*200) { + } + if predictInterval > int64(time.Millisecond*200) { predictInterval = int64(time.Millisecond * 200) } - if now-lastPredictTs > predictInterval { - if atomic.CompareAndSwapInt64(&n.predictTs, lastPredictTs, now) { - var ( - total int64 - count int - predict int64 - ) - n.lk.RLock() - first := n.inflights.Front() - for first != nil { - lag := now - first.Value.(int64) - if lag > avgLag { - count++ - total += lag - } - first = first.Next() + if now-lastPredictTs > predictInterval && atomic.CompareAndSwapInt64(&n.predictTs, lastPredictTs, now) { + var ( + total int64 + count int + predict int64 + ) + n.lk.RLock() + first := n.inflights.Front() + for first != nil { + lag := now - first.Value.(int64) + if lag > avgLag { + count++ + total += lag } - if count > (n.inflights.Len()/2 + 1) { - predict = total / int64(count) - } - n.lk.RUnlock() - atomic.StoreInt64(&n.predict, predict) + first = first.Next() + } + if count > (n.inflights.Len()/2 + 1) { + predict = total / int64(count) } + n.lk.RUnlock() + atomic.StoreInt64(&n.predict, predict) } if avgLag == 0 { // penalty is the penalty value when there is no data when the node is just started. // The default value is 1e9 * 10 load = penalty * uint64(atomic.LoadInt64(&n.inflight)) - } else { - predict := atomic.LoadInt64(&n.predict) - if predict > avgLag { - avgLag = predict - } - load = uint64(avgLag) * uint64(atomic.LoadInt64(&n.inflight)) + return + } + predict := atomic.LoadInt64(&n.predict) + if predict > avgLag { + avgLag = predict } + load = uint64(avgLag) * uint64(atomic.LoadInt64(&n.inflight)) return } @@ -155,11 +155,10 @@ func (n *Node) Pick() selector.DoneFunc { success := uint64(1000) // error value ,if error set 1 if di.Err != nil { - if n.errHandler != nil { - if n.errHandler(di.Err) { - success = 0 - } - } else if errors.Is(context.DeadlineExceeded, di.Err) || errors.Is(context.Canceled, di.Err) || + if n.errHandler != nil && n.errHandler(di.Err) { + success = 0 + } + if errors.Is(context.DeadlineExceeded, di.Err) || errors.Is(context.Canceled, di.Err) || errors.IsServiceUnavailable(di.Err) || errors.IsGatewayTimeout(di.Err) { success = 0 }