package memcache

import (
	"bytes"
	"context"
	"reflect"
	"testing"
	"time"

	"github.com/bilibili/kratos/pkg/container/pool"
	xtime "github.com/bilibili/kratos/pkg/time"
)

var itempool = &Item{
	Key:        "testpool",
	Value:      []byte("testpool"),
	Flags:      0,
	Expiration: 60,
	cas:        0,
}
var itempool2 = &Item{
	Key:        "test_count",
	Value:      []byte("0"),
	Flags:      0,
	Expiration: 1000,
	cas:        0,
}

type testObject struct {
	Mid   int64
	Value []byte
}

var largeValue = &Item{
	Key:        "large_value",
	Flags:      FlagGOB | FlagGzip,
	Expiration: 1000,
	cas:        0,
}

var largeValueBoundary = &Item{
	Key:        "large_value",
	Flags:      FlagGOB | FlagGzip,
	Expiration: 1000,
	cas:        0,
}

func TestPoolSet(t *testing.T) {
	conn := testPool.Get(context.Background())
	defer conn.Close()
	// set
	if err := conn.Set(itempool); err != nil {
		t.Errorf("memcache: set error(%v)", err)
	} else {
		t.Logf("memcache: set value: %s", itempool.Value)
	}
	if err := conn.Close(); err != nil {
		t.Errorf("memcache: close error(%v)", err)
	}
}

func TestPoolGet(t *testing.T) {
	key := "testpool"
	conn := testPool.Get(context.Background())
	defer conn.Close()
	// get
	if res, err := conn.Get(key); err != nil {
		t.Errorf("memcache: get error(%v)", err)
	} else {
		t.Logf("memcache: get value: %s", res.Value)
	}
	if _, err := conn.Get("not_found"); err != ErrNotFound {
		t.Errorf("memcache: expceted err is not found but got: %v", err)
	}
	if err := conn.Close(); err != nil {
		t.Errorf("memcache: close error(%v)", err)
	}
}

func TestPoolGetMulti(t *testing.T) {
	conn := testPool.Get(context.Background())
	defer conn.Close()
	s := []string{"testpool", "test1"}
	// get
	if res, err := conn.GetMulti(s); err != nil {
		t.Errorf("memcache: gets error(%v)", err)
	} else {
		t.Logf("memcache: gets value: %d", len(res))
	}
	if err := conn.Close(); err != nil {
		t.Errorf("memcache: close error(%v)", err)
	}
}

func TestPoolTouch(t *testing.T) {
	key := "testpool"
	conn := testPool.Get(context.Background())
	defer conn.Close()
	// touch
	if err := conn.Touch(key, 10); err != nil {
		t.Errorf("memcache: touch error(%v)", err)
	}
	if err := conn.Close(); err != nil {
		t.Errorf("memcache: close error(%v)", err)
	}
}

func TestPoolIncrement(t *testing.T) {
	key := "test_count"
	conn := testPool.Get(context.Background())
	defer conn.Close()
	// set
	if err := conn.Set(itempool2); err != nil {
		t.Errorf("memcache: set error(%v)", err)
	} else {
		t.Logf("memcache: set value: 0")
	}
	// incr
	if res, err := conn.Increment(key, 1); err != nil {
		t.Errorf("memcache: incr error(%v)", err)
	} else {
		t.Logf("memcache: incr n: %d", res)
		if res != 1 {
			t.Errorf("memcache: expected res=1 but got %d", res)
		}
	}
	// decr
	if res, err := conn.Decrement(key, 1); err != nil {
		t.Errorf("memcache: decr error(%v)", err)
	} else {
		t.Logf("memcache: decr n: %d", res)
		if res != 0 {
			t.Errorf("memcache: expected res=0 but got %d", res)
		}
	}
	if err := conn.Close(); err != nil {
		t.Errorf("memcache: close error(%v)", err)
	}
}

func TestPoolErr(t *testing.T) {
	conn := testPool.Get(context.Background())
	defer conn.Close()
	if err := conn.Close(); err != nil {
		t.Errorf("memcache: close error(%v)", err)
	}
	if err := conn.Err(); err == nil {
		t.Errorf("memcache: err not nil")
	} else {
		t.Logf("memcache: err: %v", err)
	}
}

func TestPoolCompareAndSwap(t *testing.T) {
	conn := testPool.Get(context.Background())
	defer conn.Close()
	key := "testpool"
	//cas
	if r, err := conn.Get(key); err != nil {
		t.Errorf("conn.Get() error(%v)", err)
	} else {
		r.Value = []byte("shit")
		if err := conn.CompareAndSwap(r); err != nil {
			t.Errorf("conn.Get() error(%v)", err)
		}
		r, _ := conn.Get("testpool")
		if r.Key != "testpool" || !bytes.Equal(r.Value, []byte("shit")) || r.Flags != 0 {
			t.Error("conn.Get() error, value")
		}
		if err := conn.Close(); err != nil {
			t.Errorf("memcache: close error(%v)", err)
		}
	}
}

