From f5d204daae24c14ee47144eca6808a9f68053eb2 Mon Sep 17 00:00:00 2001 From: Windfarer Date: Sat, 12 Oct 2019 15:25:13 +0800 Subject: [PATCH] redis and pipeline --- pkg/cache/redis/commandinfo_test.go | 27 + pkg/cache/redis/conn.go | 62 ++- pkg/cache/redis/conn_test.go | 670 +++++++++++++++++++++++ pkg/cache/redis/log.go | 18 +- pkg/cache/redis/main_test.go | 67 +++ pkg/cache/redis/metrics.go | 4 +- pkg/cache/redis/mock.go | 4 +- pkg/cache/redis/pipeline.go | 85 +++ pkg/cache/redis/pipeline_test.go | 96 ++++ pkg/cache/redis/pool.go | 31 +- pkg/cache/redis/pool_test.go | 540 ++++++++++++++++++ pkg/cache/redis/pubsub_test.go | 146 +++++ pkg/cache/redis/redis.go | 65 ++- pkg/cache/redis/redis_test.go | 324 +++++++++++ pkg/cache/redis/reply_test.go | 179 ++++++ pkg/cache/redis/scan_test.go | 438 +++++++++++++++ pkg/cache/redis/script_test.go | 103 ++++ pkg/cache/redis/test/docker-compose.yaml | 12 + pkg/cache/redis/trace.go | 44 +- pkg/cache/redis/trace_test.go | 192 +++++++ pkg/cache/redis/util.go | 17 + pkg/cache/redis/util_test.go | 37 ++ 22 files changed, 3084 insertions(+), 77 deletions(-) create mode 100644 pkg/cache/redis/commandinfo_test.go create mode 100644 pkg/cache/redis/conn_test.go create mode 100644 pkg/cache/redis/main_test.go create mode 100644 pkg/cache/redis/pipeline.go create mode 100644 pkg/cache/redis/pipeline_test.go create mode 100644 pkg/cache/redis/pool_test.go create mode 100644 pkg/cache/redis/pubsub_test.go create mode 100644 pkg/cache/redis/redis_test.go create mode 100644 pkg/cache/redis/reply_test.go create mode 100644 pkg/cache/redis/scan_test.go create mode 100644 pkg/cache/redis/script_test.go create mode 100644 pkg/cache/redis/test/docker-compose.yaml create mode 100644 pkg/cache/redis/trace_test.go create mode 100644 pkg/cache/redis/util.go create mode 100644 pkg/cache/redis/util_test.go diff --git a/pkg/cache/redis/commandinfo_test.go b/pkg/cache/redis/commandinfo_test.go new file mode 100644 index 000000000..d8f4e5214 --- /dev/null +++ b/pkg/cache/redis/commandinfo_test.go @@ -0,0 +1,27 @@ +package redis + +import "testing" + +func TestLookupCommandInfo(t *testing.T) { + for _, n := range []string{"watch", "WATCH", "wAtch"} { + if LookupCommandInfo(n) == (CommandInfo{}) { + t.Errorf("LookupCommandInfo(%q) = CommandInfo{}, expected non-zero value", n) + } + } +} + +func benchmarkLookupCommandInfo(b *testing.B, names ...string) { + for i := 0; i < b.N; i++ { + for _, c := range names { + LookupCommandInfo(c) + } + } +} + +func BenchmarkLookupCommandInfoCorrectCase(b *testing.B) { + benchmarkLookupCommandInfo(b, "watch", "WATCH", "monitor", "MONITOR") +} + +func BenchmarkLookupCommandInfoMixedCase(b *testing.B) { + benchmarkLookupCommandInfo(b, "wAtch", "WeTCH", "monItor", "MONiTOR") +} diff --git a/pkg/cache/redis/conn.go b/pkg/cache/redis/conn.go index 3f5a5a454..949a4bd55 100644 --- a/pkg/cache/redis/conn.go +++ b/pkg/cache/redis/conn.go @@ -30,6 +30,33 @@ import ( "github.com/pkg/errors" ) +// Conn represents a connection to a Redis server. +type Conn interface { + // Close closes the connection. + Close() error + + // Err returns a non-nil value if the connection is broken. The returned + // value is either the first non-nil value returned from the underlying + // network connection or a protocol parsing error. Applications should + // close broken connections. + Err() error + + // Do sends a command to the server and returns the received reply. + Do(commandName string, args ...interface{}) (reply interface{}, err error) + + // Send writes the command to the client's output buffer. + Send(commandName string, args ...interface{}) error + + // Flush flushes the output buffer to the Redis server. + Flush() error + + // Receive receives a single reply from the Redis server + Receive() (reply interface{}, err error) + + // WithContext returns Conn with the input ctx. + WithContext(ctx context.Context) Conn +} + // conn is the low-level implementation of Conn type conn struct { // Shared @@ -38,6 +65,8 @@ type conn struct { err error conn net.Conn + ctx context.Context + // Read readTimeout time.Duration br *bufio.Reader @@ -226,6 +255,7 @@ func NewConn(c *Config) (cn Conn, err error) { func (c *conn) Close() error { c.mu.Lock() + c.ctx = nil err := c.err if c.err == nil { c.err = errors.New("redigo: closed") @@ -295,7 +325,7 @@ func (c *conn) writeFloat64(n float64) error { func (c *conn) writeCommand(cmd string, args []interface{}) (err error) { if c.writeTimeout != 0 { - c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) + c.conn.SetWriteDeadline(shrinkDeadline(c.ctx, c.writeTimeout)) } c.writeLen('*', 1+len(args)) err = c.writeString(cmd) @@ -478,7 +508,7 @@ func (c *conn) Send(cmd string, args ...interface{}) (err error) { func (c *conn) Flush() (err error) { if c.writeTimeout != 0 { - c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) + c.conn.SetWriteDeadline(shrinkDeadline(c.ctx, c.writeTimeout)) } if err = c.bw.Flush(); err != nil { c.fatal(err) @@ -488,7 +518,7 @@ func (c *conn) Flush() (err error) { func (c *conn) Receive() (reply interface{}, err error) { if c.readTimeout != 0 { - c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) + c.conn.SetReadDeadline(shrinkDeadline(c.ctx, c.readTimeout)) } if reply, err = c.readReply(); err != nil { return nil, c.fatal(err) @@ -511,7 +541,7 @@ func (c *conn) Receive() (reply interface{}, err error) { return } -func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { +func (c *conn) Do(cmd string, args ...interface{}) (reply interface{}, err error) { c.mu.Lock() pending := c.pending c.pending = 0 @@ -519,7 +549,7 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { if cmd == "" && pending == 0 { return nil, nil } - var err error + if cmd != "" { err = c.writeCommand(cmd, args) } @@ -530,7 +560,7 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { return nil, c.fatal(err) } if c.readTimeout != 0 { - c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) + c.conn.SetReadDeadline(shrinkDeadline(c.ctx, c.readTimeout)) } if cmd == "" { reply := make([]interface{}, pending) @@ -548,7 +578,6 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { return reply, nil } - var reply interface{} for i := 0; i <= pending; i++ { var e error if reply, e = c.readReply(); e != nil { @@ -561,5 +590,20 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { return reply, err } -// WithContext FIXME: implement WithContext -func (c *conn) WithContext(ctx context.Context) Conn { return c } +func (c *conn) copy() *conn { + return &conn{ + pending: c.pending, + err: c.err, + conn: c.conn, + bw: c.bw, + br: c.br, + readTimeout: c.readTimeout, + writeTimeout: c.writeTimeout, + } +} + +func (c *conn) WithContext(ctx context.Context) Conn { + c2 := c.copy() + c2.ctx = ctx + return c2 +} diff --git a/pkg/cache/redis/conn_test.go b/pkg/cache/redis/conn_test.go new file mode 100644 index 000000000..3e37e882c --- /dev/null +++ b/pkg/cache/redis/conn_test.go @@ -0,0 +1,670 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import ( + "bytes" + "context" + "io" + "math" + "net" + "os" + "reflect" + "strings" + "testing" + "time" +) + +type tConn struct { + io.Reader + io.Writer +} + +func (*tConn) Close() error { return nil } +func (*tConn) LocalAddr() net.Addr { return nil } +func (*tConn) RemoteAddr() net.Addr { return nil } +func (*tConn) SetDeadline(t time.Time) error { return nil } +func (*tConn) SetReadDeadline(t time.Time) error { return nil } +func (*tConn) SetWriteDeadline(t time.Time) error { return nil } + +func dialTestConn(r io.Reader, w io.Writer) DialOption { + return DialNetDial(func(net, addr string) (net.Conn, error) { + return &tConn{Reader: r, Writer: w}, nil + }) +} + +var writeTests = []struct { + args []interface{} + expected string +}{ + { + []interface{}{"SET", "key", "value"}, + "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n", + }, + { + []interface{}{"SET", "key", "value"}, + "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n", + }, + { + []interface{}{"SET", "key", byte(100)}, + "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$3\r\n100\r\n", + }, + { + []interface{}{"SET", "key", 100}, + "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$3\r\n100\r\n", + }, + { + []interface{}{"SET", "key", int64(math.MinInt64)}, + "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$20\r\n-9223372036854775808\r\n", + }, + { + []interface{}{"SET", "key", float64(1349673917.939762)}, + "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$21\r\n1.349673917939762e+09\r\n", + }, + { + []interface{}{"SET", "key", ""}, + "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n", + }, + { + []interface{}{"SET", "key", nil}, + "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n", + }, + { + []interface{}{"ECHO", true, false}, + "*3\r\n$4\r\nECHO\r\n$1\r\n1\r\n$1\r\n0\r\n", + }, +} + +func TestWrite(t *testing.T) { + for _, tt := range writeTests { + var buf bytes.Buffer + c, _ := Dial("", "", dialTestConn(nil, &buf)) + err := c.Send(tt.args[0].(string), tt.args[1:]...) + if err != nil { + t.Errorf("Send(%v) returned error %v", tt.args, err) + continue + } + c.Flush() + actual := buf.String() + if actual != tt.expected { + t.Errorf("Send(%v) = %q, want %q", tt.args, actual, tt.expected) + } + } +} + +var errorSentinel = &struct{}{} + +var readTests = []struct { + reply string + expected interface{} +}{ + { + "+OK\r\n", + "OK", + }, + { + "+PONG\r\n", + "PONG", + }, + { + "@OK\r\n", + errorSentinel, + }, + { + "$6\r\nfoobar\r\n", + []byte("foobar"), + }, + { + "$-1\r\n", + nil, + }, + { + ":1\r\n", + int64(1), + }, + { + ":-2\r\n", + int64(-2), + }, + { + "*0\r\n", + []interface{}{}, + }, + { + "*-1\r\n", + nil, + }, + { + "*4\r\n$3\r\nfoo\r\n$3\r\nbar\r\n$5\r\nHello\r\n$5\r\nWorld\r\n", + []interface{}{[]byte("foo"), []byte("bar"), []byte("Hello"), []byte("World")}, + }, + { + "*3\r\n$3\r\nfoo\r\n$-1\r\n$3\r\nbar\r\n", + []interface{}{[]byte("foo"), nil, []byte("bar")}, + }, + + { + // "x" is not a valid length + "$x\r\nfoobar\r\n", + errorSentinel, + }, + { + // -2 is not a valid length + "$-2\r\n", + errorSentinel, + }, + { + // "x" is not a valid integer + ":x\r\n", + errorSentinel, + }, + { + // missing \r\n following value + "$6\r\nfoobar", + errorSentinel, + }, + { + // short value + "$6\r\nxx", + errorSentinel, + }, + { + // long value + "$6\r\nfoobarx\r\n", + errorSentinel, + }, +} + +func TestRead(t *testing.T) { + for _, tt := range readTests { + c, _ := Dial("", "", dialTestConn(strings.NewReader(tt.reply), nil)) + actual, err := c.Receive() + if tt.expected == errorSentinel { + if err == nil { + t.Errorf("Receive(%q) did not return expected error", tt.reply) + } + } else { + if err != nil { + t.Errorf("Receive(%q) returned error %v", tt.reply, err) + continue + } + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("Receive(%q) = %v, want %v", tt.reply, actual, tt.expected) + } + } + } +} + +var testCommands = []struct { + args []interface{} + expected interface{} +}{ + { + []interface{}{"PING"}, + "PONG", + }, + { + []interface{}{"SET", "foo", "bar"}, + "OK", + }, + { + []interface{}{"GET", "foo"}, + []byte("bar"), + }, + { + []interface{}{"GET", "nokey"}, + nil, + }, + { + []interface{}{"MGET", "nokey", "foo"}, + []interface{}{nil, []byte("bar")}, + }, + { + []interface{}{"INCR", "mycounter"}, + int64(1), + }, + { + []interface{}{"LPUSH", "mylist", "foo"}, + int64(1), + }, + { + []interface{}{"LPUSH", "mylist", "bar"}, + int64(2), + }, + { + []interface{}{"LRANGE", "mylist", 0, -1}, + []interface{}{[]byte("bar"), []byte("foo")}, + }, + { + []interface{}{"MULTI"}, + "OK", + }, + { + []interface{}{"LRANGE", "mylist", 0, -1}, + "QUEUED", + }, + { + []interface{}{"PING"}, + "QUEUED", + }, + { + []interface{}{"EXEC"}, + []interface{}{ + []interface{}{[]byte("bar"), []byte("foo")}, + "PONG", + }, + }, +} + +func TestDoCommands(t *testing.T) { + c, err := DialDefaultServer() + if err != nil { + t.Fatalf("error connection to database, %v", err) + } + defer c.Close() + + for _, cmd := range testCommands { + actual, err := c.Do(cmd.args[0].(string), cmd.args[1:]...) + if err != nil { + t.Errorf("Do(%v) returned error %v", cmd.args, err) + continue + } + if !reflect.DeepEqual(actual, cmd.expected) { + t.Errorf("Do(%v) = %v, want %v", cmd.args, actual, cmd.expected) + } + } +} + +func TestPipelineCommands(t *testing.T) { + c, err := DialDefaultServer() + if err != nil { + t.Fatalf("error connection to database, %v", err) + } + defer c.Close() + + for _, cmd := range testCommands { + if err := c.Send(cmd.args[0].(string), cmd.args[1:]...); err != nil { + t.Fatalf("Send(%v) returned error %v", cmd.args, err) + } + } + if err := c.Flush(); err != nil { + t.Errorf("Flush() returned error %v", err) + } + for _, cmd := range testCommands { + actual, err := c.Receive() + if err != nil { + t.Fatalf("Receive(%v) returned error %v", cmd.args, err) + } + if !reflect.DeepEqual(actual, cmd.expected) { + t.Errorf("Receive(%v) = %v, want %v", cmd.args, actual, cmd.expected) + } + } +} + +func TestBlankCommmand(t *testing.T) { + c, err := DialDefaultServer() + if err != nil { + t.Fatalf("error connection to database, %v", err) + } + defer c.Close() + + for _, cmd := range testCommands { + if err = c.Send(cmd.args[0].(string), cmd.args[1:]...); err != nil { + t.Fatalf("Send(%v) returned error %v", cmd.args, err) + } + } + reply, err := Values(c.Do("")) + if err != nil { + t.Fatalf("Do() returned error %v", err) + } + if len(reply) != len(testCommands) { + t.Fatalf("len(reply)=%d, want %d", len(reply), len(testCommands)) + } + for i, cmd := range testCommands { + actual := reply[i] + if !reflect.DeepEqual(actual, cmd.expected) { + t.Errorf("Receive(%v) = %v, want %v", cmd.args, actual, cmd.expected) + } + } +} + +func TestRecvBeforeSend(t *testing.T) { + c, err := DialDefaultServer() + if err != nil { + t.Fatalf("error connection to database, %v", err) + } + defer c.Close() + done := make(chan struct{}) + go func() { + c.Receive() + close(done) + }() + time.Sleep(time.Millisecond) + c.Send("PING") + c.Flush() + <-done + _, err = c.Do("") + if err != nil { + t.Fatalf("error=%v", err) + } +} + +func TestError(t *testing.T) { + c, err := DialDefaultServer() + if err != nil { + t.Fatalf("error connection to database, %v", err) + } + defer c.Close() + + c.Do("SET", "key", "val") + _, err = c.Do("HSET", "key", "fld", "val") + if err == nil { + t.Errorf("Expected err for HSET on string key.") + } + if c.Err() != nil { + t.Errorf("Conn has Err()=%v, expect nil", c.Err()) + } + _, err = c.Do("SET", "key", "val") + if err != nil { + t.Errorf("Do(SET, key, val) returned error %v, expected nil.", err) + } +} + +func TestReadTimeout(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.Listen returned %v", err) + } + defer l.Close() + + go func() { + for { + c, err1 := l.Accept() + if err1 != nil { + return + } + go func() { + time.Sleep(time.Second) + c.Write([]byte("+OK\r\n")) + c.Close() + }() + } + }() + + // Do + + c1, err := Dial(l.Addr().Network(), l.Addr().String(), DialReadTimeout(time.Millisecond)) + if err != nil { + t.Fatalf("Dial returned %v", err) + } + defer c1.Close() + + _, err = c1.Do("PING") + if err == nil { + t.Fatalf("c1.Do() returned nil, expect error") + } + if c1.Err() == nil { + t.Fatalf("c1.Err() = nil, expect error") + } + + // Send/Flush/Receive + + c2, err := Dial(l.Addr().Network(), l.Addr().String(), DialReadTimeout(time.Millisecond)) + if err != nil { + t.Fatalf("Dial returned %v", err) + } + defer c2.Close() + + c2.Send("PING") + c2.Flush() + _, err = c2.Receive() + if err == nil { + t.Fatalf("c2.Receive() returned nil, expect error") + } + if c2.Err() == nil { + t.Fatalf("c2.Err() = nil, expect error") + } +} + +var dialErrors = []struct { + rawurl string + expectedError string +}{ + { + "localhost", + "invalid redis URL scheme", + }, + // The error message for invalid hosts is diffferent in different + // versions of Go, so just check that there is an error message. + { + "redis://weird url", + "", + }, + { + "redis://foo:bar:baz", + "", + }, + { + "http://www.google.com", + "invalid redis URL scheme: http", + }, + { + "redis://localhost:6379/abc123", + "invalid database: abc123", + }, +} + +func TestDialURLErrors(t *testing.T) { + for _, d := range dialErrors { + _, err := DialURL(d.rawurl) + if err == nil || !strings.Contains(err.Error(), d.expectedError) { + t.Errorf("DialURL did not return expected error (expected %v to contain %s)", err, d.expectedError) + } + } +} + +func TestDialURLPort(t *testing.T) { + checkPort := func(network, address string) (net.Conn, error) { + if address != "localhost:6379" { + t.Errorf("DialURL did not set port to 6379 by default (got %v)", address) + } + return nil, nil + } + _, err := DialURL("redis://localhost", DialNetDial(checkPort)) + if err != nil { + t.Error("dial error:", err) + } +} + +func TestDialURLHost(t *testing.T) { + checkHost := func(network, address string) (net.Conn, error) { + if address != "localhost:6379" { + t.Errorf("DialURL did not set host to localhost by default (got %v)", address) + } + return nil, nil + } + _, err := DialURL("redis://:6379", DialNetDial(checkHost)) + if err != nil { + t.Error("dial error:", err) + } +} + +func TestDialURLPassword(t *testing.T) { + var buf bytes.Buffer + _, err := DialURL("redis://x:abc123@localhost", dialTestConn(strings.NewReader("+OK\r\n"), &buf)) + if err != nil { + t.Error("dial error:", err) + } + expected := "*2\r\n$4\r\nAUTH\r\n$6\r\nabc123\r\n" + actual := buf.String() + if actual != expected { + t.Errorf("commands = %q, want %q", actual, expected) + } +} + +func TestDialURLDatabase(t *testing.T) { + var buf bytes.Buffer + _, err := DialURL("redis://localhost/3", dialTestConn(strings.NewReader("+OK\r\n"), &buf)) + if err != nil { + t.Error("dial error:", err) + } + expected := "*2\r\n$6\r\nSELECT\r\n$1\r\n3\r\n" + actual := buf.String() + if actual != expected { + t.Errorf("commands = %q, want %q", actual, expected) + } +} + +// Connect to local instance of Redis running on the default port. +func ExampleDial() { + c, err := Dial("tcp", ":6379") + if err != nil { + // handle error + } + defer c.Close() +} + +// Connect to remote instance of Redis using a URL. +func ExampleDialURL() { + c, err := DialURL(os.Getenv("REDIS_URL")) + if err != nil { + // handle connection error + } + defer c.Close() +} + +// TextExecError tests handling of errors in a transaction. See +// http://io/topics/transactions for information on how Redis handles +// errors in a transaction. +func TestExecError(t *testing.T) { + c, err := DialDefaultServer() + if err != nil { + t.Fatalf("error connection to database, %v", err) + } + defer c.Close() + + // Execute commands that fail before EXEC is called. + + c.Do("DEL", "k0") + c.Do("ZADD", "k0", 0, 0) + c.Send("MULTI") + c.Send("NOTACOMMAND", "k0", 0, 0) + c.Send("ZINCRBY", "k0", 0, 0) + v, err := c.Do("EXEC") + if err == nil { + t.Fatalf("EXEC returned values %v, expected error", v) + } + + // Execute commands that fail after EXEC is called. The first command + // returns an error. + + c.Do("DEL", "k1") + c.Do("ZADD", "k1", 0, 0) + c.Send("MULTI") + c.Send("HSET", "k1", 0, 0) + c.Send("ZINCRBY", "k1", 0, 0) + v, err = c.Do("EXEC") + if err != nil { + t.Fatalf("EXEC returned error %v", err) + } + + vs, err := Values(v, nil) + if err != nil { + t.Fatalf("Values(v) returned error %v", err) + } + + if len(vs) != 2 { + t.Fatalf("len(vs) == %d, want 2", len(vs)) + } + + if _, ok := vs[0].(error); !ok { + t.Fatalf("first result is type %T, expected error", vs[0]) + } + + if _, ok := vs[1].([]byte); !ok { + t.Fatalf("second result is type %T, expected []byte", vs[1]) + } + + // Execute commands that fail after EXEC is called. The second command + // returns an error. + + c.Do("ZADD", "k2", 0, 0) + c.Send("MULTI") + c.Send("ZINCRBY", "k2", 0, 0) + c.Send("HSET", "k2", 0, 0) + v, err = c.Do("EXEC") + if err != nil { + t.Fatalf("EXEC returned error %v", err) + } + + vs, err = Values(v, nil) + if err != nil { + t.Fatalf("Values(v) returned error %v", err) + } + + if len(vs) != 2 { + t.Fatalf("len(vs) == %d, want 2", len(vs)) + } + + if _, ok := vs[0].([]byte); !ok { + t.Fatalf("first result is type %T, expected []byte", vs[0]) + } + + if _, ok := vs[1].(error); !ok { + t.Fatalf("second result is type %T, expected error", vs[2]) + } +} + +func BenchmarkDoEmpty(b *testing.B) { + c, err := DialDefaultServer() + if err != nil { + b.Fatal(err) + } + defer c.Close() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := c.Do(""); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkDoPing(b *testing.B) { + c, err := DialDefaultServer() + if err != nil { + b.Fatal(err) + } + defer c.Close() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := c.Do("PING"); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkConn(b *testing.B) { + for i := 0; i < b.N; i++ { + c, err := DialDefaultServer() + if err != nil { + b.Fatal(err) + } + c2 := c.WithContext(context.TODO()) + if _, err := c2.Do("PING"); err != nil { + b.Fatal(err) + } + c2.Close() + } +} diff --git a/pkg/cache/redis/log.go b/pkg/cache/redis/log.go index 129b86d67..487a1408f 100644 --- a/pkg/cache/redis/log.go +++ b/pkg/cache/redis/log.go @@ -16,16 +16,18 @@ package redis import ( "bytes" + "context" "fmt" "log" ) // NewLoggingConn returns a logging wrapper around a connection. +// ATTENTION: ONLY use loggingConn in developing, DO NOT use this in production. func NewLoggingConn(conn Conn, logger *log.Logger, prefix string) Conn { if prefix != "" { prefix = prefix + "." } - return &loggingConn{conn, logger, prefix} + return &loggingConn{Conn: conn, logger: logger, prefix: prefix} } type loggingConn struct { @@ -98,16 +100,16 @@ func (c *loggingConn) print(method, commandName string, args []interface{}, repl c.logger.Output(3, buf.String()) } -func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{}, error) { - reply, err := c.Conn.Do(commandName, args...) +func (c *loggingConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) { + reply, err = c.Conn.Do(commandName, args...) c.print("Do", commandName, args, reply, err) return reply, err } -func (c *loggingConn) Send(commandName string, args ...interface{}) error { - err := c.Conn.Send(commandName, args...) +func (c *loggingConn) Send(commandName string, args ...interface{}) (err error) { + err = c.Conn.Send(commandName, args...) c.print("Send", commandName, args, nil, err) - return err + return } func (c *loggingConn) Receive() (interface{}, error) { @@ -115,3 +117,7 @@ func (c *loggingConn) Receive() (interface{}, error) { c.print("Receive", "", nil, reply, err) return reply, err } + +func (c *loggingConn) WithContext(ctx context.Context) Conn { + return c +} diff --git a/pkg/cache/redis/main_test.go b/pkg/cache/redis/main_test.go new file mode 100644 index 000000000..3164628f2 --- /dev/null +++ b/pkg/cache/redis/main_test.go @@ -0,0 +1,67 @@ +package redis + +import ( + "flag" + "os" + "testing" + "time" + + "github.com/bilibili/kratos/pkg/container/pool" + "github.com/bilibili/kratos/pkg/testing/lich" + xtime "github.com/bilibili/kratos/pkg/time" +) + +var ( + testRedisAddr string + testPool *Pool + testConfig *Config +) + +func setupTestConfig(addr string) { + c := getTestConfig(addr) + c.Config = &pool.Config{ + Active: 20, + Idle: 2, + IdleTimeout: xtime.Duration(90 * time.Second), + } + testConfig = c +} + +func getTestConfig(addr string) *Config { + return &Config{ + Name: "test", + Proto: "tcp", + Addr: addr, + DialTimeout: xtime.Duration(time.Second), + ReadTimeout: xtime.Duration(time.Second), + WriteTimeout: xtime.Duration(time.Second), + } +} + +func setupTestPool() { + testPool = NewPool(testConfig) +} + +// DialDefaultServer starts the test server if not already started and dials a +// connection to the server. +func DialDefaultServer() (Conn, error) { + c, err := Dial("tcp", testRedisAddr, DialReadTimeout(1*time.Second), DialWriteTimeout(1*time.Second)) + if err != nil { + return nil, err + } + c.Do("FLUSHDB") + return c, nil +} + +func TestMain(m *testing.M) { + flag.Set("f", "./test/docker-compose.yaml") + if err := lich.Setup(); err != nil { + panic(err) + } + defer lich.Teardown() + testRedisAddr = "localhost:6379" + setupTestConfig(testRedisAddr) + setupTestPool() + ret := m.Run() + os.Exit(ret) +} diff --git a/pkg/cache/redis/metrics.go b/pkg/cache/redis/metrics.go index e48795957..2037ce48a 100644 --- a/pkg/cache/redis/metrics.go +++ b/pkg/cache/redis/metrics.go @@ -1,6 +1,8 @@ package redis -import "github.com/bilibili/kratos/pkg/stat/metric" +import ( + "github.com/bilibili/kratos/pkg/stat/metric" +) const namespace = "redis_client" diff --git a/pkg/cache/redis/mock.go b/pkg/cache/redis/mock.go index da75817f3..fc9d5a3da 100644 --- a/pkg/cache/redis/mock.go +++ b/pkg/cache/redis/mock.go @@ -1,8 +1,6 @@ package redis -import ( - "context" -) +import "context" // MockErr for unit test. type MockErr struct { diff --git a/pkg/cache/redis/pipeline.go b/pkg/cache/redis/pipeline.go new file mode 100644 index 000000000..0a23205e5 --- /dev/null +++ b/pkg/cache/redis/pipeline.go @@ -0,0 +1,85 @@ +package redis + +import ( + "context" + "errors" +) + +type Pipeliner interface { + // Send writes the command to the client's output buffer. + Send(commandName string, args ...interface{}) + + // Exec executes all commands and get replies. + Exec(ctx context.Context) (rs *Replies, err error) +} + +var ( + ErrNoReply = errors.New("redis: no reply in result set") +) + +type pipeliner struct { + pool *Pool + cmds []*cmd +} + +type Replies struct { + replies []*reply +} + +type reply struct { + reply interface{} + err error +} + +func (rs *Replies) Next() bool { + return len(rs.replies) > 0 +} + +func (rs *Replies) Scan() (reply interface{}, err error) { + if !rs.Next() { + return nil, ErrNoReply + } + reply, err = rs.replies[0].reply, rs.replies[0].err + rs.replies = rs.replies[1:] + return +} + +type cmd struct { + commandName string + args []interface{} +} + +func (p *pipeliner) Send(commandName string, args ...interface{}) { + p.cmds = append(p.cmds, &cmd{commandName: commandName, args: args}) + return +} + +func (p *pipeliner) Exec(ctx context.Context) (rs *Replies, err error) { + n := len(p.cmds) + if n == 0 { + return &Replies{}, nil + } + c := p.pool.Get(ctx) + defer c.Close() + for len(p.cmds) > 0 { + cmd := p.cmds[0] + p.cmds = p.cmds[1:] + if err := c.Send(cmd.commandName, cmd.args...); err != nil { + p.cmds = p.cmds[:0] + return nil, err + } + } + if err = c.Flush(); err != nil { + p.cmds = p.cmds[:0] + return nil, err + } + rps := make([]*reply, 0, n) + for i := 0; i < n; i++ { + rp, err := c.Receive() + rps = append(rps, &reply{reply: rp, err: err}) + } + rs = &Replies{ + replies: rps, + } + return +} diff --git a/pkg/cache/redis/pipeline_test.go b/pkg/cache/redis/pipeline_test.go new file mode 100644 index 000000000..d3d95d260 --- /dev/null +++ b/pkg/cache/redis/pipeline_test.go @@ -0,0 +1,96 @@ +package redis + +import ( + "context" + "fmt" + "reflect" + "testing" + "time" + + "github.com/bilibili/kratos/pkg/container/pool" + xtime "github.com/bilibili/kratos/pkg/time" +) + +func TestRedis_Pipeline(t *testing.T) { + conf := &Config{ + Name: "test", + Proto: "tcp", + Addr: testRedisAddr, + DialTimeout: xtime.Duration(1 * time.Second), + ReadTimeout: xtime.Duration(1 * time.Second), + WriteTimeout: xtime.Duration(1 * time.Second), + } + conf.Config = &pool.Config{ + Active: 10, + Idle: 2, + IdleTimeout: xtime.Duration(90 * time.Second), + } + + r := NewRedis(conf) + r.Do(context.TODO(), "FLUSHDB") + + p := r.Pipeline() + + for _, cmd := range testCommands { + p.Send(cmd.args[0].(string), cmd.args[1:]...) + } + + replies, err := p.Exec(context.TODO()) + + i := 0 + for replies.Next() { + cmd := testCommands[i] + actual, err := replies.Scan() + if err != nil { + t.Fatalf("Receive(%v) returned error %v", cmd.args, err) + } + if !reflect.DeepEqual(actual, cmd.expected) { + t.Errorf("Receive(%v) = %v, want %v", cmd.args, actual, cmd.expected) + } + i++ + } + err = r.Close() + if err != nil { + t.Errorf("Close() error %v", err) + } +} + +func ExamplePipeliner() { + r := NewRedis(testConfig) + defer r.Close() + + pip := r.Pipeline() + pip.Send("SET", "hello", "world") + pip.Send("GET", "hello") + replies, err := pip.Exec(context.TODO()) + if err != nil { + fmt.Printf("%#v\n", err) + } + for replies.Next() { + s, err := String(replies.Scan()) + if err != nil { + fmt.Printf("err %#v\n", err) + } + fmt.Printf("%#v\n", s) + } + // Output: + // "OK" + // "world" +} + +func BenchmarkRedisPipelineExec(b *testing.B) { + r := NewRedis(testConfig) + defer r.Close() + + r.Do(context.TODO(), "SET", "abcde", "fghiasdfasdf") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + p := r.Pipeline() + p.Send("GET", "abcde") + _, err := p.Exec(context.TODO()) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/pkg/cache/redis/pool.go b/pkg/cache/redis/pool.go index 9da107919..cd64de6a5 100644 --- a/pkg/cache/redis/pool.go +++ b/pkg/cache/redis/pool.go @@ -45,24 +45,14 @@ type Pool struct { statfunc func(name, addr, cmd string, t time.Time, err error) func() } -// Config client settings. -type Config struct { - *pool.Config - - Name string // redis name, for trace - Proto string - Addr string - Auth string - DialTimeout xtime.Duration - ReadTimeout xtime.Duration - WriteTimeout xtime.Duration -} - // NewPool creates a new pool. func NewPool(c *Config, options ...DialOption) (p *Pool) { if c.DialTimeout <= 0 || c.ReadTimeout <= 0 || c.WriteTimeout <= 0 { panic("must config redis timeout") } + if c.SlowLog <= 0 { + c.SlowLog = xtime.Duration(250 * time.Millisecond) + } ops := []DialOption{ DialConnectTimeout(time.Duration(c.DialTimeout)), DialReadTimeout(time.Duration(c.ReadTimeout)), @@ -71,12 +61,18 @@ func NewPool(c *Config, options ...DialOption) (p *Pool) { } ops = append(ops, options...) p1 := pool.NewSlice(c.Config) + + // new pool p1.New = func(ctx context.Context) (io.Closer, error) { conn, err := Dial(c.Proto, c.Addr, ops...) if err != nil { return nil, err } - return &traceConn{Conn: conn, connTags: []trace.Tag{trace.TagString(trace.TagPeerAddress, c.Addr)}}, nil + return &traceConn{ + Conn: conn, + connTags: []trace.Tag{trace.TagString(trace.TagPeerAddress, c.Addr)}, + slowLogThreshold: time.Duration(c.SlowLog), + }, nil } p = &Pool{Slice: p1, c: c, statfunc: pstat} return @@ -93,7 +89,7 @@ func (p *Pool) Get(ctx context.Context) Conn { return errorConnection{err} } c1, _ := c.(Conn) - return &pooledConnection{p: p, c: c1.WithContext(ctx), ctx: ctx, now: beginTime} + return &pooledConnection{p: p, c: c1.WithContext(ctx), rc: c1, now: beginTime} } // Close releases the resources used by the pool. @@ -103,12 +99,12 @@ func (p *Pool) Close() error { type pooledConnection struct { p *Pool + rc Conn c Conn state int now time.Time cmds []string - ctx context.Context } var ( @@ -180,7 +176,7 @@ func (pc *pooledConnection) Close() error { } } _, err := c.Do("") - pc.p.Slice.Put(context.Background(), c, pc.state != 0 || c.Err() != nil) + pc.p.Slice.Put(context.Background(), pc.rc, pc.state != 0 || c.Err() != nil) return err } @@ -227,7 +223,6 @@ func (pc *pooledConnection) Receive() (reply interface{}, err error) { } func (pc *pooledConnection) WithContext(ctx context.Context) Conn { - pc.ctx = ctx return pc } diff --git a/pkg/cache/redis/pool_test.go b/pkg/cache/redis/pool_test.go new file mode 100644 index 000000000..fdea2a337 --- /dev/null +++ b/pkg/cache/redis/pool_test.go @@ -0,0 +1,540 @@ +// Copyright 2011 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import ( + "context" + "errors" + "io" + "reflect" + "sync" + "testing" + "time" + + "github.com/bilibili/kratos/pkg/container/pool" +) + +type poolTestConn struct { + d *poolDialer + err error + c Conn + ctx context.Context +} + +func (c *poolTestConn) Flush() error { + return c.c.Flush() +} + +func (c *poolTestConn) Receive() (reply interface{}, err error) { + return c.c.Receive() +} + +func (c *poolTestConn) WithContext(ctx context.Context) Conn { + c.c.WithContext(ctx) + c.ctx = ctx + return c +} + +func (c *poolTestConn) Close() error { + c.d.mu.Lock() + c.d.open-- + c.d.mu.Unlock() + return c.c.Close() +} + +func (c *poolTestConn) Err() error { return c.err } + +func (c *poolTestConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) { + if commandName == "ERR" { + c.err = args[0].(error) + commandName = "PING" + } + if commandName != "" { + c.d.commands = append(c.d.commands, commandName) + } + return c.c.Do(commandName, args...) +} + +func (c *poolTestConn) Send(commandName string, args ...interface{}) error { + c.d.commands = append(c.d.commands, commandName) + return c.c.Send(commandName, args...) +} + +type poolDialer struct { + mu sync.Mutex + t *testing.T + dialed int + open int + commands []string + dialErr error +} + +func (d *poolDialer) dial() (Conn, error) { + d.mu.Lock() + d.dialed += 1 + dialErr := d.dialErr + d.mu.Unlock() + if dialErr != nil { + return nil, d.dialErr + } + c, err := DialDefaultServer() + if err != nil { + return nil, err + } + d.mu.Lock() + d.open += 1 + d.mu.Unlock() + return &poolTestConn{d: d, c: c}, nil +} + +func (d *poolDialer) check(message string, p *Pool, dialed, open int) { + d.mu.Lock() + if d.dialed != dialed { + d.t.Errorf("%s: dialed=%d, want %d", message, d.dialed, dialed) + } + if d.open != open { + d.t.Errorf("%s: open=%d, want %d", message, d.open, open) + } + // if active := p.ActiveCount(); active != open { + // d.t.Errorf("%s: active=%d, want %d", message, active, open) + // } + d.mu.Unlock() +} + +func TestPoolReuse(t *testing.T) { + d := poolDialer{t: t} + p := NewPool(testConfig) + p.Slice.New = func(ctx context.Context) (io.Closer, error) { + return d.dial() + } + var err error + + for i := 0; i < 10; i++ { + c1 := p.Get(context.TODO()) + c1.Do("PING") + c2 := p.Get(context.TODO()) + c2.Do("PING") + c1.Close() + c2.Close() + + } + + d.check("before close", p, 2, 2) + err = p.Close() + if err != nil { + t.Fatal(err) + } + d.check("after close", p, 2, 0) +} + +func TestPoolMaxIdle(t *testing.T) { + d := poolDialer{t: t} + p := NewPool(testConfig) + p.Slice.New = func(ctx context.Context) (io.Closer, error) { + return d.dial() + } + defer p.Close() + + for i := 0; i < 10; i++ { + c1 := p.Get(context.TODO()) + c1.Do("PING") + c2 := p.Get(context.TODO()) + c2.Do("PING") + c3 := p.Get(context.TODO()) + c3.Do("PING") + c1.Close() + c2.Close() + c3.Close() + } + d.check("before close", p, 12, 2) + p.Close() + d.check("after close", p, 12, 0) +} + +func TestPoolError(t *testing.T) { + d := poolDialer{t: t} + p := NewPool(testConfig) + p.Slice.New = func(ctx context.Context) (io.Closer, error) { + return d.dial() + } + defer p.Close() + + c := p.Get(context.TODO()) + c.Do("ERR", io.EOF) + if c.Err() == nil { + t.Errorf("expected c.Err() != nil") + } + c.Close() + + c = p.Get(context.TODO()) + c.Do("ERR", io.EOF) + c.Close() + + d.check(".", p, 2, 0) +} + +func TestPoolClose(t *testing.T) { + d := poolDialer{t: t} + p := NewPool(testConfig) + p.Slice.New = func(ctx context.Context) (io.Closer, error) { + return d.dial() + } + defer p.Close() + + c1 := p.Get(context.TODO()) + c1.Do("PING") + c2 := p.Get(context.TODO()) + c2.Do("PING") + c3 := p.Get(context.TODO()) + c3.Do("PING") + + c1.Close() + if _, err := c1.Do("PING"); err == nil { + t.Errorf("expected error after connection closed") + } + + c2.Close() + c2.Close() + + p.Close() + + d.check("after pool close", p, 3, 1) + + if _, err := c1.Do("PING"); err == nil { + t.Errorf("expected error after connection and pool closed") + } + + c3.Close() + + d.check("after conn close", p, 3, 0) + + c1 = p.Get(context.TODO()) + if _, err := c1.Do("PING"); err == nil { + t.Errorf("expected error after pool closed") + } +} + +func TestPoolConcurrenSendReceive(t *testing.T) { + p := NewPool(testConfig) + p.Slice.New = func(ctx context.Context) (io.Closer, error) { + return DialDefaultServer() + } + defer p.Close() + + c := p.Get(context.TODO()) + done := make(chan error, 1) + go func() { + _, err := c.Receive() + done <- err + }() + c.Send("PING") + c.Flush() + err := <-done + if err != nil { + t.Fatalf("Receive() returned error %v", err) + } + _, err = c.Do("") + if err != nil { + t.Fatalf("Do() returned error %v", err) + } + c.Close() +} + +func TestPoolMaxActive(t *testing.T) { + d := poolDialer{t: t} + conf := getTestConfig(testRedisAddr) + conf.Config = &pool.Config{ + Active: 2, + Idle: 2, + } + p := NewPool(conf) + p.Slice.New = func(ctx context.Context) (io.Closer, error) { + return d.dial() + } + defer p.Close() + + c1 := p.Get(context.TODO()) + c1.Do("PING") + c2 := p.Get(context.TODO()) + c2.Do("PING") + + d.check("1", p, 2, 2) + + c3 := p.Get(context.TODO()) + if _, err := c3.Do("PING"); err != pool.ErrPoolExhausted { + t.Errorf("expected pool exhausted") + } + + c3.Close() + d.check("2", p, 2, 2) + c2.Close() + d.check("3", p, 2, 2) + + c3 = p.Get(context.TODO()) + if _, err := c3.Do("PING"); err != nil { + t.Errorf("expected good channel, err=%v", err) + } + c3.Close() + + d.check("4", p, 2, 2) +} + +func TestPoolMonitorCleanup(t *testing.T) { + d := poolDialer{t: t} + p := NewPool(testConfig) + p.Slice.New = func(ctx context.Context) (io.Closer, error) { + return d.dial() + } + defer p.Close() + c := p.Get(context.TODO()) + c.Send("MONITOR") + c.Close() + + d.check("", p, 1, 0) +} + +func TestPoolPubSubCleanup(t *testing.T) { + d := poolDialer{t: t} + p := NewPool(testConfig) + p.Slice.New = func(ctx context.Context) (io.Closer, error) { + return d.dial() + } + defer p.Close() + + c := p.Get(context.TODO()) + c.Send("SUBSCRIBE", "x") + c.Close() + + want := []string{"SUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"} + if !reflect.DeepEqual(d.commands, want) { + t.Errorf("got commands %v, want %v", d.commands, want) + } + d.commands = nil + + c = p.Get(context.TODO()) + c.Send("PSUBSCRIBE", "x*") + c.Close() + + want = []string{"PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"} + if !reflect.DeepEqual(d.commands, want) { + t.Errorf("got commands %v, want %v", d.commands, want) + } + d.commands = nil +} + +func TestPoolTransactionCleanup(t *testing.T) { + d := poolDialer{t: t} + p := NewPool(testConfig) + p.Slice.New = func(ctx context.Context) (io.Closer, error) { + return d.dial() + } + defer p.Close() + + c := p.Get(context.TODO()) + c.Do("WATCH", "key") + c.Do("PING") + c.Close() + + want := []string{"WATCH", "PING", "UNWATCH"} + if !reflect.DeepEqual(d.commands, want) { + t.Errorf("got commands %v, want %v", d.commands, want) + } + d.commands = nil + + c = p.Get(context.TODO()) + c.Do("WATCH", "key") + c.Do("UNWATCH") + c.Do("PING") + c.Close() + + want = []string{"WATCH", "UNWATCH", "PING"} + if !reflect.DeepEqual(d.commands, want) { + t.Errorf("got commands %v, want %v", d.commands, want) + } + d.commands = nil + + c = p.Get(context.TODO()) + c.Do("WATCH", "key") + c.Do("MULTI") + c.Do("PING") + c.Close() + + want = []string{"WATCH", "MULTI", "PING", "DISCARD"} + if !reflect.DeepEqual(d.commands, want) { + t.Errorf("got commands %v, want %v", d.commands, want) + } + d.commands = nil + + c = p.Get(context.TODO()) + c.Do("WATCH", "key") + c.Do("MULTI") + c.Do("DISCARD") + c.Do("PING") + c.Close() + + want = []string{"WATCH", "MULTI", "DISCARD", "PING"} + if !reflect.DeepEqual(d.commands, want) { + t.Errorf("got commands %v, want %v", d.commands, want) + } + d.commands = nil + + c = p.Get(context.TODO()) + c.Do("WATCH", "key") + c.Do("MULTI") + c.Do("EXEC") + c.Do("PING") + c.Close() + + want = []string{"WATCH", "MULTI", "EXEC", "PING"} + if !reflect.DeepEqual(d.commands, want) { + t.Errorf("got commands %v, want %v", d.commands, want) + } + d.commands = nil +} + +func startGoroutines(p *Pool, cmd string, args ...interface{}) chan error { + errs := make(chan error, 10) + for i := 0; i < cap(errs); i++ { + go func() { + c := p.Get(context.TODO()) + _, err := c.Do(cmd, args...) + errs <- err + c.Close() + }() + } + + // Wait for goroutines to block. + time.Sleep(time.Second / 4) + + return errs +} + +func TestWaitPoolDialError(t *testing.T) { + testErr := errors.New("test") + d := poolDialer{t: t} + config1 := testConfig + config1.Config = &pool.Config{ + Active: 1, + Idle: 1, + Wait: true, + } + p := NewPool(config1) + p.Slice.New = func(ctx context.Context) (io.Closer, error) { + return d.dial() + } + defer p.Close() + + c := p.Get(context.TODO()) + errs := startGoroutines(p, "ERR", testErr) + d.check("before close", p, 1, 1) + + d.dialErr = errors.New("dial") + c.Close() + + nilCount := 0 + errCount := 0 + timeout := time.After(2 * time.Second) + for i := 0; i < cap(errs); i++ { + select { + case err := <-errs: + switch err { + case nil: + nilCount++ + case d.dialErr: + errCount++ + default: + t.Fatalf("expected dial error or nil, got %v", err) + } + case <-timeout: + t.Logf("Wait all the time and timeout %d", i) + return + } + } + if nilCount != 1 { + t.Errorf("expected one nil error, got %d", nilCount) + } + if errCount != cap(errs)-1 { + t.Errorf("expected %d dial erors, got %d", cap(errs)-1, errCount) + } + d.check("done", p, cap(errs), 0) +} + +func BenchmarkPoolGet(b *testing.B) { + b.StopTimer() + p := NewPool(testConfig) + c := p.Get(context.Background()) + if err := c.Err(); err != nil { + b.Fatal(err) + } + c.Close() + defer p.Close() + b.StartTimer() + for i := 0; i < b.N; i++ { + c := p.Get(context.Background()) + c.Close() + } +} + +func BenchmarkPoolGetErr(b *testing.B) { + b.StopTimer() + p := NewPool(testConfig) + c := p.Get(context.Background()) + if err := c.Err(); err != nil { + b.Fatal(err) + } + c.Close() + defer p.Close() + b.StartTimer() + for i := 0; i < b.N; i++ { + c = p.Get(context.Background()) + if err := c.Err(); err != nil { + b.Fatal(err) + } + c.Close() + } +} + +func BenchmarkPoolGetPing(b *testing.B) { + b.StopTimer() + p := NewPool(testConfig) + c := p.Get(context.Background()) + if err := c.Err(); err != nil { + b.Fatal(err) + } + c.Close() + defer p.Close() + b.StartTimer() + for i := 0; i < b.N; i++ { + c := p.Get(context.Background()) + if _, err := c.Do("PING"); err != nil { + b.Fatal(err) + } + c.Close() + } +} + +func BenchmarkPooledConn(b *testing.B) { + p := NewPool(testConfig) + defer p.Close() + for i := 0; i < b.N; i++ { + ctx := context.TODO() + c := p.Get(ctx) + c2 := c.WithContext(context.TODO()) + if _, err := c2.Do("PING"); err != nil { + b.Fatal(err) + } + c2.Close() + } +} diff --git a/pkg/cache/redis/pubsub_test.go b/pkg/cache/redis/pubsub_test.go new file mode 100644 index 000000000..69c66ffd8 --- /dev/null +++ b/pkg/cache/redis/pubsub_test.go @@ -0,0 +1,146 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import ( + "fmt" + "reflect" + "sync" + "testing" +) + +func publish(channel, value interface{}) { + c, err := dial() + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + c.Do("PUBLISH", channel, value) +} + +// Applications can receive pushed messages from one goroutine and manage subscriptions from another goroutine. +func ExamplePubSubConn() { + c, err := dial() + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + var wg sync.WaitGroup + wg.Add(2) + + psc := PubSubConn{Conn: c} + + // This goroutine receives and prints pushed notifications from the server. + // The goroutine exits when the connection is unsubscribed from all + // channels or there is an error. + go func() { + defer wg.Done() + for { + switch n := psc.Receive().(type) { + case Message: + fmt.Printf("Message: %s %s\n", n.Channel, n.Data) + case PMessage: + fmt.Printf("PMessage: %s %s %s\n", n.Pattern, n.Channel, n.Data) + case Subscription: + fmt.Printf("Subscription: %s %s %d\n", n.Kind, n.Channel, n.Count) + if n.Count == 0 { + return + } + case error: + fmt.Printf("error: %v\n", n) + return + } + } + }() + + // This goroutine manages subscriptions for the connection. + go func() { + defer wg.Done() + + psc.Subscribe("example") + psc.PSubscribe("p*") + + // The following function calls publish a message using another + // connection to the Redis server. + publish("example", "hello") + publish("example", "world") + publish("pexample", "foo") + publish("pexample", "bar") + + // Unsubscribe from all connections. This will cause the receiving + // goroutine to exit. + psc.Unsubscribe() + psc.PUnsubscribe() + }() + + wg.Wait() + + // Output: + // Subscription: subscribe example 1 + // Subscription: psubscribe p* 2 + // Message: example hello + // Message: example world + // PMessage: p* pexample foo + // PMessage: p* pexample bar + // Subscription: unsubscribe example 1 + // Subscription: punsubscribe p* 0 +} + +func expectPushed(t *testing.T, c PubSubConn, message string, expected interface{}) { + actual := c.Receive() + if !reflect.DeepEqual(actual, expected) { + t.Errorf("%s = %v, want %v", message, actual, expected) + } +} + +func TestPushed(t *testing.T) { + pc, err := DialDefaultServer() + if err != nil { + t.Fatalf("error connection to database, %v", err) + } + defer pc.Close() + + sc, err := DialDefaultServer() + if err != nil { + t.Fatalf("error connection to database, %v", err) + } + defer sc.Close() + + c := PubSubConn{Conn: sc} + + c.Subscribe("c1") + expectPushed(t, c, "Subscribe(c1)", Subscription{Kind: "subscribe", Channel: "c1", Count: 1}) + c.Subscribe("c2") + expectPushed(t, c, "Subscribe(c2)", Subscription{Kind: "subscribe", Channel: "c2", Count: 2}) + c.PSubscribe("p1") + expectPushed(t, c, "PSubscribe(p1)", Subscription{Kind: "psubscribe", Channel: "p1", Count: 3}) + c.PSubscribe("p2") + expectPushed(t, c, "PSubscribe(p2)", Subscription{Kind: "psubscribe", Channel: "p2", Count: 4}) + c.PUnsubscribe() + expectPushed(t, c, "Punsubscribe(p1)", Subscription{Kind: "punsubscribe", Channel: "p1", Count: 3}) + expectPushed(t, c, "Punsubscribe()", Subscription{Kind: "punsubscribe", Channel: "p2", Count: 2}) + + pc.Do("PUBLISH", "c1", "hello") + expectPushed(t, c, "PUBLISH c1 hello", Message{Channel: "c1", Data: []byte("hello")}) + + c.Ping("hello") + expectPushed(t, c, `Ping("hello")`, Pong{"hello"}) + + c.Conn.Send("PING") + c.Conn.Flush() + expectPushed(t, c, `Send("PING")`, Pong{}) +} diff --git a/pkg/cache/redis/redis.go b/pkg/cache/redis/redis.go index 638cd9ab2..f912372b0 100644 --- a/pkg/cache/redis/redis.go +++ b/pkg/cache/redis/redis.go @@ -16,6 +16,9 @@ package redis import ( "context" + + "github.com/bilibili/kratos/pkg/container/pool" + xtime "github.com/bilibili/kratos/pkg/time" ) // Error represents an error returned in a command reply. @@ -23,29 +26,53 @@ type Error string func (err Error) Error() string { return string(err) } -// Conn represents a connection to a Redis server. -type Conn interface { - // Close closes the connection. - Close() error +// Config client settings. +type Config struct { + *pool.Config + + Name string // redis name, for trace + Proto string + Addr string + Auth string + DialTimeout xtime.Duration + ReadTimeout xtime.Duration + WriteTimeout xtime.Duration + SlowLog xtime.Duration +} - // Err returns a non-nil value if the connection is broken. The returned - // value is either the first non-nil value returned from the underlying - // network connection or a protocol parsing error. Applications should - // close broken connections. - Err() error +type Redis struct { + pool *Pool + conf *Config +} - // Do sends a command to the server and returns the received reply. - Do(commandName string, args ...interface{}) (reply interface{}, err error) +func NewRedis(c *Config, options ...DialOption) *Redis { + return &Redis{ + pool: NewPool(c, options...), + conf: c, + } +} - // Send writes the command to the client's output buffer. - Send(commandName string, args ...interface{}) error +// Do gets a new conn from pool, then execute Do with this conn, finally close this conn. +// ATTENTION: Don't use this method with transaction command like MULTI etc. Because every Do will close conn automatically, use r.Conn to get a raw conn for this situation. +func (r *Redis) Do(ctx context.Context, commandName string, args ...interface{}) (reply interface{}, err error) { + conn := r.pool.Get(ctx) + defer conn.Close() + reply, err = conn.Do(commandName, args...) + return +} - // Flush flushes the output buffer to the Redis server. - Flush() error +// Close closes connection pool +func (r *Redis) Close() error { + return r.pool.Close() +} - // Receive receives a single reply from the Redis server - Receive() (reply interface{}, err error) +// Conn direct gets a connection +func (r *Redis) Conn(ctx context.Context) Conn { + return r.pool.Get(ctx) +} - // WithContext - WithContext(ctx context.Context) Conn +func (r *Redis) Pipeline() (p Pipeliner) { + return &pipeliner{ + pool: r.pool, + } } diff --git a/pkg/cache/redis/redis_test.go b/pkg/cache/redis/redis_test.go new file mode 100644 index 000000000..464cbff8f --- /dev/null +++ b/pkg/cache/redis/redis_test.go @@ -0,0 +1,324 @@ +package redis + +import ( + "context" + "reflect" + "testing" + "time" + + "github.com/bilibili/kratos/pkg/container/pool" + xtime "github.com/bilibili/kratos/pkg/time" +) + +func TestRedis(t *testing.T) { + testSet(t, testPool) + testSend(t, testPool) + testGet(t, testPool) + testErr(t, testPool) + if err := testPool.Close(); err != nil { + t.Errorf("redis: close error(%v)", err) + } + conn, err := NewConn(testConfig) + if err != nil { + t.Errorf("redis: new conn error(%v)", err) + } + if err := conn.Close(); err != nil { + t.Errorf("redis: close error(%v)", err) + } +} + +func testSet(t *testing.T, p *Pool) { + var ( + key = "test" + value = "test" + conn = p.Get(context.TODO()) + ) + defer conn.Close() + if reply, err := conn.Do("set", key, value); err != nil { + t.Errorf("redis: conn.Do(SET, %s, %s) error(%v)", key, value, err) + } else { + t.Logf("redis: set status: %s", reply) + } +} + +func testSend(t *testing.T, p *Pool) { + var ( + key = "test" + value = "test" + expire = 1000 + conn = p.Get(context.TODO()) + ) + defer conn.Close() + if err := conn.Send("SET", key, value); err != nil { + t.Errorf("redis: conn.Send(SET, %s, %s) error(%v)", key, value, err) + } + if err := conn.Send("EXPIRE", key, expire); err != nil { + t.Errorf("redis: conn.Send(EXPIRE key(%s) expire(%d)) error(%v)", key, expire, err) + } + if err := conn.Flush(); err != nil { + t.Errorf("redis: conn.Flush error(%v)", err) + } + for i := 0; i < 2; i++ { + if _, err := conn.Receive(); err != nil { + t.Errorf("redis: conn.Receive error(%v)", err) + return + } + } + t.Logf("redis: set value: %s", value) +} + +func testGet(t *testing.T, p *Pool) { + var ( + key = "test" + conn = p.Get(context.TODO()) + ) + defer conn.Close() + if reply, err := conn.Do("GET", key); err != nil { + t.Errorf("redis: conn.Do(GET, %s) error(%v)", key, err) + } else { + t.Logf("redis: get value: %s", reply) + } +} + +func testErr(t *testing.T, p *Pool) { + conn := p.Get(context.TODO()) + if err := conn.Close(); err != nil { + t.Errorf("redis: close error(%v)", err) + } + if err := conn.Err(); err == nil { + t.Errorf("redis: err not nil") + } else { + t.Logf("redis: err: %v", err) + } +} + +func BenchmarkRedis(b *testing.B) { + conf := &Config{ + Name: "test", + Proto: "tcp", + Addr: testRedisAddr, + DialTimeout: xtime.Duration(time.Second), + ReadTimeout: xtime.Duration(time.Second), + WriteTimeout: xtime.Duration(time.Second), + } + conf.Config = &pool.Config{ + Active: 10, + Idle: 5, + IdleTimeout: xtime.Duration(90 * time.Second), + } + benchmarkPool := NewPool(conf) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn := benchmarkPool.Get(context.TODO()) + if err := conn.Close(); err != nil { + b.Errorf("redis: close error(%v)", err) + } + } + }) + if err := benchmarkPool.Close(); err != nil { + b.Errorf("redis: close error(%v)", err) + } +} + +var testRedisCommands = []struct { + args []interface{} + expected interface{} +}{ + { + []interface{}{"PING"}, + "PONG", + }, + { + []interface{}{"SET", "foo", "bar"}, + "OK", + }, + { + []interface{}{"GET", "foo"}, + []byte("bar"), + }, + { + []interface{}{"GET", "nokey"}, + nil, + }, + { + []interface{}{"MGET", "nokey", "foo"}, + []interface{}{nil, []byte("bar")}, + }, + { + []interface{}{"INCR", "mycounter"}, + int64(1), + }, + { + []interface{}{"LPUSH", "mylist", "foo"}, + int64(1), + }, + { + []interface{}{"LPUSH", "mylist", "bar"}, + int64(2), + }, + { + []interface{}{"LRANGE", "mylist", 0, -1}, + []interface{}{[]byte("bar"), []byte("foo")}, + }, +} + +func TestNewRedis(t *testing.T) { + type args struct { + c *Config + options []DialOption + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "new_redis", + args{ + testConfig, + make([]DialOption, 0), + }, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := NewRedis(tt.args.c, tt.args.options...) + if r == nil { + t.Errorf("NewRedis() error, got nil") + return + } + err := r.Close() + if err != nil { + t.Errorf("Close() error %v", err) + } + }) + } +} + +func TestRedis_Do(t *testing.T) { + r := NewRedis(testConfig) + r.Do(context.TODO(), "FLUSHDB") + + for _, cmd := range testRedisCommands { + actual, err := r.Do(context.TODO(), cmd.args[0].(string), cmd.args[1:]...) + if err != nil { + t.Errorf("Do(%v) returned error %v", cmd.args, err) + continue + } + if !reflect.DeepEqual(actual, cmd.expected) { + t.Errorf("Do(%v) = %v, want %v", cmd.args, actual, cmd.expected) + } + } + err := r.Close() + if err != nil { + t.Errorf("Close() error %v", err) + } +} + +func TestRedis_Conn(t *testing.T) { + + type args struct { + ctx context.Context + } + tests := []struct { + name string + p *Redis + args args + wantErr bool + g int + c int + }{ + { + "Close", + NewRedis(&Config{ + Config: &pool.Config{ + Active: 1, + Idle: 1, + }, + Name: "test_get", + Proto: "tcp", + Addr: testRedisAddr, + DialTimeout: xtime.Duration(time.Second), + ReadTimeout: xtime.Duration(time.Second), + WriteTimeout: xtime.Duration(time.Second), + }), + args{context.TODO()}, + false, + 3, + 3, + }, + { + "CloseExceededPoolSize", + NewRedis(&Config{ + Config: &pool.Config{ + Active: 1, + Idle: 1, + }, + Name: "test_get_out", + Proto: "tcp", + Addr: testRedisAddr, + DialTimeout: xtime.Duration(time.Second), + ReadTimeout: xtime.Duration(time.Second), + WriteTimeout: xtime.Duration(time.Second), + }), + args{context.TODO()}, + true, + 5, + 3, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for i := 1; i <= tt.g; i++ { + got := tt.p.Conn(tt.args.ctx) + if err := got.Close(); err != nil { + if !tt.wantErr { + t.Error(err) + } + } + if i <= tt.c { + if err := got.Close(); err != nil { + t.Error(err) + } + } + } + }) + } +} + +func BenchmarkRedisDoPing(b *testing.B) { + r := NewRedis(testConfig) + defer r.Close() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := r.Do(context.Background(), "PING"); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkRedisDoSET(b *testing.B) { + r := NewRedis(testConfig) + defer r.Close() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := r.Do(context.Background(), "SET", "a", "b"); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkRedisDoGET(b *testing.B) { + r := NewRedis(testConfig) + defer r.Close() + r.Do(context.Background(), "SET", "a", "b") + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := r.Do(context.Background(), "GET", "b"); err != nil { + b.Fatal(err) + } + } +} diff --git a/pkg/cache/redis/reply_test.go b/pkg/cache/redis/reply_test.go new file mode 100644 index 000000000..d3b1b9551 --- /dev/null +++ b/pkg/cache/redis/reply_test.go @@ -0,0 +1,179 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import ( + "fmt" + "reflect" + "testing" + + "github.com/pkg/errors" +) + +type valueError struct { + v interface{} + err error +} + +func ve(v interface{}, err error) valueError { + return valueError{v, err} +} + +var replyTests = []struct { + name interface{} + actual valueError + expected valueError +}{ + { + "ints([v1, v2])", + ve(Ints([]interface{}{[]byte("4"), []byte("5")}, nil)), + ve([]int{4, 5}, nil), + }, + { + "ints(nil)", + ve(Ints(nil, nil)), + ve([]int(nil), ErrNil), + }, + { + "strings([v1, v2])", + ve(Strings([]interface{}{[]byte("v1"), []byte("v2")}, nil)), + ve([]string{"v1", "v2"}, nil), + }, + { + "strings(nil)", + ve(Strings(nil, nil)), + ve([]string(nil), ErrNil), + }, + { + "byteslices([v1, v2])", + ve(ByteSlices([]interface{}{[]byte("v1"), []byte("v2")}, nil)), + ve([][]byte{[]byte("v1"), []byte("v2")}, nil), + }, + { + "byteslices(nil)", + ve(ByteSlices(nil, nil)), + ve([][]byte(nil), ErrNil), + }, + { + "values([v1, v2])", + ve(Values([]interface{}{[]byte("v1"), []byte("v2")}, nil)), + ve([]interface{}{[]byte("v1"), []byte("v2")}, nil), + }, + { + "values(nil)", + ve(Values(nil, nil)), + ve([]interface{}(nil), ErrNil), + }, + { + "float64(1.0)", + ve(Float64([]byte("1.0"), nil)), + ve(float64(1.0), nil), + }, + { + "float64(nil)", + ve(Float64(nil, nil)), + ve(float64(0.0), ErrNil), + }, + { + "uint64(1)", + ve(Uint64(int64(1), nil)), + ve(uint64(1), nil), + }, + { + "uint64(-1)", + ve(Uint64(int64(-1), nil)), + ve(uint64(0), errNegativeInt), + }, +} + +func TestReply(t *testing.T) { + for _, rt := range replyTests { + if errors.Cause(rt.actual.err) != rt.expected.err { + t.Errorf("%s returned err %v, want %v", rt.name, rt.actual.err, rt.expected.err) + continue + } + if !reflect.DeepEqual(rt.actual.v, rt.expected.v) { + t.Errorf("%s=%+v, want %+v", rt.name, rt.actual.v, rt.expected.v) + } + } +} + +// dial wraps DialDefaultServer() with a more suitable function name for examples. +func dial() (Conn, error) { + return DialDefaultServer() +} + +func ExampleBool() { + c, err := dial() + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + + c.Do("SET", "foo", 1) + exists, _ := Bool(c.Do("EXISTS", "foo")) + fmt.Printf("%#v\n", exists) + // Output: + // true +} + +func ExampleInt() { + c, err := dial() + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + + c.Do("SET", "k1", 1) + n, _ := Int(c.Do("GET", "k1")) + fmt.Printf("%#v\n", n) + n, _ = Int(c.Do("INCR", "k1")) + fmt.Printf("%#v\n", n) + // Output: + // 1 + // 2 +} + +func ExampleInts() { + c, err := dial() + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + + c.Do("SADD", "set_with_integers", 4, 5, 6) + ints, _ := Ints(c.Do("SMEMBERS", "set_with_integers")) + fmt.Printf("%#v\n", ints) + // Output: + // []int{4, 5, 6} +} + +func ExampleString() { + c, err := dial() + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + + c.Do("SET", "hello", "world") + s, _ := String(c.Do("GET", "hello")) + fmt.Printf("%#v", s) + // Output: + // "world" +} diff --git a/pkg/cache/redis/scan_test.go b/pkg/cache/redis/scan_test.go new file mode 100644 index 000000000..fba605d77 --- /dev/null +++ b/pkg/cache/redis/scan_test.go @@ -0,0 +1,438 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import ( + "fmt" + "math" + "reflect" + "testing" +) + +var scanConversionTests = []struct { + src interface{} + dest interface{} +}{ + {[]byte("-inf"), math.Inf(-1)}, + {[]byte("+inf"), math.Inf(1)}, + {[]byte("0"), float64(0)}, + {[]byte("3.14159"), float64(3.14159)}, + {[]byte("3.14"), float32(3.14)}, + {[]byte("-100"), int(-100)}, + {[]byte("101"), int(101)}, + {int64(102), int(102)}, + {[]byte("103"), uint(103)}, + {int64(104), uint(104)}, + {[]byte("105"), int8(105)}, + {int64(106), int8(106)}, + {[]byte("107"), uint8(107)}, + {int64(108), uint8(108)}, + {[]byte("0"), false}, + {int64(0), false}, + {[]byte("f"), false}, + {[]byte("1"), true}, + {int64(1), true}, + {[]byte("t"), true}, + {"hello", "hello"}, + {[]byte("hello"), "hello"}, + {[]byte("world"), []byte("world")}, + {[]interface{}{[]byte("foo")}, []interface{}{[]byte("foo")}}, + {[]interface{}{[]byte("foo")}, []string{"foo"}}, + {[]interface{}{[]byte("hello"), []byte("world")}, []string{"hello", "world"}}, + {[]interface{}{[]byte("bar")}, [][]byte{[]byte("bar")}}, + {[]interface{}{[]byte("1")}, []int{1}}, + {[]interface{}{[]byte("1"), []byte("2")}, []int{1, 2}}, + {[]interface{}{[]byte("1"), []byte("2")}, []float64{1, 2}}, + {[]interface{}{[]byte("1")}, []byte{1}}, + {[]interface{}{[]byte("1")}, []bool{true}}, +} + +func TestScanConversion(t *testing.T) { + for _, tt := range scanConversionTests { + values := []interface{}{tt.src} + dest := reflect.New(reflect.TypeOf(tt.dest)) + values, err := Scan(values, dest.Interface()) + if err != nil { + t.Errorf("Scan(%v) returned error %v", tt, err) + continue + } + if !reflect.DeepEqual(tt.dest, dest.Elem().Interface()) { + t.Errorf("Scan(%v) returned %v values: %v, want %v", tt, dest.Elem().Interface(), values, tt.dest) + } + } +} + +var scanConversionErrorTests = []struct { + src interface{} + dest interface{} +}{ + {[]byte("1234"), byte(0)}, + {int64(1234), byte(0)}, + {[]byte("-1"), byte(0)}, + {int64(-1), byte(0)}, + {[]byte("junk"), false}, + {Error("blah"), false}, +} + +func TestScanConversionError(t *testing.T) { + for _, tt := range scanConversionErrorTests { + values := []interface{}{tt.src} + dest := reflect.New(reflect.TypeOf(tt.dest)) + values, err := Scan(values, dest.Interface()) + if err == nil { + t.Errorf("Scan(%v) did not return error values: %v", tt, values) + } + } +} + +func ExampleScan() { + c, err := dial() + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + + c.Send("HMSET", "album:1", "title", "Red", "rating", 5) + c.Send("HMSET", "album:2", "title", "Earthbound", "rating", 1) + c.Send("HMSET", "album:3", "title", "Beat") + c.Send("LPUSH", "albums", "1") + c.Send("LPUSH", "albums", "2") + c.Send("LPUSH", "albums", "3") + values, err := Values(c.Do("SORT", "albums", + "BY", "album:*->rating", + "GET", "album:*->title", + "GET", "album:*->rating")) + if err != nil { + fmt.Println(err) + return + } + + for len(values) > 0 { + var title string + rating := -1 // initialize to illegal value to detect nil. + values, err = Scan(values, &title, &rating) + if err != nil { + fmt.Println(err) + return + } + if rating == -1 { + fmt.Println(title, "not-rated") + } else { + fmt.Println(title, rating) + } + } + // Output: + // Beat not-rated + // Earthbound 1 + // Red 5 +} + +type s0 struct { + X int + Y int `redis:"y"` + Bt bool +} + +type s1 struct { + X int `redis:"-"` + I int `redis:"i"` + U uint `redis:"u"` + S string `redis:"s"` + P []byte `redis:"p"` + B bool `redis:"b"` + Bt bool + Bf bool + s0 +} + +var scanStructTests = []struct { + title string + reply []string + value interface{} +}{ + {"basic", + []string{"i", "-1234", "u", "5678", "s", "hello", "p", "world", "b", "t", "Bt", "1", "Bf", "0", "X", "123", "y", "456"}, + &s1{I: -1234, U: 5678, S: "hello", P: []byte("world"), B: true, Bt: true, Bf: false, s0: s0{X: 123, Y: 456}}, + }, +} + +func TestScanStruct(t *testing.T) { + for _, tt := range scanStructTests { + + var reply []interface{} + for _, v := range tt.reply { + reply = append(reply, []byte(v)) + } + + value := reflect.New(reflect.ValueOf(tt.value).Type().Elem()) + + if err := ScanStruct(reply, value.Interface()); err != nil { + t.Fatalf("ScanStruct(%s) returned error %v", tt.title, err) + } + + if !reflect.DeepEqual(value.Interface(), tt.value) { + t.Fatalf("ScanStruct(%s) returned %v, want %v", tt.title, value.Interface(), tt.value) + } + } +} + +func TestBadScanStructArgs(t *testing.T) { + x := []interface{}{"A", "b"} + test := func(v interface{}) { + if err := ScanStruct(x, v); err == nil { + t.Errorf("Expect error for ScanStruct(%T, %T)", x, v) + } + } + + test(nil) + + var v0 *struct{} + test(v0) + + var v1 int + test(&v1) + + x = x[:1] + v2 := struct{ A string }{} + test(&v2) +} + +var scanSliceTests = []struct { + src []interface{} + fieldNames []string + ok bool + dest interface{} +}{ + { + []interface{}{[]byte("1"), nil, []byte("-1")}, + nil, + true, + []int{1, 0, -1}, + }, + { + []interface{}{[]byte("1"), nil, []byte("2")}, + nil, + true, + []uint{1, 0, 2}, + }, + { + []interface{}{[]byte("-1")}, + nil, + false, + []uint{1}, + }, + { + []interface{}{[]byte("hello"), nil, []byte("world")}, + nil, + true, + [][]byte{[]byte("hello"), nil, []byte("world")}, + }, + { + []interface{}{[]byte("hello"), nil, []byte("world")}, + nil, + true, + []string{"hello", "", "world"}, + }, + { + []interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")}, + nil, + true, + []struct{ A, B string }{{"a1", "b1"}, {"a2", "b2"}}, + }, + { + []interface{}{[]byte("a1"), []byte("b1")}, + nil, + false, + []struct{ A, B, C string }{{"a1", "b1", ""}}, + }, + { + []interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")}, + nil, + true, + []*struct{ A, B string }{{"a1", "b1"}, {"a2", "b2"}}, + }, + { + []interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")}, + []string{"A", "B"}, + true, + []struct{ A, C, B string }{{"a1", "", "b1"}, {"a2", "", "b2"}}, + }, + { + []interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")}, + nil, + false, + []struct{}{}, + }, +} + +func TestScanSlice(t *testing.T) { + for _, tt := range scanSliceTests { + + typ := reflect.ValueOf(tt.dest).Type() + dest := reflect.New(typ) + + err := ScanSlice(tt.src, dest.Interface(), tt.fieldNames...) + if tt.ok != (err == nil) { + t.Errorf("ScanSlice(%v, []%s, %v) returned error %v", tt.src, typ, tt.fieldNames, err) + continue + } + if tt.ok && !reflect.DeepEqual(dest.Elem().Interface(), tt.dest) { + t.Errorf("ScanSlice(src, []%s) returned %#v, want %#v", typ, dest.Elem().Interface(), tt.dest) + } + } +} + +func ExampleScanSlice() { + c, err := dial() + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + + c.Send("HMSET", "album:1", "title", "Red", "rating", 5) + c.Send("HMSET", "album:2", "title", "Earthbound", "rating", 1) + c.Send("HMSET", "album:3", "title", "Beat", "rating", 4) + c.Send("LPUSH", "albums", "1") + c.Send("LPUSH", "albums", "2") + c.Send("LPUSH", "albums", "3") + values, err := Values(c.Do("SORT", "albums", + "BY", "album:*->rating", + "GET", "album:*->title", + "GET", "album:*->rating")) + if err != nil { + fmt.Println(err) + return + } + + var albums []struct { + Title string + Rating int + } + if err := ScanSlice(values, &albums); err != nil { + fmt.Println(err) + return + } + fmt.Printf("%v\n", albums) + // Output: + // [{Earthbound 1} {Beat 4} {Red 5}] +} + +var argsTests = []struct { + title string + actual Args + expected Args +}{ + {"struct ptr", + Args{}.AddFlat(&struct { + I int `redis:"i"` + U uint `redis:"u"` + S string `redis:"s"` + P []byte `redis:"p"` + M map[string]string `redis:"m"` + Bt bool + Bf bool + }{ + -1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false, + }), + Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "m", map[string]string{"hello": "world"}, "Bt", true, "Bf", false}, + }, + {"struct", + Args{}.AddFlat(struct{ I int }{123}), + Args{"I", 123}, + }, + {"slice", + Args{}.Add(1).AddFlat([]string{"a", "b", "c"}).Add(2), + Args{1, "a", "b", "c", 2}, + }, + {"struct omitempty", + Args{}.AddFlat(&struct { + I int `redis:"i,omitempty"` + U uint `redis:"u,omitempty"` + S string `redis:"s,omitempty"` + P []byte `redis:"p,omitempty"` + M map[string]string `redis:"m,omitempty"` + Bt bool `redis:"Bt,omitempty"` + Bf bool `redis:"Bf,omitempty"` + }{ + 0, 0, "", []byte{}, map[string]string{}, true, false, + }), + Args{"Bt", true}, + }, +} + +func TestArgs(t *testing.T) { + for _, tt := range argsTests { + if !reflect.DeepEqual(tt.actual, tt.expected) { + t.Fatalf("%s is %v, want %v", tt.title, tt.actual, tt.expected) + } + } +} + +func ExampleArgs() { + c, err := dial() + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + + var p1, p2 struct { + Title string `redis:"title"` + Author string `redis:"author"` + Body string `redis:"body"` + } + + p1.Title = "Example" + p1.Author = "Gary" + p1.Body = "Hello" + + if _, err := c.Do("HMSET", Args{}.Add("id1").AddFlat(&p1)...); err != nil { + fmt.Println(err) + return + } + + m := map[string]string{ + "title": "Example2", + "author": "Steve", + "body": "Map", + } + + if _, err := c.Do("HMSET", Args{}.Add("id2").AddFlat(m)...); err != nil { + fmt.Println(err) + return + } + + for _, id := range []string{"id1", "id2"} { + + v, err := Values(c.Do("HGETALL", id)) + if err != nil { + fmt.Println(err) + return + } + + if err := ScanStruct(v, &p2); err != nil { + fmt.Println(err) + return + } + + fmt.Printf("%+v\n", p2) + } + + // Output: + // {Title:Example Author:Gary Body:Hello} + // {Title:Example2 Author:Steve Body:Map} +} diff --git a/pkg/cache/redis/script_test.go b/pkg/cache/redis/script_test.go new file mode 100644 index 000000000..405a33128 --- /dev/null +++ b/pkg/cache/redis/script_test.go @@ -0,0 +1,103 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package redis + +import ( + "fmt" + "reflect" + "testing" + "time" +) + +func ExampleScript() { + c, err := Dial("tcp", ":6379") + if err != nil { + // handle error + } + defer c.Close() + // Initialize a package-level variable with a script. + var getScript = NewScript(1, `return call('get', KEYS[1])`) + + // In a function, use the script Do method to evaluate the script. The Do + // method optimistically uses the EVALSHA command. If the script is not + // loaded, then the Do method falls back to the EVAL command. + if _, err = getScript.Do(c, "foo"); err != nil { + // handle error + } +} + +func TestScript(t *testing.T) { + c, err := DialDefaultServer() + if err != nil { + t.Fatalf("error connection to database, %v", err) + } + defer c.Close() + + // To test fall back in Do, we make script unique by adding comment with current time. + script := fmt.Sprintf("--%d\nreturn {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", time.Now().UnixNano()) + s := NewScript(2, script) + reply := []interface{}{[]byte("key1"), []byte("key2"), []byte("arg1"), []byte("arg2")} + + v, err := s.Do(c, "key1", "key2", "arg1", "arg2") + if err != nil { + t.Errorf("s.Do(c, ...) returned %v", err) + } + + if !reflect.DeepEqual(v, reply) { + t.Errorf("s.Do(c, ..); = %v, want %v", v, reply) + } + + err = s.Load(c) + if err != nil { + t.Errorf("s.Load(c) returned %v", err) + } + + err = s.SendHash(c, "key1", "key2", "arg1", "arg2") + if err != nil { + t.Errorf("s.SendHash(c, ...) returned %v", err) + } + + err = c.Flush() + if err != nil { + t.Errorf("c.Flush() returned %v", err) + } + + v, err = c.Receive() + if err != nil { + t.Errorf("c.Receive() returned %v", err) + } + if !reflect.DeepEqual(v, reply) { + t.Errorf("s.SendHash(c, ..); c.Receive() = %v, want %v", v, reply) + } + + err = s.Send(c, "key1", "key2", "arg1", "arg2") + if err != nil { + t.Errorf("s.Send(c, ...) returned %v", err) + } + + err = c.Flush() + if err != nil { + t.Errorf("c.Flush() returned %v", err) + } + + v, err = c.Receive() + if err != nil { + t.Errorf("c.Receive() returned %v", err) + } + if !reflect.DeepEqual(v, reply) { + t.Errorf("s.Send(c, ..); c.Receive() = %v, want %v", v, reply) + } + +} diff --git a/pkg/cache/redis/test/docker-compose.yaml b/pkg/cache/redis/test/docker-compose.yaml new file mode 100644 index 000000000..4bb1f4552 --- /dev/null +++ b/pkg/cache/redis/test/docker-compose.yaml @@ -0,0 +1,12 @@ +version: "3.7" + +services: + redis: + image: redis + ports: + - 6379:6379 + healthcheck: + test: ["CMD", "redis-cli","ping"] + interval: 20s + timeout: 1s + retries: 20 \ No newline at end of file diff --git a/pkg/cache/redis/trace.go b/pkg/cache/redis/trace.go index 3804f937d..b782d309d 100644 --- a/pkg/cache/redis/trace.go +++ b/pkg/cache/redis/trace.go @@ -10,10 +10,9 @@ import ( ) const ( - _traceComponentName = "pkg/cache/redis" + _traceComponentName = "library/cache/redis" _tracePeerService = "redis" _traceSpanKind = "client" - _slowLogDuration = time.Millisecond * 250 ) var _internalTags = []trace.Tag{ @@ -24,26 +23,28 @@ var _internalTags = []trace.Tag{ type traceConn struct { // tr for pipeline, if tr != nil meaning on pipeline - tr trace.Trace - ctx context.Context + tr trace.Trace // connTag include e.g. ip,port connTags []trace.Tag + ctx context.Context + // origin redis conn Conn pending int + // TODO: split slow log from trace. + slowLogThreshold time.Duration } func (t *traceConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) { statement := getStatement(commandName, args...) - defer slowLog(statement, time.Now()) - root, ok := trace.FromContext(t.ctx) + defer t.slowLog(statement, time.Now()) // NOTE: ignored empty commandName // current sdk will Do empty command after pipeline finished - if !ok || commandName == "" { + if t.tr == nil || commandName == "" { return t.Conn.Do(commandName, args...) } - tr := root.Fork("", "Redis:"+commandName) + tr := t.tr.Fork("", "Redis:"+commandName) tr.SetTag(_internalTags...) tr.SetTag(t.connTags...) tr.SetTag(trace.TagString(trace.TagDBStatement, statement)) @@ -52,16 +53,15 @@ func (t *traceConn) Do(commandName string, args ...interface{}) (reply interface return } -func (t *traceConn) Send(commandName string, args ...interface{}) error { +func (t *traceConn) Send(commandName string, args ...interface{}) (err error) { statement := getStatement(commandName, args...) - defer slowLog(statement, time.Now()) + defer t.slowLog(statement, time.Now()) t.pending++ - root, ok := trace.FromContext(t.ctx) - if !ok { + if t.tr == nil { return t.Conn.Send(commandName, args...) } - if t.tr == nil { - t.tr = root.Fork("", "Redis:Pipeline") + if t.pending == 1 { + t.tr = t.tr.Fork("", "Redis:Pipeline") t.tr.SetTag(_internalTags...) t.tr.SetTag(t.connTags...) } @@ -69,8 +69,7 @@ func (t *traceConn) Send(commandName string, args ...interface{}) error { trace.Log(trace.LogEvent, "Send"), trace.Log("db.statement", statement), ) - err := t.Conn.Send(commandName, args...) - if err != nil { + if err = t.Conn.Send(commandName, args...); err != nil { t.tr.SetTag(trace.TagBool(trace.TagError, true)) t.tr.SetLog( trace.Log(trace.LogEvent, "Send Fail"), @@ -81,7 +80,7 @@ func (t *traceConn) Send(commandName string, args ...interface{}) error { } func (t *traceConn) Flush() error { - defer slowLog("Flush", time.Now()) + defer t.slowLog("Flush", time.Now()) if t.tr == nil { return t.Conn.Flush() } @@ -98,7 +97,7 @@ func (t *traceConn) Flush() error { } func (t *traceConn) Receive() (reply interface{}, err error) { - defer slowLog("Receive", time.Now()) + defer t.slowLog("Receive", time.Now()) if t.tr == nil { return t.Conn.Receive() } @@ -122,13 +121,16 @@ func (t *traceConn) Receive() (reply interface{}, err error) { } func (t *traceConn) WithContext(ctx context.Context) Conn { - t.ctx = ctx + t.Conn = t.Conn.WithContext(ctx) + if root, ok := trace.FromContext(ctx); ok { + t.tr = root + } return t } -func slowLog(statement string, now time.Time) { +func (t *traceConn) slowLog(statement string, now time.Time) { du := time.Since(now) - if du > _slowLogDuration { + if du > t.slowLogThreshold { log.Warn("%s slow log statement: %s time: %v", _tracePeerService, statement, du) } } diff --git a/pkg/cache/redis/trace_test.go b/pkg/cache/redis/trace_test.go new file mode 100644 index 000000000..181910342 --- /dev/null +++ b/pkg/cache/redis/trace_test.go @@ -0,0 +1,192 @@ +package redis + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/bilibili/kratos/pkg/net/trace" + "github.com/stretchr/testify/assert" +) + +const testTraceSlowLogThreshold = time.Duration(250 * time.Millisecond) + +type mockTrace struct { + tags []trace.Tag + logs []trace.LogField + perr *error + operationName string + finished bool +} + +func (m *mockTrace) Fork(serviceName string, operationName string) trace.Trace { + m.operationName = operationName + return m +} +func (m *mockTrace) Follow(serviceName string, operationName string) trace.Trace { + panic("not implemented") +} +func (m *mockTrace) Finish(err *error) { + m.perr = err + m.finished = true +} +func (m *mockTrace) SetTag(tags ...trace.Tag) trace.Trace { + m.tags = append(m.tags, tags...) + return m +} +func (m *mockTrace) SetLog(logs ...trace.LogField) trace.Trace { + m.logs = append(m.logs, logs...) + return m +} +func (m *mockTrace) Visit(fn func(k, v string)) {} +func (m *mockTrace) SetTitle(title string) {} +func (m *mockTrace) TraceID() string { return "" } + +type mockConn struct{} + +func (c *mockConn) Close() error { return nil } +func (c *mockConn) Err() error { return nil } +func (c *mockConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) { + return nil, nil +} +func (c *mockConn) Send(commandName string, args ...interface{}) error { return nil } +func (c *mockConn) Flush() error { return nil } +func (c *mockConn) Receive() (reply interface{}, err error) { return nil, nil } +func (c *mockConn) WithContext(context.Context) Conn { return c } + +func TestTraceDo(t *testing.T) { + tr := &mockTrace{} + ctx := trace.NewContext(context.Background(), tr) + tc := &traceConn{Conn: &mockConn{}, slowLogThreshold: testTraceSlowLogThreshold} + conn := tc.WithContext(ctx) + + conn.Do("GET", "test") + + assert.Equal(t, "Redis:GET", tr.operationName) + assert.NotEmpty(t, tr.tags) + assert.True(t, tr.finished) +} + +func TestTraceDoErr(t *testing.T) { + tr := &mockTrace{} + ctx := trace.NewContext(context.Background(), tr) + tc := &traceConn{Conn: MockErr{Error: fmt.Errorf("hhhhhhh")}, + slowLogThreshold: testTraceSlowLogThreshold} + conn := tc.WithContext(ctx) + + conn.Do("GET", "test") + + assert.Equal(t, "Redis:GET", tr.operationName) + assert.True(t, tr.finished) + assert.NotNil(t, *tr.perr) +} + +func TestTracePipeline(t *testing.T) { + tr := &mockTrace{} + ctx := trace.NewContext(context.Background(), tr) + tc := &traceConn{Conn: &mockConn{}, slowLogThreshold: testTraceSlowLogThreshold} + conn := tc.WithContext(ctx) + + N := 2 + for i := 0; i < N; i++ { + conn.Send("GET", "hello, world") + } + conn.Flush() + for i := 0; i < N; i++ { + conn.Receive() + } + + assert.Equal(t, "Redis:Pipeline", tr.operationName) + assert.NotEmpty(t, tr.tags) + assert.NotEmpty(t, tr.logs) + assert.True(t, tr.finished) +} + +func TestTracePipelineErr(t *testing.T) { + tr := &mockTrace{} + ctx := trace.NewContext(context.Background(), tr) + tc := &traceConn{Conn: MockErr{Error: fmt.Errorf("hahah")}, + slowLogThreshold: testTraceSlowLogThreshold} + conn := tc.WithContext(ctx) + + N := 2 + for i := 0; i < N; i++ { + conn.Send("GET", "hello, world") + } + conn.Flush() + for i := 0; i < N; i++ { + conn.Receive() + } + + assert.Equal(t, "Redis:Pipeline", tr.operationName) + assert.NotEmpty(t, tr.tags) + assert.NotEmpty(t, tr.logs) + assert.True(t, tr.finished) + var isError bool + for _, tag := range tr.tags { + if tag.Key == "error" { + isError = true + } + } + assert.True(t, isError) +} + +func TestSendStatement(t *testing.T) { + tr := &mockTrace{} + ctx := trace.NewContext(context.Background(), tr) + tc := &traceConn{Conn: MockErr{Error: fmt.Errorf("hahah")}, + slowLogThreshold: testTraceSlowLogThreshold} + conn := tc.WithContext(ctx) + conn.Send("SET", "hello", "test") + conn.Flush() + conn.Receive() + + assert.Equal(t, "Redis:Pipeline", tr.operationName) + assert.NotEmpty(t, tr.tags) + assert.NotEmpty(t, tr.logs) + assert.Equal(t, "event", tr.logs[0].Key) + assert.Equal(t, "Send", tr.logs[0].Value) + assert.Equal(t, "db.statement", tr.logs[1].Key) + assert.Equal(t, "SET hello", tr.logs[1].Value) + assert.True(t, tr.finished) + var isError bool + for _, tag := range tr.tags { + if tag.Key == "error" { + isError = true + } + } + assert.True(t, isError) +} + +func TestDoStatement(t *testing.T) { + tr := &mockTrace{} + ctx := trace.NewContext(context.Background(), tr) + tc := &traceConn{Conn: MockErr{Error: fmt.Errorf("hahah")}, + slowLogThreshold: testTraceSlowLogThreshold} + conn := tc.WithContext(ctx) + conn.Do("SET", "hello", "test") + + assert.Equal(t, "Redis:SET", tr.operationName) + assert.Equal(t, "SET hello", tr.tags[len(tr.tags)-1].Value) + assert.True(t, tr.finished) +} + +func BenchmarkTraceConn(b *testing.B) { + for i := 0; i < b.N; i++ { + c, err := DialDefaultServer() + if err != nil { + b.Fatal(err) + } + t := &traceConn{ + Conn: c, + connTags: []trace.Tag{trace.TagString(trace.TagPeerAddress, "abc")}, + slowLogThreshold: time.Duration(1 * time.Second), + } + c2 := t.WithContext(context.TODO()) + if _, err := c2.Do("PING"); err != nil { + b.Fatal(err) + } + c2.Close() + } +} diff --git a/pkg/cache/redis/util.go b/pkg/cache/redis/util.go new file mode 100644 index 000000000..aa52597bb --- /dev/null +++ b/pkg/cache/redis/util.go @@ -0,0 +1,17 @@ +package redis + +import ( + "context" + "time" +) + +func shrinkDeadline(ctx context.Context, timeout time.Duration) time.Time { + var timeoutTime = time.Now().Add(timeout) + if ctx == nil { + return timeoutTime + } + if deadline, ok := ctx.Deadline(); ok && timeoutTime.After(deadline) { + return deadline + } + return timeoutTime +} diff --git a/pkg/cache/redis/util_test.go b/pkg/cache/redis/util_test.go new file mode 100644 index 000000000..748b8423e --- /dev/null +++ b/pkg/cache/redis/util_test.go @@ -0,0 +1,37 @@ +package redis + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestShrinkDeadline(t *testing.T) { + t.Run("test not deadline", func(t *testing.T) { + timeout := time.Second + timeoutTime := time.Now().Add(timeout) + tm := shrinkDeadline(context.Background(), timeout) + assert.True(t, tm.After(timeoutTime)) + }) + t.Run("test big deadline", func(t *testing.T) { + timeout := time.Second + timeoutTime := time.Now().Add(timeout) + deadlineTime := time.Now().Add(2 * time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + tm := shrinkDeadline(ctx, timeout) + assert.True(t, tm.After(timeoutTime) && tm.Before(deadlineTime)) + }) + t.Run("test small deadline", func(t *testing.T) { + timeout := time.Second + deadlineTime := time.Now().Add(500 * time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + tm := shrinkDeadline(ctx, timeout) + assert.True(t, tm.After(deadlineTime) && tm.Before(time.Now().Add(timeout))) + }) +}