package context

import (
	"context"
	"sync"
	"sync/atomic"
	"time"
)

type mergeCtx struct {
	parent1, parent2 context.Context

	done     chan struct{}
	doneMark uint32
	doneOnce sync.Once
	doneErr  error

	cancelCh   chan struct{}
	cancelOnce sync.Once
}

// Merge merges two contexts into one.
func Merge(parent1, parent2 context.Context) (context.Context, context.CancelFunc) {
	mc := &mergeCtx{
		parent1:  parent1,
		parent2:  parent2,
		done:     make(chan struct{}),
		cancelCh: make(chan struct{}),
	}
	select {
	case <-parent1.Done():
		_ = mc.finish(parent1.Err())
	case <-parent2.Done():
		_ = mc.finish(parent2.Err())
	default:
		go mc.wait()
	}
	return mc, mc.cancel
}

func (mc *mergeCtx) finish(err error) error {
	mc.doneOnce.Do(func() {
		mc.doneErr = err
		atomic.StoreUint32(&mc.doneMark, 1)
		close(mc.done)
	})
	return mc.doneErr
}

func (mc *mergeCtx) wait() {
	var err error
	select {
	case <-mc.parent1.Done():
		err = mc.parent1.Err()
	case <-mc.parent2.Done():
		err = mc.parent2.Err()
	case <-mc.cancelCh:
		err = context.Canceled
	}
	_ = mc.finish(err)
}

func (mc *mergeCtx) cancel() {
	mc.cancelOnce.Do(func() {
		close(mc.cancelCh)
	})
}

// Done implements context.Context.
func (mc *mergeCtx) Done() <-chan struct{} {
	return mc.done
}

// Err implements context.Context.
func (mc *mergeCtx) Err() error {
	if atomic.LoadUint32(&mc.doneMark) != 0 {
		return mc.doneErr
	}
	var err error
	select {
	case <-mc.parent1.Done():
		err = mc.parent1.Err()
	case <-mc.parent2.Done():
		err = mc.parent2.Err()
	case <-mc.cancelCh:
		err = context.Canceled
	default:
		return nil
	}
	return mc.finish(err)
}

// Deadline implements context.Context.
func (mc *mergeCtx) Deadline() (time.Time, bool) {
	d1, ok1 := mc.parent1.Deadline()
	d2, ok2 := mc.parent2.Deadline()
	switch {
	case !ok1:
		return d2, ok2
	case !ok2:
		return d1, ok1
	case d1.Before(d2):
		return d1, true
	default:
		return d2, true
	}
}

// Value implements context.Context.
func (mc *mergeCtx) Value(key interface{}) interface{} {
	if v := mc.parent1.Value(key); v != nil {
		return v
	}
	return mc.parent2.Value(key)
}