func TestPoolDel(t *testing.T) {
	key := "testpool"
	conn := testPool.Get(context.Background())
	defer conn.Close()
	// delete
	if err := conn.Delete(key); err != nil {
		t.Errorf("memcache: delete error(%v)", err)
	} else {
		t.Logf("memcache: delete key: %s", key)
	}
	if err := conn.Close(); err != nil {
		t.Errorf("memcache: close error(%v)", err)
	}
}

func BenchmarkMemcache(b *testing.B) {
	c := &Config{
		Name:         "test",
		Proto:        "tcp",
		Addr:         testMemcacheAddr,
		DialTimeout:  xtime.Duration(time.Second),
		ReadTimeout:  xtime.Duration(time.Second),
		WriteTimeout: xtime.Duration(time.Second),
	}
	c.Config = &pool.Config{
		Active:      10,
		Idle:        5,
		IdleTimeout: xtime.Duration(90 * time.Second),
	}
	testPool = NewPool(c)
	b.ResetTimer()
	b.RunParallel(func(pb *testing.PB) {
		for pb.Next() {
			conn := testPool.Get(context.Background())
			if err := conn.Close(); err != nil {
				b.Errorf("memcache: close error(%v)", err)
			}
		}
	})
	if err := testPool.Close(); err != nil {
		b.Errorf("memcache: close error(%v)", err)
	}
}

func TestPoolSetLargeValue(t *testing.T) {
	var b bytes.Buffer
	for i := 0; i < 4000000; i++ {
		b.WriteByte(1)
	}
	obj := &testObject{}
	obj.Mid = 1000
	obj.Value = b.Bytes()
	largeValue.Object = obj
	conn := testPool.Get(context.Background())
	defer conn.Close()
	// set
	if err := conn.Set(largeValue); err != nil {
		t.Errorf("memcache: set error(%v)", err)
	}
	if err := conn.Close(); err != nil {
		t.Errorf("memcache: close error(%v)", err)
	}
}

func TestPoolGetLargeValue(t *testing.T) {
	key := largeValue.Key
	conn := testPool.Get(context.Background())
	defer conn.Close()
	// get
	var err error
	if _, err = conn.Get(key); err != nil {
		t.Errorf("memcache: large get error(%+v)", err)
	}
}

func TestPoolGetMultiLargeValue(t *testing.T) {
	conn := testPool.Get(context.Background())
	defer conn.Close()
	s := []string{largeValue.Key, largeValue.Key}
	// get
	if res, err := conn.GetMulti(s); err != nil {
		t.Errorf("memcache: gets error(%v)", err)
	} else {
		t.Logf("memcache: gets value: %d", len(res))
	}
	if err := conn.Close(); err != nil {
		t.Errorf("memcache: close error(%v)", err)
	}
}

func TestPoolSetLargeValueBoundary(t *testing.T) {
	var b bytes.Buffer
	for i := 0; i < _largeValue; i++ {
		b.WriteByte(1)
	}
	obj := &testObject{}
	obj.Mid = 1000
	obj.Value = b.Bytes()
	largeValueBoundary.Object = obj
	conn := testPool.Get(context.Background())
	defer conn.Close()
	// set
	if err := conn.Set(largeValueBoundary); err != nil {
		t.Errorf("memcache: set error(%v)", err)
	}
	if err := conn.Close(); err != nil {
		t.Errorf("memcache: close error(%v)", err)
	}
}

func TestPoolGetLargeValueBoundary(t *testing.T) {
	key := largeValueBoundary.Key
	conn := testPool.Get(context.Background())
	defer conn.Close()
	// get
	var err error
	if _, err = conn.Get(key); err != nil {
		t.Errorf("memcache: large get error(%v)", err)
	}
}

func TestPoolAdd(t *testing.T) {
	var (
		key  = "test_add"
		item = &Item{
			Key:        key,
			Value:      []byte("0"),
			Flags:      0,
			Expiration: 60,
			cas:        0,
		}
		conn = testPool.Get(context.Background())
	)
	defer conn.Close()
	conn.Delete(key)
	if err := conn.Add(item); err != nil {
		t.Errorf("memcache: add error(%v)", err)
	}
	if err := conn.Add(item); err != ErrNotStored {
		t.Errorf("memcache: add error(%v)", err)
	}
}

