package ewma

import (
	"container/list"
	"context"
	"math"
	"sync"
	"sync/atomic"
	"time"

	"github.com/go-kratos/kratos/v2/errors"
	"github.com/go-kratos/kratos/v2/selector"
)

const (
	// The mean lifetime of `cost`, it reaches its half-life after Tau*ln(2).
	tau = int64(time.Millisecond * 600)
	// if statistic not collected,we add a big lag penalty to endpoint
	penalty = uint64(time.Second * 10)
)

var (
	_ selector.WeightedNode        = &Node{}
	_ selector.WeightedNodeBuilder = &Builder{}
)

// Node is endpoint instance
type Node struct {
	selector.Node

	// client statistic data
	lag       int64
	success   uint64
	inflight  int64
	inflights *list.List
	// last collected timestamp
	stamp     int64
	predictTs int64
	predict   int64
	// request number in a period time
	reqs int64
	// last lastPick timestamp
	lastPick int64

	errHandler func(err error) (isErr bool)
	lk         sync.RWMutex
}

// Builder is ewma node builder.
type Builder struct {
	ErrHandler func(err error) (isErr bool)
}

// Build create a weighted node.
func (b *Builder) Build(n selector.Node) selector.WeightedNode {
	s := &Node{
		Node:       n,
		lag:        0,
		success:    1000,
		inflight:   1,
		inflights:  list.New(),
		errHandler: b.ErrHandler,
	}
	return s
}

func (n *Node) health() uint64 {
	return atomic.LoadUint64(&n.success)
}

func (n *Node) load() (load uint64) {
	now := time.Now().UnixNano()
	avgLag := atomic.LoadInt64(&n.lag)
	lastPredictTs := atomic.LoadInt64(&n.predictTs)
	predicInterval := avgLag / 5
	if predicInterval < int64(time.Millisecond*5) {
		predicInterval = int64(time.Millisecond * 5)
	} else if predicInterval > int64(time.Millisecond*200) {
		predicInterval = int64(time.Millisecond * 200)
	}
	if now-lastPredictTs > predicInterval {
		if atomic.CompareAndSwapInt64(&n.predictTs, lastPredictTs, now) {
			var (
				total   int64
				count   int
				predict int64
			)
			n.lk.RLock()
			first := n.inflights.Front()
			for first != nil {
				lag := now - first.Value.(int64)
				if lag > avgLag {
					count++
					total += lag
				}
				first = first.Next()
			}
			if count > (n.inflights.Len()/2 + 1) {
				predict = total / int64(count)
			}
			n.lk.RUnlock()
			atomic.StoreInt64(&n.predict, predict)
		}
	}

	if avgLag == 0 {
		// penalty是node刚启动时没有数据时的惩罚值,默认为1e9 * 10
		load = penalty * uint64(atomic.LoadInt64(&n.inflight))
	} else {
		predict := atomic.LoadInt64(&n.predict)
		if predict > avgLag {
			avgLag = predict
		}
		load = uint64(avgLag) * uint64(atomic.LoadInt64(&n.inflight))
	}
	return
}

// Pick pick a node.
func (n *Node) Pick() selector.DoneFunc {
	now := time.Now().UnixNano()
	atomic.StoreInt64(&n.lastPick, now)
	atomic.AddInt64(&n.inflight, 1)
	atomic.AddInt64(&n.reqs, 1)
	n.lk.Lock()
	e := n.inflights.PushBack(now)
	n.lk.Unlock()
	return func(ctx context.Context, di selector.DoneInfo) {
		n.lk.Lock()
		n.inflights.Remove(e)
		n.lk.Unlock()
		atomic.AddInt64(&n.inflight, -1)

		now := time.Now().UnixNano()
		// get moving average ratio w
		stamp := atomic.SwapInt64(&n.stamp, now)
		td := now - stamp
		if td < 0 {
			td = 0
		}
		w := math.Exp(float64(-td) / float64(tau))

		start := e.Value.(int64)
		lag := now - start
		if lag < 0 {
			lag = 0
		}
		oldLag := atomic.LoadInt64(&n.lag)
		if oldLag == 0 {
			w = 0.0
		}
		lag = int64(float64(oldLag)*w + float64(lag)*(1.0-w))
		atomic.StoreInt64(&n.lag, lag)

		success := uint64(1000) // error value ,if error set 1
		if di.Err != nil {
			if n.errHandler != nil {
				if n.errHandler(di.Err) {
					success = 0
				}
			} else if errors.Is(context.DeadlineExceeded, di.Err) || errors.Is(context.Canceled, di.Err) ||
				errors.IsServiceUnavailable(di.Err) || errors.IsGatewayTimeout(di.Err) {
				success = 0
			}
		}
		oldSuc := atomic.LoadUint64(&n.success)
		success = uint64(float64(oldSuc)*w + float64(success)*(1.0-w))
		atomic.StoreUint64(&n.success, success)
	}
}

// Weight is node effective weight.
func (n *Node) Weight() (weight float64) {
	weight = float64(n.health()*uint64(time.Second)) / float64(n.load())
	return
}

func (n *Node) PickElapsed() time.Duration {
	return time.Duration(time.Now().UnixNano() - atomic.LoadInt64(&n.lastPick))
}

func (n *Node) Raw() selector.Node {
	return n.Node
}