func TestNewPool(t *testing.T) {
	type args struct {
		cfg *Config
	}
	tests := []struct {
		name      string
		args      args
		wantErr   error
		wantPanic bool
	}{
		{
			"NewPoolIllegalDialTimeout",
			args{
				&Config{
					Name:         "test_illegal_dial_timeout",
					Proto:        "tcp",
					Addr:         testMemcacheAddr,
					DialTimeout:  xtime.Duration(-time.Second),
					ReadTimeout:  xtime.Duration(time.Second),
					WriteTimeout: xtime.Duration(time.Second),
				},
			},
			nil,
			true,
		},
		{
			"NewPoolIllegalReadTimeout",
			args{
				&Config{
					Name:         "test_illegal_read_timeout",
					Proto:        "tcp",
					Addr:         testMemcacheAddr,
					DialTimeout:  xtime.Duration(time.Second),
					ReadTimeout:  xtime.Duration(-time.Second),
					WriteTimeout: xtime.Duration(time.Second),
				},
			},
			nil,
			true,
		},
		{
			"NewPoolIllegalWriteTimeout",
			args{
				&Config{
					Name:         "test_illegal_write_timeout",
					Proto:        "tcp",
					Addr:         testMemcacheAddr,
					DialTimeout:  xtime.Duration(time.Second),
					ReadTimeout:  xtime.Duration(time.Second),
					WriteTimeout: xtime.Duration(-time.Second),
				},
			},
			nil,
			true,
		},
		{
			"NewPool",
			args{
				&Config{
					Name:         "test_new",
					Proto:        "tcp",
					Addr:         testMemcacheAddr,
					DialTimeout:  xtime.Duration(time.Second),
					ReadTimeout:  xtime.Duration(time.Second),
					WriteTimeout: xtime.Duration(time.Second),
				},
			},
			nil,
			true,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			defer func() {
				r := recover()
				if (r != nil) != tt.wantPanic {
					t.Errorf("wantPanic recover = %v, wantPanic = %v", r, tt.wantPanic)
				}
			}()

			if gotP := NewPool(tt.args.cfg); gotP == nil {
				t.Error("NewPool() failed, got nil")
			}
		})
	}
}

func TestPool_Get(t *testing.T) {

	type args struct {
		ctx context.Context
	}
	tests := []struct {
		name    string
		p       *Pool
		args    args
		wantErr bool
		n       int
	}{
		{
			"Get",
			NewPool(&Config{
				Config: &pool.Config{
					Active: 3,
					Idle:   2,
				},
				Name:         "test_get",
				Proto:        "tcp",
				Addr:         testMemcacheAddr,
				DialTimeout:  xtime.Duration(time.Second),
				ReadTimeout:  xtime.Duration(time.Second),
				WriteTimeout: xtime.Duration(time.Second),
			}),
			args{context.TODO()},
			false,
			3,
		},
		{
			"GetExceededPoolSize",
			NewPool(&Config{
				Config: &pool.Config{
					Active: 3,
					Idle:   2,
				},
				Name:         "test_get_out",
				Proto:        "tcp",
				Addr:         testMemcacheAddr,
				DialTimeout:  xtime.Duration(time.Second),
				ReadTimeout:  xtime.Duration(time.Second),
				WriteTimeout: xtime.Duration(time.Second),
			}),
			args{context.TODO()},
			true,
			6,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			for i := 1; i <= tt.n; i++ {
				got := tt.p.Get(tt.args.ctx)
				if reflect.TypeOf(got) == reflect.TypeOf(errConn{}) {
					if !tt.wantErr {
						t.Errorf("got errConn, export Conn")
					}
					return
				} else {
					if tt.wantErr {
						if i > tt.p.c.Active {
							t.Errorf("got Conn, export errConn")
						}
					}
				}
			}
		})
	}
}

func TestPool_Close(t *testing.T) {

	type args struct {
		ctx context.Context
	}
	tests := []struct {
		name    string
		p       *Pool
		args    args
		wantErr bool
		g       int
		c       int
	}{
		{
			"Close",
			NewPool(&Config{
				Config: &pool.Config{
					Active: 1,
					Idle:   1,
				},
				Name:         "test_get",
				Proto:        "tcp",
				Addr:         testMemcacheAddr,
				DialTimeout:  xtime.Duration(time.Second),
				ReadTimeout:  xtime.Duration(time.Second),
				WriteTimeout: xtime.Duration(time.Second),
			}),
			args{context.TODO()},
			false,
			3,
			3,
		},
		{
			"CloseExceededPoolSize",
			NewPool(&Config{
				Config: &pool.Config{
					Active: 1,
					Idle:   1,
				},
				Name:         "test_get_out",
				Proto:        "tcp",
				Addr:         testMemcacheAddr,
				DialTimeout:  xtime.Duration(time.Second),
				ReadTimeout:  xtime.Duration(time.Second),
				WriteTimeout: xtime.Duration(time.Second),
			}),
			args{context.TODO()},
			true,
			5,
			3,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			for i := 1; i <= tt.g; i++ {
				got := tt.p.Get(tt.args.ctx)
				if err := got.Close(); err != nil {
					if !tt.wantErr {
						t.Error(err)
					}
				}
				if i <= tt.c {
					if err := got.Close(); err != nil {
						t.Error(err)
					}
				}
			}
		})
	}
}