Merge pull request #325 from Windfarer/redis_pipeline

Redis API优化
pull/378/head
Terry.Mao 5 years ago committed by GitHub
commit 5949c6debf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 27
      pkg/cache/redis/commandinfo_test.go
  2. 62
      pkg/cache/redis/conn.go
  3. 670
      pkg/cache/redis/conn_test.go
  4. 18
      pkg/cache/redis/log.go
  5. 67
      pkg/cache/redis/main_test.go
  6. 4
      pkg/cache/redis/metrics.go
  7. 4
      pkg/cache/redis/mock.go
  8. 85
      pkg/cache/redis/pipeline.go
  9. 96
      pkg/cache/redis/pipeline_test.go
  10. 31
      pkg/cache/redis/pool.go
  11. 540
      pkg/cache/redis/pool_test.go
  12. 146
      pkg/cache/redis/pubsub_test.go
  13. 65
      pkg/cache/redis/redis.go
  14. 324
      pkg/cache/redis/redis_test.go
  15. 179
      pkg/cache/redis/reply_test.go
  16. 438
      pkg/cache/redis/scan_test.go
  17. 103
      pkg/cache/redis/script_test.go
  18. 12
      pkg/cache/redis/test/docker-compose.yaml
  19. 42
      pkg/cache/redis/trace.go
  20. 192
      pkg/cache/redis/trace_test.go
  21. 17
      pkg/cache/redis/util.go
  22. 37
      pkg/cache/redis/util_test.go

@ -0,0 +1,27 @@
package redis
import "testing"
func TestLookupCommandInfo(t *testing.T) {
for _, n := range []string{"watch", "WATCH", "wAtch"} {
if LookupCommandInfo(n) == (CommandInfo{}) {
t.Errorf("LookupCommandInfo(%q) = CommandInfo{}, expected non-zero value", n)
}
}
}
func benchmarkLookupCommandInfo(b *testing.B, names ...string) {
for i := 0; i < b.N; i++ {
for _, c := range names {
LookupCommandInfo(c)
}
}
}
func BenchmarkLookupCommandInfoCorrectCase(b *testing.B) {
benchmarkLookupCommandInfo(b, "watch", "WATCH", "monitor", "MONITOR")
}
func BenchmarkLookupCommandInfoMixedCase(b *testing.B) {
benchmarkLookupCommandInfo(b, "wAtch", "WeTCH", "monItor", "MONiTOR")
}

@ -30,6 +30,33 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// Conn represents a connection to a Redis server.
type Conn interface {
// Close closes the connection.
Close() error
// Err returns a non-nil value if the connection is broken. The returned
// value is either the first non-nil value returned from the underlying
// network connection or a protocol parsing error. Applications should
// close broken connections.
Err() error
// Do sends a command to the server and returns the received reply.
Do(commandName string, args ...interface{}) (reply interface{}, err error)
// Send writes the command to the client's output buffer.
Send(commandName string, args ...interface{}) error
// Flush flushes the output buffer to the Redis server.
Flush() error
// Receive receives a single reply from the Redis server
Receive() (reply interface{}, err error)
// WithContext returns Conn with the input ctx.
WithContext(ctx context.Context) Conn
}
// conn is the low-level implementation of Conn // conn is the low-level implementation of Conn
type conn struct { type conn struct {
// Shared // Shared
@ -38,6 +65,8 @@ type conn struct {
err error err error
conn net.Conn conn net.Conn
ctx context.Context
// Read // Read
readTimeout time.Duration readTimeout time.Duration
br *bufio.Reader br *bufio.Reader
@ -226,6 +255,7 @@ func NewConn(c *Config) (cn Conn, err error) {
func (c *conn) Close() error { func (c *conn) Close() error {
c.mu.Lock() c.mu.Lock()
c.ctx = nil
err := c.err err := c.err
if c.err == nil { if c.err == nil {
c.err = errors.New("redigo: closed") c.err = errors.New("redigo: closed")
@ -295,7 +325,7 @@ func (c *conn) writeFloat64(n float64) error {
func (c *conn) writeCommand(cmd string, args []interface{}) (err error) { func (c *conn) writeCommand(cmd string, args []interface{}) (err error) {
if c.writeTimeout != 0 { if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) c.conn.SetWriteDeadline(shrinkDeadline(c.ctx, c.writeTimeout))
} }
c.writeLen('*', 1+len(args)) c.writeLen('*', 1+len(args))
err = c.writeString(cmd) err = c.writeString(cmd)
@ -478,7 +508,7 @@ func (c *conn) Send(cmd string, args ...interface{}) (err error) {
func (c *conn) Flush() (err error) { func (c *conn) Flush() (err error) {
if c.writeTimeout != 0 { if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) c.conn.SetWriteDeadline(shrinkDeadline(c.ctx, c.writeTimeout))
} }
if err = c.bw.Flush(); err != nil { if err = c.bw.Flush(); err != nil {
c.fatal(err) c.fatal(err)
@ -488,7 +518,7 @@ func (c *conn) Flush() (err error) {
func (c *conn) Receive() (reply interface{}, err error) { func (c *conn) Receive() (reply interface{}, err error) {
if c.readTimeout != 0 { if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) c.conn.SetReadDeadline(shrinkDeadline(c.ctx, c.readTimeout))
} }
if reply, err = c.readReply(); err != nil { if reply, err = c.readReply(); err != nil {
return nil, c.fatal(err) return nil, c.fatal(err)
@ -511,7 +541,7 @@ func (c *conn) Receive() (reply interface{}, err error) {
return return
} }
func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { func (c *conn) Do(cmd string, args ...interface{}) (reply interface{}, err error) {
c.mu.Lock() c.mu.Lock()
pending := c.pending pending := c.pending
c.pending = 0 c.pending = 0
@ -519,7 +549,7 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
if cmd == "" && pending == 0 { if cmd == "" && pending == 0 {
return nil, nil return nil, nil
} }
var err error
if cmd != "" { if cmd != "" {
err = c.writeCommand(cmd, args) err = c.writeCommand(cmd, args)
} }
@ -530,7 +560,7 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
return nil, c.fatal(err) return nil, c.fatal(err)
} }
if c.readTimeout != 0 { if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) c.conn.SetReadDeadline(shrinkDeadline(c.ctx, c.readTimeout))
} }
if cmd == "" { if cmd == "" {
reply := make([]interface{}, pending) reply := make([]interface{}, pending)
@ -548,7 +578,6 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
return reply, nil return reply, nil
} }
var reply interface{}
for i := 0; i <= pending; i++ { for i := 0; i <= pending; i++ {
var e error var e error
if reply, e = c.readReply(); e != nil { if reply, e = c.readReply(); e != nil {
@ -561,5 +590,20 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
return reply, err return reply, err
} }
// WithContext FIXME: implement WithContext func (c *conn) copy() *conn {
func (c *conn) WithContext(ctx context.Context) Conn { return c } return &conn{
pending: c.pending,
err: c.err,
conn: c.conn,
bw: c.bw,
br: c.br,
readTimeout: c.readTimeout,
writeTimeout: c.writeTimeout,
}
}
func (c *conn) WithContext(ctx context.Context) Conn {
c2 := c.copy()
c2.ctx = ctx
return c2
}

@ -0,0 +1,670 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bytes"
"context"
"io"
"math"
"net"
"os"
"reflect"
"strings"
"testing"
"time"
)
type tConn struct {
io.Reader
io.Writer
}
func (*tConn) Close() error { return nil }
func (*tConn) LocalAddr() net.Addr { return nil }
func (*tConn) RemoteAddr() net.Addr { return nil }
func (*tConn) SetDeadline(t time.Time) error { return nil }
func (*tConn) SetReadDeadline(t time.Time) error { return nil }
func (*tConn) SetWriteDeadline(t time.Time) error { return nil }
func dialTestConn(r io.Reader, w io.Writer) DialOption {
return DialNetDial(func(net, addr string) (net.Conn, error) {
return &tConn{Reader: r, Writer: w}, nil
})
}
var writeTests = []struct {
args []interface{}
expected string
}{
{
[]interface{}{"SET", "key", "value"},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
},
{
[]interface{}{"SET", "key", "value"},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
},
{
[]interface{}{"SET", "key", byte(100)},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$3\r\n100\r\n",
},
{
[]interface{}{"SET", "key", 100},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$3\r\n100\r\n",
},
{
[]interface{}{"SET", "key", int64(math.MinInt64)},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$20\r\n-9223372036854775808\r\n",
},
{
[]interface{}{"SET", "key", float64(1349673917.939762)},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$21\r\n1.349673917939762e+09\r\n",
},
{
[]interface{}{"SET", "key", ""},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
},
{
[]interface{}{"SET", "key", nil},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
},
{
[]interface{}{"ECHO", true, false},
"*3\r\n$4\r\nECHO\r\n$1\r\n1\r\n$1\r\n0\r\n",
},
}
func TestWrite(t *testing.T) {
for _, tt := range writeTests {
var buf bytes.Buffer
c, _ := Dial("", "", dialTestConn(nil, &buf))
err := c.Send(tt.args[0].(string), tt.args[1:]...)
if err != nil {
t.Errorf("Send(%v) returned error %v", tt.args, err)
continue
}
c.Flush()
actual := buf.String()
if actual != tt.expected {
t.Errorf("Send(%v) = %q, want %q", tt.args, actual, tt.expected)
}
}
}
var errorSentinel = &struct{}{}
var readTests = []struct {
reply string
expected interface{}
}{
{
"+OK\r\n",
"OK",
},
{
"+PONG\r\n",
"PONG",
},
{
"@OK\r\n",
errorSentinel,
},
{
"$6\r\nfoobar\r\n",
[]byte("foobar"),
},
{
"$-1\r\n",
nil,
},
{
":1\r\n",
int64(1),
},
{
":-2\r\n",
int64(-2),
},
{
"*0\r\n",
[]interface{}{},
},
{
"*-1\r\n",
nil,
},
{
"*4\r\n$3\r\nfoo\r\n$3\r\nbar\r\n$5\r\nHello\r\n$5\r\nWorld\r\n",
[]interface{}{[]byte("foo"), []byte("bar"), []byte("Hello"), []byte("World")},
},
{
"*3\r\n$3\r\nfoo\r\n$-1\r\n$3\r\nbar\r\n",
[]interface{}{[]byte("foo"), nil, []byte("bar")},
},
{
// "x" is not a valid length
"$x\r\nfoobar\r\n",
errorSentinel,
},
{
// -2 is not a valid length
"$-2\r\n",
errorSentinel,
},
{
// "x" is not a valid integer
":x\r\n",
errorSentinel,
},
{
// missing \r\n following value
"$6\r\nfoobar",
errorSentinel,
},
{
// short value
"$6\r\nxx",
errorSentinel,
},
{
// long value
"$6\r\nfoobarx\r\n",
errorSentinel,
},
}
func TestRead(t *testing.T) {
for _, tt := range readTests {
c, _ := Dial("", "", dialTestConn(strings.NewReader(tt.reply), nil))
actual, err := c.Receive()
if tt.expected == errorSentinel {
if err == nil {
t.Errorf("Receive(%q) did not return expected error", tt.reply)
}
} else {
if err != nil {
t.Errorf("Receive(%q) returned error %v", tt.reply, err)
continue
}
if !reflect.DeepEqual(actual, tt.expected) {
t.Errorf("Receive(%q) = %v, want %v", tt.reply, actual, tt.expected)
}
}
}
}
var testCommands = []struct {
args []interface{}
expected interface{}
}{
{
[]interface{}{"PING"},
"PONG",
},
{
[]interface{}{"SET", "foo", "bar"},
"OK",
},
{
[]interface{}{"GET", "foo"},
[]byte("bar"),
},
{
[]interface{}{"GET", "nokey"},
nil,
},
{
[]interface{}{"MGET", "nokey", "foo"},
[]interface{}{nil, []byte("bar")},
},
{
[]interface{}{"INCR", "mycounter"},
int64(1),
},
{
[]interface{}{"LPUSH", "mylist", "foo"},
int64(1),
},
{
[]interface{}{"LPUSH", "mylist", "bar"},
int64(2),
},
{
[]interface{}{"LRANGE", "mylist", 0, -1},
[]interface{}{[]byte("bar"), []byte("foo")},
},
{
[]interface{}{"MULTI"},
"OK",
},
{
[]interface{}{"LRANGE", "mylist", 0, -1},
"QUEUED",
},
{
[]interface{}{"PING"},
"QUEUED",
},
{
[]interface{}{"EXEC"},
[]interface{}{
[]interface{}{[]byte("bar"), []byte("foo")},
"PONG",
},
},
}
func TestDoCommands(t *testing.T) {
c, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
for _, cmd := range testCommands {
actual, err := c.Do(cmd.args[0].(string), cmd.args[1:]...)
if err != nil {
t.Errorf("Do(%v) returned error %v", cmd.args, err)
continue
}
if !reflect.DeepEqual(actual, cmd.expected) {
t.Errorf("Do(%v) = %v, want %v", cmd.args, actual, cmd.expected)
}
}
}
func TestPipelineCommands(t *testing.T) {
c, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
for _, cmd := range testCommands {
if err := c.Send(cmd.args[0].(string), cmd.args[1:]...); err != nil {
t.Fatalf("Send(%v) returned error %v", cmd.args, err)
}
}
if err := c.Flush(); err != nil {
t.Errorf("Flush() returned error %v", err)
}
for _, cmd := range testCommands {
actual, err := c.Receive()
if err != nil {
t.Fatalf("Receive(%v) returned error %v", cmd.args, err)
}
if !reflect.DeepEqual(actual, cmd.expected) {
t.Errorf("Receive(%v) = %v, want %v", cmd.args, actual, cmd.expected)
}
}
}
func TestBlankCommmand(t *testing.T) {
c, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
for _, cmd := range testCommands {
if err = c.Send(cmd.args[0].(string), cmd.args[1:]...); err != nil {
t.Fatalf("Send(%v) returned error %v", cmd.args, err)
}
}
reply, err := Values(c.Do(""))
if err != nil {
t.Fatalf("Do() returned error %v", err)
}
if len(reply) != len(testCommands) {
t.Fatalf("len(reply)=%d, want %d", len(reply), len(testCommands))
}
for i, cmd := range testCommands {
actual := reply[i]
if !reflect.DeepEqual(actual, cmd.expected) {
t.Errorf("Receive(%v) = %v, want %v", cmd.args, actual, cmd.expected)
}
}
}
func TestRecvBeforeSend(t *testing.T) {
c, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
done := make(chan struct{})
go func() {
c.Receive()
close(done)
}()
time.Sleep(time.Millisecond)
c.Send("PING")
c.Flush()
<-done
_, err = c.Do("")
if err != nil {
t.Fatalf("error=%v", err)
}
}
func TestError(t *testing.T) {
c, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
c.Do("SET", "key", "val")
_, err = c.Do("HSET", "key", "fld", "val")
if err == nil {
t.Errorf("Expected err for HSET on string key.")
}
if c.Err() != nil {
t.Errorf("Conn has Err()=%v, expect nil", c.Err())
}
_, err = c.Do("SET", "key", "val")
if err != nil {
t.Errorf("Do(SET, key, val) returned error %v, expected nil.", err)
}
}
func TestReadTimeout(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen returned %v", err)
}
defer l.Close()
go func() {
for {
c, err1 := l.Accept()
if err1 != nil {
return
}
go func() {
time.Sleep(time.Second)
c.Write([]byte("+OK\r\n"))
c.Close()
}()
}
}()
// Do
c1, err := Dial(l.Addr().Network(), l.Addr().String(), DialReadTimeout(time.Millisecond))
if err != nil {
t.Fatalf("Dial returned %v", err)
}
defer c1.Close()
_, err = c1.Do("PING")
if err == nil {
t.Fatalf("c1.Do() returned nil, expect error")
}
if c1.Err() == nil {
t.Fatalf("c1.Err() = nil, expect error")
}
// Send/Flush/Receive
c2, err := Dial(l.Addr().Network(), l.Addr().String(), DialReadTimeout(time.Millisecond))
if err != nil {
t.Fatalf("Dial returned %v", err)
}
defer c2.Close()
c2.Send("PING")
c2.Flush()
_, err = c2.Receive()
if err == nil {
t.Fatalf("c2.Receive() returned nil, expect error")
}
if c2.Err() == nil {
t.Fatalf("c2.Err() = nil, expect error")
}
}
var dialErrors = []struct {
rawurl string
expectedError string
}{
{
"localhost",
"invalid redis URL scheme",
},
// The error message for invalid hosts is diffferent in different
// versions of Go, so just check that there is an error message.
{
"redis://weird url",
"",
},
{
"redis://foo:bar:baz",
"",
},
{
"http://www.google.com",
"invalid redis URL scheme: http",
},
{
"redis://localhost:6379/abc123",
"invalid database: abc123",
},
}
func TestDialURLErrors(t *testing.T) {
for _, d := range dialErrors {
_, err := DialURL(d.rawurl)
if err == nil || !strings.Contains(err.Error(), d.expectedError) {
t.Errorf("DialURL did not return expected error (expected %v to contain %s)", err, d.expectedError)
}
}
}
func TestDialURLPort(t *testing.T) {
checkPort := func(network, address string) (net.Conn, error) {
if address != "localhost:6379" {
t.Errorf("DialURL did not set port to 6379 by default (got %v)", address)
}
return nil, nil
}
_, err := DialURL("redis://localhost", DialNetDial(checkPort))
if err != nil {
t.Error("dial error:", err)
}
}
func TestDialURLHost(t *testing.T) {
checkHost := func(network, address string) (net.Conn, error) {
if address != "localhost:6379" {
t.Errorf("DialURL did not set host to localhost by default (got %v)", address)
}
return nil, nil
}
_, err := DialURL("redis://:6379", DialNetDial(checkHost))
if err != nil {
t.Error("dial error:", err)
}
}
func TestDialURLPassword(t *testing.T) {
var buf bytes.Buffer
_, err := DialURL("redis://x:abc123@localhost", dialTestConn(strings.NewReader("+OK\r\n"), &buf))
if err != nil {
t.Error("dial error:", err)
}
expected := "*2\r\n$4\r\nAUTH\r\n$6\r\nabc123\r\n"
actual := buf.String()
if actual != expected {
t.Errorf("commands = %q, want %q", actual, expected)
}
}
func TestDialURLDatabase(t *testing.T) {
var buf bytes.Buffer
_, err := DialURL("redis://localhost/3", dialTestConn(strings.NewReader("+OK\r\n"), &buf))
if err != nil {
t.Error("dial error:", err)
}
expected := "*2\r\n$6\r\nSELECT\r\n$1\r\n3\r\n"
actual := buf.String()
if actual != expected {
t.Errorf("commands = %q, want %q", actual, expected)
}
}
// Connect to local instance of Redis running on the default port.
func ExampleDial() {
c, err := Dial("tcp", ":6379")
if err != nil {
// handle error
}
defer c.Close()
}
// Connect to remote instance of Redis using a URL.
func ExampleDialURL() {
c, err := DialURL(os.Getenv("REDIS_URL"))
if err != nil {
// handle connection error
}
defer c.Close()
}
// TextExecError tests handling of errors in a transaction. See
// http://io/topics/transactions for information on how Redis handles
// errors in a transaction.
func TestExecError(t *testing.T) {
c, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
// Execute commands that fail before EXEC is called.
c.Do("DEL", "k0")
c.Do("ZADD", "k0", 0, 0)
c.Send("MULTI")
c.Send("NOTACOMMAND", "k0", 0, 0)
c.Send("ZINCRBY", "k0", 0, 0)
v, err := c.Do("EXEC")
if err == nil {
t.Fatalf("EXEC returned values %v, expected error", v)
}
// Execute commands that fail after EXEC is called. The first command
// returns an error.
c.Do("DEL", "k1")
c.Do("ZADD", "k1", 0, 0)
c.Send("MULTI")
c.Send("HSET", "k1", 0, 0)
c.Send("ZINCRBY", "k1", 0, 0)
v, err = c.Do("EXEC")
if err != nil {
t.Fatalf("EXEC returned error %v", err)
}
vs, err := Values(v, nil)
if err != nil {
t.Fatalf("Values(v) returned error %v", err)
}
if len(vs) != 2 {
t.Fatalf("len(vs) == %d, want 2", len(vs))
}
if _, ok := vs[0].(error); !ok {
t.Fatalf("first result is type %T, expected error", vs[0])
}
if _, ok := vs[1].([]byte); !ok {
t.Fatalf("second result is type %T, expected []byte", vs[1])
}
// Execute commands that fail after EXEC is called. The second command
// returns an error.
c.Do("ZADD", "k2", 0, 0)
c.Send("MULTI")
c.Send("ZINCRBY", "k2", 0, 0)
c.Send("HSET", "k2", 0, 0)
v, err = c.Do("EXEC")
if err != nil {
t.Fatalf("EXEC returned error %v", err)
}
vs, err = Values(v, nil)
if err != nil {
t.Fatalf("Values(v) returned error %v", err)
}
if len(vs) != 2 {
t.Fatalf("len(vs) == %d, want 2", len(vs))
}
if _, ok := vs[0].([]byte); !ok {
t.Fatalf("first result is type %T, expected []byte", vs[0])
}
if _, ok := vs[1].(error); !ok {
t.Fatalf("second result is type %T, expected error", vs[2])
}
}
func BenchmarkDoEmpty(b *testing.B) {
c, err := DialDefaultServer()
if err != nil {
b.Fatal(err)
}
defer c.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := c.Do(""); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkDoPing(b *testing.B) {
c, err := DialDefaultServer()
if err != nil {
b.Fatal(err)
}
defer c.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := c.Do("PING"); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkConn(b *testing.B) {
for i := 0; i < b.N; i++ {
c, err := DialDefaultServer()
if err != nil {
b.Fatal(err)
}
c2 := c.WithContext(context.TODO())
if _, err := c2.Do("PING"); err != nil {
b.Fatal(err)
}
c2.Close()
}
}

@ -16,16 +16,18 @@ package redis
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"log" "log"
) )
// NewLoggingConn returns a logging wrapper around a connection. // NewLoggingConn returns a logging wrapper around a connection.
// ATTENTION: ONLY use loggingConn in developing, DO NOT use this in production.
func NewLoggingConn(conn Conn, logger *log.Logger, prefix string) Conn { func NewLoggingConn(conn Conn, logger *log.Logger, prefix string) Conn {
if prefix != "" { if prefix != "" {
prefix = prefix + "." prefix = prefix + "."
} }
return &loggingConn{conn, logger, prefix} return &loggingConn{Conn: conn, logger: logger, prefix: prefix}
} }
type loggingConn struct { type loggingConn struct {
@ -98,16 +100,16 @@ func (c *loggingConn) print(method, commandName string, args []interface{}, repl
c.logger.Output(3, buf.String()) c.logger.Output(3, buf.String())
} }
func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{}, error) { func (c *loggingConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
reply, err := c.Conn.Do(commandName, args...) reply, err = c.Conn.Do(commandName, args...)
c.print("Do", commandName, args, reply, err) c.print("Do", commandName, args, reply, err)
return reply, err return reply, err
} }
func (c *loggingConn) Send(commandName string, args ...interface{}) error { func (c *loggingConn) Send(commandName string, args ...interface{}) (err error) {
err := c.Conn.Send(commandName, args...) err = c.Conn.Send(commandName, args...)
c.print("Send", commandName, args, nil, err) c.print("Send", commandName, args, nil, err)
return err return
} }
func (c *loggingConn) Receive() (interface{}, error) { func (c *loggingConn) Receive() (interface{}, error) {
@ -115,3 +117,7 @@ func (c *loggingConn) Receive() (interface{}, error) {
c.print("Receive", "", nil, reply, err) c.print("Receive", "", nil, reply, err)
return reply, err return reply, err
} }
func (c *loggingConn) WithContext(ctx context.Context) Conn {
return c
}

@ -0,0 +1,67 @@
package redis
import (
"flag"
"os"
"testing"
"time"
"github.com/bilibili/kratos/pkg/container/pool"
"github.com/bilibili/kratos/pkg/testing/lich"
xtime "github.com/bilibili/kratos/pkg/time"
)
var (
testRedisAddr string
testPool *Pool
testConfig *Config
)
func setupTestConfig(addr string) {
c := getTestConfig(addr)
c.Config = &pool.Config{
Active: 20,
Idle: 2,
IdleTimeout: xtime.Duration(90 * time.Second),
}
testConfig = c
}
func getTestConfig(addr string) *Config {
return &Config{
Name: "test",
Proto: "tcp",
Addr: addr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}
}
func setupTestPool() {
testPool = NewPool(testConfig)
}
// DialDefaultServer starts the test server if not already started and dials a
// connection to the server.
func DialDefaultServer() (Conn, error) {
c, err := Dial("tcp", testRedisAddr, DialReadTimeout(1*time.Second), DialWriteTimeout(1*time.Second))
if err != nil {
return nil, err
}
c.Do("FLUSHDB")
return c, nil
}
func TestMain(m *testing.M) {
flag.Set("f", "./test/docker-compose.yaml")
if err := lich.Setup(); err != nil {
panic(err)
}
defer lich.Teardown()
testRedisAddr = "localhost:6379"
setupTestConfig(testRedisAddr)
setupTestPool()
ret := m.Run()
os.Exit(ret)
}

@ -1,6 +1,8 @@
package redis package redis
import "github.com/bilibili/kratos/pkg/stat/metric" import (
"github.com/bilibili/kratos/pkg/stat/metric"
)
const namespace = "redis_client" const namespace = "redis_client"

@ -1,8 +1,6 @@
package redis package redis
import ( import "context"
"context"
)
// MockErr for unit test. // MockErr for unit test.
type MockErr struct { type MockErr struct {

@ -0,0 +1,85 @@
package redis
import (
"context"
"errors"
)
type Pipeliner interface {
// Send writes the command to the client's output buffer.
Send(commandName string, args ...interface{})
// Exec executes all commands and get replies.
Exec(ctx context.Context) (rs *Replies, err error)
}
var (
ErrNoReply = errors.New("redis: no reply in result set")
)
type pipeliner struct {
pool *Pool
cmds []*cmd
}
type Replies struct {
replies []*reply
}
type reply struct {
reply interface{}
err error
}
func (rs *Replies) Next() bool {
return len(rs.replies) > 0
}
func (rs *Replies) Scan() (reply interface{}, err error) {
if !rs.Next() {
return nil, ErrNoReply
}
reply, err = rs.replies[0].reply, rs.replies[0].err
rs.replies = rs.replies[1:]
return
}
type cmd struct {
commandName string
args []interface{}
}
func (p *pipeliner) Send(commandName string, args ...interface{}) {
p.cmds = append(p.cmds, &cmd{commandName: commandName, args: args})
return
}
func (p *pipeliner) Exec(ctx context.Context) (rs *Replies, err error) {
n := len(p.cmds)
if n == 0 {
return &Replies{}, nil
}
c := p.pool.Get(ctx)
defer c.Close()
for len(p.cmds) > 0 {
cmd := p.cmds[0]
p.cmds = p.cmds[1:]
if err := c.Send(cmd.commandName, cmd.args...); err != nil {
p.cmds = p.cmds[:0]
return nil, err
}
}
if err = c.Flush(); err != nil {
p.cmds = p.cmds[:0]
return nil, err
}
rps := make([]*reply, 0, n)
for i := 0; i < n; i++ {
rp, err := c.Receive()
rps = append(rps, &reply{reply: rp, err: err})
}
rs = &Replies{
replies: rps,
}
return
}

@ -0,0 +1,96 @@
package redis
import (
"context"
"fmt"
"reflect"
"testing"
"time"
"github.com/bilibili/kratos/pkg/container/pool"
xtime "github.com/bilibili/kratos/pkg/time"
)
func TestRedis_Pipeline(t *testing.T) {
conf := &Config{
Name: "test",
Proto: "tcp",
Addr: testRedisAddr,
DialTimeout: xtime.Duration(1 * time.Second),
ReadTimeout: xtime.Duration(1 * time.Second),
WriteTimeout: xtime.Duration(1 * time.Second),
}
conf.Config = &pool.Config{
Active: 10,
Idle: 2,
IdleTimeout: xtime.Duration(90 * time.Second),
}
r := NewRedis(conf)
r.Do(context.TODO(), "FLUSHDB")
p := r.Pipeline()
for _, cmd := range testCommands {
p.Send(cmd.args[0].(string), cmd.args[1:]...)
}
replies, err := p.Exec(context.TODO())
i := 0
for replies.Next() {
cmd := testCommands[i]
actual, err := replies.Scan()
if err != nil {
t.Fatalf("Receive(%v) returned error %v", cmd.args, err)
}
if !reflect.DeepEqual(actual, cmd.expected) {
t.Errorf("Receive(%v) = %v, want %v", cmd.args, actual, cmd.expected)
}
i++
}
err = r.Close()
if err != nil {
t.Errorf("Close() error %v", err)
}
}
func ExamplePipeliner() {
r := NewRedis(testConfig)
defer r.Close()
pip := r.Pipeline()
pip.Send("SET", "hello", "world")
pip.Send("GET", "hello")
replies, err := pip.Exec(context.TODO())
if err != nil {
fmt.Printf("%#v\n", err)
}
for replies.Next() {
s, err := String(replies.Scan())
if err != nil {
fmt.Printf("err %#v\n", err)
}
fmt.Printf("%#v\n", s)
}
// Output:
// "OK"
// "world"
}
func BenchmarkRedisPipelineExec(b *testing.B) {
r := NewRedis(testConfig)
defer r.Close()
r.Do(context.TODO(), "SET", "abcde", "fghiasdfasdf")
b.ResetTimer()
for i := 0; i < b.N; i++ {
p := r.Pipeline()
p.Send("GET", "abcde")
_, err := p.Exec(context.TODO())
if err != nil {
b.Fatal(err)
}
}
}

@ -45,24 +45,14 @@ type Pool struct {
statfunc func(name, addr, cmd string, t time.Time, err error) func() statfunc func(name, addr, cmd string, t time.Time, err error) func()
} }
// Config client settings.
type Config struct {
*pool.Config
Name string // redis name, for trace
Proto string
Addr string
Auth string
DialTimeout xtime.Duration
ReadTimeout xtime.Duration
WriteTimeout xtime.Duration
}
// NewPool creates a new pool. // NewPool creates a new pool.
func NewPool(c *Config, options ...DialOption) (p *Pool) { func NewPool(c *Config, options ...DialOption) (p *Pool) {
if c.DialTimeout <= 0 || c.ReadTimeout <= 0 || c.WriteTimeout <= 0 { if c.DialTimeout <= 0 || c.ReadTimeout <= 0 || c.WriteTimeout <= 0 {
panic("must config redis timeout") panic("must config redis timeout")
} }
if c.SlowLog <= 0 {
c.SlowLog = xtime.Duration(250 * time.Millisecond)
}
ops := []DialOption{ ops := []DialOption{
DialConnectTimeout(time.Duration(c.DialTimeout)), DialConnectTimeout(time.Duration(c.DialTimeout)),
DialReadTimeout(time.Duration(c.ReadTimeout)), DialReadTimeout(time.Duration(c.ReadTimeout)),
@ -71,12 +61,18 @@ func NewPool(c *Config, options ...DialOption) (p *Pool) {
} }
ops = append(ops, options...) ops = append(ops, options...)
p1 := pool.NewSlice(c.Config) p1 := pool.NewSlice(c.Config)
// new pool
p1.New = func(ctx context.Context) (io.Closer, error) { p1.New = func(ctx context.Context) (io.Closer, error) {
conn, err := Dial(c.Proto, c.Addr, ops...) conn, err := Dial(c.Proto, c.Addr, ops...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &traceConn{Conn: conn, connTags: []trace.Tag{trace.TagString(trace.TagPeerAddress, c.Addr)}}, nil return &traceConn{
Conn: conn,
connTags: []trace.Tag{trace.TagString(trace.TagPeerAddress, c.Addr)},
slowLogThreshold: time.Duration(c.SlowLog),
}, nil
} }
p = &Pool{Slice: p1, c: c, statfunc: pstat} p = &Pool{Slice: p1, c: c, statfunc: pstat}
return return
@ -93,7 +89,7 @@ func (p *Pool) Get(ctx context.Context) Conn {
return errorConnection{err} return errorConnection{err}
} }
c1, _ := c.(Conn) c1, _ := c.(Conn)
return &pooledConnection{p: p, c: c1.WithContext(ctx), ctx: ctx, now: beginTime} return &pooledConnection{p: p, c: c1.WithContext(ctx), rc: c1, now: beginTime}
} }
// Close releases the resources used by the pool. // Close releases the resources used by the pool.
@ -103,12 +99,12 @@ func (p *Pool) Close() error {
type pooledConnection struct { type pooledConnection struct {
p *Pool p *Pool
rc Conn
c Conn c Conn
state int state int
now time.Time now time.Time
cmds []string cmds []string
ctx context.Context
} }
var ( var (
@ -180,7 +176,7 @@ func (pc *pooledConnection) Close() error {
} }
} }
_, err := c.Do("") _, err := c.Do("")
pc.p.Slice.Put(context.Background(), c, pc.state != 0 || c.Err() != nil) pc.p.Slice.Put(context.Background(), pc.rc, pc.state != 0 || c.Err() != nil)
return err return err
} }
@ -227,7 +223,6 @@ func (pc *pooledConnection) Receive() (reply interface{}, err error) {
} }
func (pc *pooledConnection) WithContext(ctx context.Context) Conn { func (pc *pooledConnection) WithContext(ctx context.Context) Conn {
pc.ctx = ctx
return pc return pc
} }

@ -0,0 +1,540 @@
// Copyright 2011 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"context"
"errors"
"io"
"reflect"
"sync"
"testing"
"time"
"github.com/bilibili/kratos/pkg/container/pool"
)
type poolTestConn struct {
d *poolDialer
err error
c Conn
ctx context.Context
}
func (c *poolTestConn) Flush() error {
return c.c.Flush()
}
func (c *poolTestConn) Receive() (reply interface{}, err error) {
return c.c.Receive()
}
func (c *poolTestConn) WithContext(ctx context.Context) Conn {
c.c.WithContext(ctx)
c.ctx = ctx
return c
}
func (c *poolTestConn) Close() error {
c.d.mu.Lock()
c.d.open--
c.d.mu.Unlock()
return c.c.Close()
}
func (c *poolTestConn) Err() error { return c.err }
func (c *poolTestConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
if commandName == "ERR" {
c.err = args[0].(error)
commandName = "PING"
}
if commandName != "" {
c.d.commands = append(c.d.commands, commandName)
}
return c.c.Do(commandName, args...)
}
func (c *poolTestConn) Send(commandName string, args ...interface{}) error {
c.d.commands = append(c.d.commands, commandName)
return c.c.Send(commandName, args...)
}
type poolDialer struct {
mu sync.Mutex
t *testing.T
dialed int
open int
commands []string
dialErr error
}
func (d *poolDialer) dial() (Conn, error) {
d.mu.Lock()
d.dialed += 1
dialErr := d.dialErr
d.mu.Unlock()
if dialErr != nil {
return nil, d.dialErr
}
c, err := DialDefaultServer()
if err != nil {
return nil, err
}
d.mu.Lock()
d.open += 1
d.mu.Unlock()
return &poolTestConn{d: d, c: c}, nil
}
func (d *poolDialer) check(message string, p *Pool, dialed, open int) {
d.mu.Lock()
if d.dialed != dialed {
d.t.Errorf("%s: dialed=%d, want %d", message, d.dialed, dialed)
}
if d.open != open {
d.t.Errorf("%s: open=%d, want %d", message, d.open, open)
}
// if active := p.ActiveCount(); active != open {
// d.t.Errorf("%s: active=%d, want %d", message, active, open)
// }
d.mu.Unlock()
}
func TestPoolReuse(t *testing.T) {
d := poolDialer{t: t}
p := NewPool(testConfig)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
var err error
for i := 0; i < 10; i++ {
c1 := p.Get(context.TODO())
c1.Do("PING")
c2 := p.Get(context.TODO())
c2.Do("PING")
c1.Close()
c2.Close()
}
d.check("before close", p, 2, 2)
err = p.Close()
if err != nil {
t.Fatal(err)
}
d.check("after close", p, 2, 0)
}
func TestPoolMaxIdle(t *testing.T) {
d := poolDialer{t: t}
p := NewPool(testConfig)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
for i := 0; i < 10; i++ {
c1 := p.Get(context.TODO())
c1.Do("PING")
c2 := p.Get(context.TODO())
c2.Do("PING")
c3 := p.Get(context.TODO())
c3.Do("PING")
c1.Close()
c2.Close()
c3.Close()
}
d.check("before close", p, 12, 2)
p.Close()
d.check("after close", p, 12, 0)
}
func TestPoolError(t *testing.T) {
d := poolDialer{t: t}
p := NewPool(testConfig)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
c := p.Get(context.TODO())
c.Do("ERR", io.EOF)
if c.Err() == nil {
t.Errorf("expected c.Err() != nil")
}
c.Close()
c = p.Get(context.TODO())
c.Do("ERR", io.EOF)
c.Close()
d.check(".", p, 2, 0)
}
func TestPoolClose(t *testing.T) {
d := poolDialer{t: t}
p := NewPool(testConfig)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
c1 := p.Get(context.TODO())
c1.Do("PING")
c2 := p.Get(context.TODO())
c2.Do("PING")
c3 := p.Get(context.TODO())
c3.Do("PING")
c1.Close()
if _, err := c1.Do("PING"); err == nil {
t.Errorf("expected error after connection closed")
}
c2.Close()
c2.Close()
p.Close()
d.check("after pool close", p, 3, 1)
if _, err := c1.Do("PING"); err == nil {
t.Errorf("expected error after connection and pool closed")
}
c3.Close()
d.check("after conn close", p, 3, 0)
c1 = p.Get(context.TODO())
if _, err := c1.Do("PING"); err == nil {
t.Errorf("expected error after pool closed")
}
}
func TestPoolConcurrenSendReceive(t *testing.T) {
p := NewPool(testConfig)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return DialDefaultServer()
}
defer p.Close()
c := p.Get(context.TODO())
done := make(chan error, 1)
go func() {
_, err := c.Receive()
done <- err
}()
c.Send("PING")
c.Flush()
err := <-done
if err != nil {
t.Fatalf("Receive() returned error %v", err)
}
_, err = c.Do("")
if err != nil {
t.Fatalf("Do() returned error %v", err)
}
c.Close()
}
func TestPoolMaxActive(t *testing.T) {
d := poolDialer{t: t}
conf := getTestConfig(testRedisAddr)
conf.Config = &pool.Config{
Active: 2,
Idle: 2,
}
p := NewPool(conf)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
c1 := p.Get(context.TODO())
c1.Do("PING")
c2 := p.Get(context.TODO())
c2.Do("PING")
d.check("1", p, 2, 2)
c3 := p.Get(context.TODO())
if _, err := c3.Do("PING"); err != pool.ErrPoolExhausted {
t.Errorf("expected pool exhausted")
}
c3.Close()
d.check("2", p, 2, 2)
c2.Close()
d.check("3", p, 2, 2)
c3 = p.Get(context.TODO())
if _, err := c3.Do("PING"); err != nil {
t.Errorf("expected good channel, err=%v", err)
}
c3.Close()
d.check("4", p, 2, 2)
}
func TestPoolMonitorCleanup(t *testing.T) {
d := poolDialer{t: t}
p := NewPool(testConfig)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
c := p.Get(context.TODO())
c.Send("MONITOR")
c.Close()
d.check("", p, 1, 0)
}
func TestPoolPubSubCleanup(t *testing.T) {
d := poolDialer{t: t}
p := NewPool(testConfig)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
c := p.Get(context.TODO())
c.Send("SUBSCRIBE", "x")
c.Close()
want := []string{"SUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get(context.TODO())
c.Send("PSUBSCRIBE", "x*")
c.Close()
want = []string{"PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
}
func TestPoolTransactionCleanup(t *testing.T) {
d := poolDialer{t: t}
p := NewPool(testConfig)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
c := p.Get(context.TODO())
c.Do("WATCH", "key")
c.Do("PING")
c.Close()
want := []string{"WATCH", "PING", "UNWATCH"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get(context.TODO())
c.Do("WATCH", "key")
c.Do("UNWATCH")
c.Do("PING")
c.Close()
want = []string{"WATCH", "UNWATCH", "PING"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get(context.TODO())
c.Do("WATCH", "key")
c.Do("MULTI")
c.Do("PING")
c.Close()
want = []string{"WATCH", "MULTI", "PING", "DISCARD"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get(context.TODO())
c.Do("WATCH", "key")
c.Do("MULTI")
c.Do("DISCARD")
c.Do("PING")
c.Close()
want = []string{"WATCH", "MULTI", "DISCARD", "PING"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get(context.TODO())
c.Do("WATCH", "key")
c.Do("MULTI")
c.Do("EXEC")
c.Do("PING")
c.Close()
want = []string{"WATCH", "MULTI", "EXEC", "PING"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
}
func startGoroutines(p *Pool, cmd string, args ...interface{}) chan error {
errs := make(chan error, 10)
for i := 0; i < cap(errs); i++ {
go func() {
c := p.Get(context.TODO())
_, err := c.Do(cmd, args...)
errs <- err
c.Close()
}()
}
// Wait for goroutines to block.
time.Sleep(time.Second / 4)
return errs
}
func TestWaitPoolDialError(t *testing.T) {
testErr := errors.New("test")
d := poolDialer{t: t}
config1 := testConfig
config1.Config = &pool.Config{
Active: 1,
Idle: 1,
Wait: true,
}
p := NewPool(config1)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
c := p.Get(context.TODO())
errs := startGoroutines(p, "ERR", testErr)
d.check("before close", p, 1, 1)
d.dialErr = errors.New("dial")
c.Close()
nilCount := 0
errCount := 0
timeout := time.After(2 * time.Second)
for i := 0; i < cap(errs); i++ {
select {
case err := <-errs:
switch err {
case nil:
nilCount++
case d.dialErr:
errCount++
default:
t.Fatalf("expected dial error or nil, got %v", err)
}
case <-timeout:
t.Logf("Wait all the time and timeout %d", i)
return
}
}
if nilCount != 1 {
t.Errorf("expected one nil error, got %d", nilCount)
}
if errCount != cap(errs)-1 {
t.Errorf("expected %d dial erors, got %d", cap(errs)-1, errCount)
}
d.check("done", p, cap(errs), 0)
}
func BenchmarkPoolGet(b *testing.B) {
b.StopTimer()
p := NewPool(testConfig)
c := p.Get(context.Background())
if err := c.Err(); err != nil {
b.Fatal(err)
}
c.Close()
defer p.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
c := p.Get(context.Background())
c.Close()
}
}
func BenchmarkPoolGetErr(b *testing.B) {
b.StopTimer()
p := NewPool(testConfig)
c := p.Get(context.Background())
if err := c.Err(); err != nil {
b.Fatal(err)
}
c.Close()
defer p.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
c = p.Get(context.Background())
if err := c.Err(); err != nil {
b.Fatal(err)
}
c.Close()
}
}
func BenchmarkPoolGetPing(b *testing.B) {
b.StopTimer()
p := NewPool(testConfig)
c := p.Get(context.Background())
if err := c.Err(); err != nil {
b.Fatal(err)
}
c.Close()
defer p.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
c := p.Get(context.Background())
if _, err := c.Do("PING"); err != nil {
b.Fatal(err)
}
c.Close()
}
}
func BenchmarkPooledConn(b *testing.B) {
p := NewPool(testConfig)
defer p.Close()
for i := 0; i < b.N; i++ {
ctx := context.TODO()
c := p.Get(ctx)
c2 := c.WithContext(context.TODO())
if _, err := c2.Do("PING"); err != nil {
b.Fatal(err)
}
c2.Close()
}
}

@ -0,0 +1,146 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"fmt"
"reflect"
"sync"
"testing"
)
func publish(channel, value interface{}) {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("PUBLISH", channel, value)
}
// Applications can receive pushed messages from one goroutine and manage subscriptions from another goroutine.
func ExamplePubSubConn() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
var wg sync.WaitGroup
wg.Add(2)
psc := PubSubConn{Conn: c}
// This goroutine receives and prints pushed notifications from the server.
// The goroutine exits when the connection is unsubscribed from all
// channels or there is an error.
go func() {
defer wg.Done()
for {
switch n := psc.Receive().(type) {
case Message:
fmt.Printf("Message: %s %s\n", n.Channel, n.Data)
case PMessage:
fmt.Printf("PMessage: %s %s %s\n", n.Pattern, n.Channel, n.Data)
case Subscription:
fmt.Printf("Subscription: %s %s %d\n", n.Kind, n.Channel, n.Count)
if n.Count == 0 {
return
}
case error:
fmt.Printf("error: %v\n", n)
return
}
}
}()
// This goroutine manages subscriptions for the connection.
go func() {
defer wg.Done()
psc.Subscribe("example")
psc.PSubscribe("p*")
// The following function calls publish a message using another
// connection to the Redis server.
publish("example", "hello")
publish("example", "world")
publish("pexample", "foo")
publish("pexample", "bar")
// Unsubscribe from all connections. This will cause the receiving
// goroutine to exit.
psc.Unsubscribe()
psc.PUnsubscribe()
}()
wg.Wait()
// Output:
// Subscription: subscribe example 1
// Subscription: psubscribe p* 2
// Message: example hello
// Message: example world
// PMessage: p* pexample foo
// PMessage: p* pexample bar
// Subscription: unsubscribe example 1
// Subscription: punsubscribe p* 0
}
func expectPushed(t *testing.T, c PubSubConn, message string, expected interface{}) {
actual := c.Receive()
if !reflect.DeepEqual(actual, expected) {
t.Errorf("%s = %v, want %v", message, actual, expected)
}
}
func TestPushed(t *testing.T) {
pc, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer pc.Close()
sc, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer sc.Close()
c := PubSubConn{Conn: sc}
c.Subscribe("c1")
expectPushed(t, c, "Subscribe(c1)", Subscription{Kind: "subscribe", Channel: "c1", Count: 1})
c.Subscribe("c2")
expectPushed(t, c, "Subscribe(c2)", Subscription{Kind: "subscribe", Channel: "c2", Count: 2})
c.PSubscribe("p1")
expectPushed(t, c, "PSubscribe(p1)", Subscription{Kind: "psubscribe", Channel: "p1", Count: 3})
c.PSubscribe("p2")
expectPushed(t, c, "PSubscribe(p2)", Subscription{Kind: "psubscribe", Channel: "p2", Count: 4})
c.PUnsubscribe()
expectPushed(t, c, "Punsubscribe(p1)", Subscription{Kind: "punsubscribe", Channel: "p1", Count: 3})
expectPushed(t, c, "Punsubscribe()", Subscription{Kind: "punsubscribe", Channel: "p2", Count: 2})
pc.Do("PUBLISH", "c1", "hello")
expectPushed(t, c, "PUBLISH c1 hello", Message{Channel: "c1", Data: []byte("hello")})
c.Ping("hello")
expectPushed(t, c, `Ping("hello")`, Pong{"hello"})
c.Conn.Send("PING")
c.Conn.Flush()
expectPushed(t, c, `Send("PING")`, Pong{})
}

@ -16,6 +16,9 @@ package redis
import ( import (
"context" "context"
"github.com/bilibili/kratos/pkg/container/pool"
xtime "github.com/bilibili/kratos/pkg/time"
) )
// Error represents an error returned in a command reply. // Error represents an error returned in a command reply.
@ -23,29 +26,53 @@ type Error string
func (err Error) Error() string { return string(err) } func (err Error) Error() string { return string(err) }
// Conn represents a connection to a Redis server. // Config client settings.
type Conn interface { type Config struct {
// Close closes the connection. *pool.Config
Close() error
// Err returns a non-nil value if the connection is broken. The returned Name string // redis name, for trace
// value is either the first non-nil value returned from the underlying Proto string
// network connection or a protocol parsing error. Applications should Addr string
// close broken connections. Auth string
Err() error DialTimeout xtime.Duration
ReadTimeout xtime.Duration
WriteTimeout xtime.Duration
SlowLog xtime.Duration
}
// Do sends a command to the server and returns the received reply. type Redis struct {
Do(commandName string, args ...interface{}) (reply interface{}, err error) pool *Pool
conf *Config
}
func NewRedis(c *Config, options ...DialOption) *Redis {
return &Redis{
pool: NewPool(c, options...),
conf: c,
}
}
// Send writes the command to the client's output buffer. // Do gets a new conn from pool, then execute Do with this conn, finally close this conn.
Send(commandName string, args ...interface{}) error // ATTENTION: Don't use this method with transaction command like MULTI etc. Because every Do will close conn automatically, use r.Conn to get a raw conn for this situation.
func (r *Redis) Do(ctx context.Context, commandName string, args ...interface{}) (reply interface{}, err error) {
conn := r.pool.Get(ctx)
defer conn.Close()
reply, err = conn.Do(commandName, args...)
return
}
// Flush flushes the output buffer to the Redis server. // Close closes connection pool
Flush() error func (r *Redis) Close() error {
return r.pool.Close()
}
// Receive receives a single reply from the Redis server // Conn direct gets a connection
Receive() (reply interface{}, err error) func (r *Redis) Conn(ctx context.Context) Conn {
return r.pool.Get(ctx)
}
// WithContext func (r *Redis) Pipeline() (p Pipeliner) {
WithContext(ctx context.Context) Conn return &pipeliner{
pool: r.pool,
}
} }

@ -0,0 +1,324 @@
package redis
import (
"context"
"reflect"
"testing"
"time"
"github.com/bilibili/kratos/pkg/container/pool"
xtime "github.com/bilibili/kratos/pkg/time"
)
func TestRedis(t *testing.T) {
testSet(t, testPool)
testSend(t, testPool)
testGet(t, testPool)
testErr(t, testPool)
if err := testPool.Close(); err != nil {
t.Errorf("redis: close error(%v)", err)
}
conn, err := NewConn(testConfig)
if err != nil {
t.Errorf("redis: new conn error(%v)", err)
}
if err := conn.Close(); err != nil {
t.Errorf("redis: close error(%v)", err)
}
}
func testSet(t *testing.T, p *Pool) {
var (
key = "test"
value = "test"
conn = p.Get(context.TODO())
)
defer conn.Close()
if reply, err := conn.Do("set", key, value); err != nil {
t.Errorf("redis: conn.Do(SET, %s, %s) error(%v)", key, value, err)
} else {
t.Logf("redis: set status: %s", reply)
}
}
func testSend(t *testing.T, p *Pool) {
var (
key = "test"
value = "test"
expire = 1000
conn = p.Get(context.TODO())
)
defer conn.Close()
if err := conn.Send("SET", key, value); err != nil {
t.Errorf("redis: conn.Send(SET, %s, %s) error(%v)", key, value, err)
}
if err := conn.Send("EXPIRE", key, expire); err != nil {
t.Errorf("redis: conn.Send(EXPIRE key(%s) expire(%d)) error(%v)", key, expire, err)
}
if err := conn.Flush(); err != nil {
t.Errorf("redis: conn.Flush error(%v)", err)
}
for i := 0; i < 2; i++ {
if _, err := conn.Receive(); err != nil {
t.Errorf("redis: conn.Receive error(%v)", err)
return
}
}
t.Logf("redis: set value: %s", value)
}
func testGet(t *testing.T, p *Pool) {
var (
key = "test"
conn = p.Get(context.TODO())
)
defer conn.Close()
if reply, err := conn.Do("GET", key); err != nil {
t.Errorf("redis: conn.Do(GET, %s) error(%v)", key, err)
} else {
t.Logf("redis: get value: %s", reply)
}
}
func testErr(t *testing.T, p *Pool) {
conn := p.Get(context.TODO())
if err := conn.Close(); err != nil {
t.Errorf("redis: close error(%v)", err)
}
if err := conn.Err(); err == nil {
t.Errorf("redis: err not nil")
} else {
t.Logf("redis: err: %v", err)
}
}
func BenchmarkRedis(b *testing.B) {
conf := &Config{
Name: "test",
Proto: "tcp",
Addr: testRedisAddr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}
conf.Config = &pool.Config{
Active: 10,
Idle: 5,
IdleTimeout: xtime.Duration(90 * time.Second),
}
benchmarkPool := NewPool(conf)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
conn := benchmarkPool.Get(context.TODO())
if err := conn.Close(); err != nil {
b.Errorf("redis: close error(%v)", err)
}
}
})
if err := benchmarkPool.Close(); err != nil {
b.Errorf("redis: close error(%v)", err)
}
}
var testRedisCommands = []struct {
args []interface{}
expected interface{}
}{
{
[]interface{}{"PING"},
"PONG",
},
{
[]interface{}{"SET", "foo", "bar"},
"OK",
},
{
[]interface{}{"GET", "foo"},
[]byte("bar"),
},
{
[]interface{}{"GET", "nokey"},
nil,
},
{
[]interface{}{"MGET", "nokey", "foo"},
[]interface{}{nil, []byte("bar")},
},
{
[]interface{}{"INCR", "mycounter"},
int64(1),
},
{
[]interface{}{"LPUSH", "mylist", "foo"},
int64(1),
},
{
[]interface{}{"LPUSH", "mylist", "bar"},
int64(2),
},
{
[]interface{}{"LRANGE", "mylist", 0, -1},
[]interface{}{[]byte("bar"), []byte("foo")},
},
}
func TestNewRedis(t *testing.T) {
type args struct {
c *Config
options []DialOption
}
tests := []struct {
name string
args args
wantErr bool
}{
{
"new_redis",
args{
testConfig,
make([]DialOption, 0),
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := NewRedis(tt.args.c, tt.args.options...)
if r == nil {
t.Errorf("NewRedis() error, got nil")
return
}
err := r.Close()
if err != nil {
t.Errorf("Close() error %v", err)
}
})
}
}
func TestRedis_Do(t *testing.T) {
r := NewRedis(testConfig)
r.Do(context.TODO(), "FLUSHDB")
for _, cmd := range testRedisCommands {
actual, err := r.Do(context.TODO(), cmd.args[0].(string), cmd.args[1:]...)
if err != nil {
t.Errorf("Do(%v) returned error %v", cmd.args, err)
continue
}
if !reflect.DeepEqual(actual, cmd.expected) {
t.Errorf("Do(%v) = %v, want %v", cmd.args, actual, cmd.expected)
}
}
err := r.Close()
if err != nil {
t.Errorf("Close() error %v", err)
}
}
func TestRedis_Conn(t *testing.T) {
type args struct {
ctx context.Context
}
tests := []struct {
name string
p *Redis
args args
wantErr bool
g int
c int
}{
{
"Close",
NewRedis(&Config{
Config: &pool.Config{
Active: 1,
Idle: 1,
},
Name: "test_get",
Proto: "tcp",
Addr: testRedisAddr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}),
args{context.TODO()},
false,
3,
3,
},
{
"CloseExceededPoolSize",
NewRedis(&Config{
Config: &pool.Config{
Active: 1,
Idle: 1,
},
Name: "test_get_out",
Proto: "tcp",
Addr: testRedisAddr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}),
args{context.TODO()},
true,
5,
3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for i := 1; i <= tt.g; i++ {
got := tt.p.Conn(tt.args.ctx)
if err := got.Close(); err != nil {
if !tt.wantErr {
t.Error(err)
}
}
if i <= tt.c {
if err := got.Close(); err != nil {
t.Error(err)
}
}
}
})
}
}
func BenchmarkRedisDoPing(b *testing.B) {
r := NewRedis(testConfig)
defer r.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := r.Do(context.Background(), "PING"); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkRedisDoSET(b *testing.B) {
r := NewRedis(testConfig)
defer r.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := r.Do(context.Background(), "SET", "a", "b"); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkRedisDoGET(b *testing.B) {
r := NewRedis(testConfig)
defer r.Close()
r.Do(context.Background(), "SET", "a", "b")
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := r.Do(context.Background(), "GET", "b"); err != nil {
b.Fatal(err)
}
}
}

@ -0,0 +1,179 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"fmt"
"reflect"
"testing"
"github.com/pkg/errors"
)
type valueError struct {
v interface{}
err error
}
func ve(v interface{}, err error) valueError {
return valueError{v, err}
}
var replyTests = []struct {
name interface{}
actual valueError
expected valueError
}{
{
"ints([v1, v2])",
ve(Ints([]interface{}{[]byte("4"), []byte("5")}, nil)),
ve([]int{4, 5}, nil),
},
{
"ints(nil)",
ve(Ints(nil, nil)),
ve([]int(nil), ErrNil),
},
{
"strings([v1, v2])",
ve(Strings([]interface{}{[]byte("v1"), []byte("v2")}, nil)),
ve([]string{"v1", "v2"}, nil),
},
{
"strings(nil)",
ve(Strings(nil, nil)),
ve([]string(nil), ErrNil),
},
{
"byteslices([v1, v2])",
ve(ByteSlices([]interface{}{[]byte("v1"), []byte("v2")}, nil)),
ve([][]byte{[]byte("v1"), []byte("v2")}, nil),
},
{
"byteslices(nil)",
ve(ByteSlices(nil, nil)),
ve([][]byte(nil), ErrNil),
},
{
"values([v1, v2])",
ve(Values([]interface{}{[]byte("v1"), []byte("v2")}, nil)),
ve([]interface{}{[]byte("v1"), []byte("v2")}, nil),
},
{
"values(nil)",
ve(Values(nil, nil)),
ve([]interface{}(nil), ErrNil),
},
{
"float64(1.0)",
ve(Float64([]byte("1.0"), nil)),
ve(float64(1.0), nil),
},
{
"float64(nil)",
ve(Float64(nil, nil)),
ve(float64(0.0), ErrNil),
},
{
"uint64(1)",
ve(Uint64(int64(1), nil)),
ve(uint64(1), nil),
},
{
"uint64(-1)",
ve(Uint64(int64(-1), nil)),
ve(uint64(0), errNegativeInt),
},
}
func TestReply(t *testing.T) {
for _, rt := range replyTests {
if errors.Cause(rt.actual.err) != rt.expected.err {
t.Errorf("%s returned err %v, want %v", rt.name, rt.actual.err, rt.expected.err)
continue
}
if !reflect.DeepEqual(rt.actual.v, rt.expected.v) {
t.Errorf("%s=%+v, want %+v", rt.name, rt.actual.v, rt.expected.v)
}
}
}
// dial wraps DialDefaultServer() with a more suitable function name for examples.
func dial() (Conn, error) {
return DialDefaultServer()
}
func ExampleBool() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("SET", "foo", 1)
exists, _ := Bool(c.Do("EXISTS", "foo"))
fmt.Printf("%#v\n", exists)
// Output:
// true
}
func ExampleInt() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("SET", "k1", 1)
n, _ := Int(c.Do("GET", "k1"))
fmt.Printf("%#v\n", n)
n, _ = Int(c.Do("INCR", "k1"))
fmt.Printf("%#v\n", n)
// Output:
// 1
// 2
}
func ExampleInts() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("SADD", "set_with_integers", 4, 5, 6)
ints, _ := Ints(c.Do("SMEMBERS", "set_with_integers"))
fmt.Printf("%#v\n", ints)
// Output:
// []int{4, 5, 6}
}
func ExampleString() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("SET", "hello", "world")
s, _ := String(c.Do("GET", "hello"))
fmt.Printf("%#v", s)
// Output:
// "world"
}

@ -0,0 +1,438 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"fmt"
"math"
"reflect"
"testing"
)
var scanConversionTests = []struct {
src interface{}
dest interface{}
}{
{[]byte("-inf"), math.Inf(-1)},
{[]byte("+inf"), math.Inf(1)},
{[]byte("0"), float64(0)},
{[]byte("3.14159"), float64(3.14159)},
{[]byte("3.14"), float32(3.14)},
{[]byte("-100"), int(-100)},
{[]byte("101"), int(101)},
{int64(102), int(102)},
{[]byte("103"), uint(103)},
{int64(104), uint(104)},
{[]byte("105"), int8(105)},
{int64(106), int8(106)},
{[]byte("107"), uint8(107)},
{int64(108), uint8(108)},
{[]byte("0"), false},
{int64(0), false},
{[]byte("f"), false},
{[]byte("1"), true},
{int64(1), true},
{[]byte("t"), true},
{"hello", "hello"},
{[]byte("hello"), "hello"},
{[]byte("world"), []byte("world")},
{[]interface{}{[]byte("foo")}, []interface{}{[]byte("foo")}},
{[]interface{}{[]byte("foo")}, []string{"foo"}},
{[]interface{}{[]byte("hello"), []byte("world")}, []string{"hello", "world"}},
{[]interface{}{[]byte("bar")}, [][]byte{[]byte("bar")}},
{[]interface{}{[]byte("1")}, []int{1}},
{[]interface{}{[]byte("1"), []byte("2")}, []int{1, 2}},
{[]interface{}{[]byte("1"), []byte("2")}, []float64{1, 2}},
{[]interface{}{[]byte("1")}, []byte{1}},
{[]interface{}{[]byte("1")}, []bool{true}},
}
func TestScanConversion(t *testing.T) {
for _, tt := range scanConversionTests {
values := []interface{}{tt.src}
dest := reflect.New(reflect.TypeOf(tt.dest))
values, err := Scan(values, dest.Interface())
if err != nil {
t.Errorf("Scan(%v) returned error %v", tt, err)
continue
}
if !reflect.DeepEqual(tt.dest, dest.Elem().Interface()) {
t.Errorf("Scan(%v) returned %v values: %v, want %v", tt, dest.Elem().Interface(), values, tt.dest)
}
}
}
var scanConversionErrorTests = []struct {
src interface{}
dest interface{}
}{
{[]byte("1234"), byte(0)},
{int64(1234), byte(0)},
{[]byte("-1"), byte(0)},
{int64(-1), byte(0)},
{[]byte("junk"), false},
{Error("blah"), false},
}
func TestScanConversionError(t *testing.T) {
for _, tt := range scanConversionErrorTests {
values := []interface{}{tt.src}
dest := reflect.New(reflect.TypeOf(tt.dest))
values, err := Scan(values, dest.Interface())
if err == nil {
t.Errorf("Scan(%v) did not return error values: %v", tt, values)
}
}
}
func ExampleScan() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Send("HMSET", "album:1", "title", "Red", "rating", 5)
c.Send("HMSET", "album:2", "title", "Earthbound", "rating", 1)
c.Send("HMSET", "album:3", "title", "Beat")
c.Send("LPUSH", "albums", "1")
c.Send("LPUSH", "albums", "2")
c.Send("LPUSH", "albums", "3")
values, err := Values(c.Do("SORT", "albums",
"BY", "album:*->rating",
"GET", "album:*->title",
"GET", "album:*->rating"))
if err != nil {
fmt.Println(err)
return
}
for len(values) > 0 {
var title string
rating := -1 // initialize to illegal value to detect nil.
values, err = Scan(values, &title, &rating)
if err != nil {
fmt.Println(err)
return
}
if rating == -1 {
fmt.Println(title, "not-rated")
} else {
fmt.Println(title, rating)
}
}
// Output:
// Beat not-rated
// Earthbound 1
// Red 5
}
type s0 struct {
X int
Y int `redis:"y"`
Bt bool
}
type s1 struct {
X int `redis:"-"`
I int `redis:"i"`
U uint `redis:"u"`
S string `redis:"s"`
P []byte `redis:"p"`
B bool `redis:"b"`
Bt bool
Bf bool
s0
}
var scanStructTests = []struct {
title string
reply []string
value interface{}
}{
{"basic",
[]string{"i", "-1234", "u", "5678", "s", "hello", "p", "world", "b", "t", "Bt", "1", "Bf", "0", "X", "123", "y", "456"},
&s1{I: -1234, U: 5678, S: "hello", P: []byte("world"), B: true, Bt: true, Bf: false, s0: s0{X: 123, Y: 456}},
},
}
func TestScanStruct(t *testing.T) {
for _, tt := range scanStructTests {
var reply []interface{}
for _, v := range tt.reply {
reply = append(reply, []byte(v))
}
value := reflect.New(reflect.ValueOf(tt.value).Type().Elem())
if err := ScanStruct(reply, value.Interface()); err != nil {
t.Fatalf("ScanStruct(%s) returned error %v", tt.title, err)
}
if !reflect.DeepEqual(value.Interface(), tt.value) {
t.Fatalf("ScanStruct(%s) returned %v, want %v", tt.title, value.Interface(), tt.value)
}
}
}
func TestBadScanStructArgs(t *testing.T) {
x := []interface{}{"A", "b"}
test := func(v interface{}) {
if err := ScanStruct(x, v); err == nil {
t.Errorf("Expect error for ScanStruct(%T, %T)", x, v)
}
}
test(nil)
var v0 *struct{}
test(v0)
var v1 int
test(&v1)
x = x[:1]
v2 := struct{ A string }{}
test(&v2)
}
var scanSliceTests = []struct {
src []interface{}
fieldNames []string
ok bool
dest interface{}
}{
{
[]interface{}{[]byte("1"), nil, []byte("-1")},
nil,
true,
[]int{1, 0, -1},
},
{
[]interface{}{[]byte("1"), nil, []byte("2")},
nil,
true,
[]uint{1, 0, 2},
},
{
[]interface{}{[]byte("-1")},
nil,
false,
[]uint{1},
},
{
[]interface{}{[]byte("hello"), nil, []byte("world")},
nil,
true,
[][]byte{[]byte("hello"), nil, []byte("world")},
},
{
[]interface{}{[]byte("hello"), nil, []byte("world")},
nil,
true,
[]string{"hello", "", "world"},
},
{
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
nil,
true,
[]struct{ A, B string }{{"a1", "b1"}, {"a2", "b2"}},
},
{
[]interface{}{[]byte("a1"), []byte("b1")},
nil,
false,
[]struct{ A, B, C string }{{"a1", "b1", ""}},
},
{
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
nil,
true,
[]*struct{ A, B string }{{"a1", "b1"}, {"a2", "b2"}},
},
{
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
[]string{"A", "B"},
true,
[]struct{ A, C, B string }{{"a1", "", "b1"}, {"a2", "", "b2"}},
},
{
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
nil,
false,
[]struct{}{},
},
}
func TestScanSlice(t *testing.T) {
for _, tt := range scanSliceTests {
typ := reflect.ValueOf(tt.dest).Type()
dest := reflect.New(typ)
err := ScanSlice(tt.src, dest.Interface(), tt.fieldNames...)
if tt.ok != (err == nil) {
t.Errorf("ScanSlice(%v, []%s, %v) returned error %v", tt.src, typ, tt.fieldNames, err)
continue
}
if tt.ok && !reflect.DeepEqual(dest.Elem().Interface(), tt.dest) {
t.Errorf("ScanSlice(src, []%s) returned %#v, want %#v", typ, dest.Elem().Interface(), tt.dest)
}
}
}
func ExampleScanSlice() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Send("HMSET", "album:1", "title", "Red", "rating", 5)
c.Send("HMSET", "album:2", "title", "Earthbound", "rating", 1)
c.Send("HMSET", "album:3", "title", "Beat", "rating", 4)
c.Send("LPUSH", "albums", "1")
c.Send("LPUSH", "albums", "2")
c.Send("LPUSH", "albums", "3")
values, err := Values(c.Do("SORT", "albums",
"BY", "album:*->rating",
"GET", "album:*->title",
"GET", "album:*->rating"))
if err != nil {
fmt.Println(err)
return
}
var albums []struct {
Title string
Rating int
}
if err := ScanSlice(values, &albums); err != nil {
fmt.Println(err)
return
}
fmt.Printf("%v\n", albums)
// Output:
// [{Earthbound 1} {Beat 4} {Red 5}]
}
var argsTests = []struct {
title string
actual Args
expected Args
}{
{"struct ptr",
Args{}.AddFlat(&struct {
I int `redis:"i"`
U uint `redis:"u"`
S string `redis:"s"`
P []byte `redis:"p"`
M map[string]string `redis:"m"`
Bt bool
Bf bool
}{
-1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false,
}),
Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "m", map[string]string{"hello": "world"}, "Bt", true, "Bf", false},
},
{"struct",
Args{}.AddFlat(struct{ I int }{123}),
Args{"I", 123},
},
{"slice",
Args{}.Add(1).AddFlat([]string{"a", "b", "c"}).Add(2),
Args{1, "a", "b", "c", 2},
},
{"struct omitempty",
Args{}.AddFlat(&struct {
I int `redis:"i,omitempty"`
U uint `redis:"u,omitempty"`
S string `redis:"s,omitempty"`
P []byte `redis:"p,omitempty"`
M map[string]string `redis:"m,omitempty"`
Bt bool `redis:"Bt,omitempty"`
Bf bool `redis:"Bf,omitempty"`
}{
0, 0, "", []byte{}, map[string]string{}, true, false,
}),
Args{"Bt", true},
},
}
func TestArgs(t *testing.T) {
for _, tt := range argsTests {
if !reflect.DeepEqual(tt.actual, tt.expected) {
t.Fatalf("%s is %v, want %v", tt.title, tt.actual, tt.expected)
}
}
}
func ExampleArgs() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
var p1, p2 struct {
Title string `redis:"title"`
Author string `redis:"author"`
Body string `redis:"body"`
}
p1.Title = "Example"
p1.Author = "Gary"
p1.Body = "Hello"
if _, err := c.Do("HMSET", Args{}.Add("id1").AddFlat(&p1)...); err != nil {
fmt.Println(err)
return
}
m := map[string]string{
"title": "Example2",
"author": "Steve",
"body": "Map",
}
if _, err := c.Do("HMSET", Args{}.Add("id2").AddFlat(m)...); err != nil {
fmt.Println(err)
return
}
for _, id := range []string{"id1", "id2"} {
v, err := Values(c.Do("HGETALL", id))
if err != nil {
fmt.Println(err)
return
}
if err := ScanStruct(v, &p2); err != nil {
fmt.Println(err)
return
}
fmt.Printf("%+v\n", p2)
}
// Output:
// {Title:Example Author:Gary Body:Hello}
// {Title:Example2 Author:Steve Body:Map}
}

@ -0,0 +1,103 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"fmt"
"reflect"
"testing"
"time"
)
func ExampleScript() {
c, err := Dial("tcp", ":6379")
if err != nil {
// handle error
}
defer c.Close()
// Initialize a package-level variable with a script.
var getScript = NewScript(1, `return call('get', KEYS[1])`)
// In a function, use the script Do method to evaluate the script. The Do
// method optimistically uses the EVALSHA command. If the script is not
// loaded, then the Do method falls back to the EVAL command.
if _, err = getScript.Do(c, "foo"); err != nil {
// handle error
}
}
func TestScript(t *testing.T) {
c, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
// To test fall back in Do, we make script unique by adding comment with current time.
script := fmt.Sprintf("--%d\nreturn {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", time.Now().UnixNano())
s := NewScript(2, script)
reply := []interface{}{[]byte("key1"), []byte("key2"), []byte("arg1"), []byte("arg2")}
v, err := s.Do(c, "key1", "key2", "arg1", "arg2")
if err != nil {
t.Errorf("s.Do(c, ...) returned %v", err)
}
if !reflect.DeepEqual(v, reply) {
t.Errorf("s.Do(c, ..); = %v, want %v", v, reply)
}
err = s.Load(c)
if err != nil {
t.Errorf("s.Load(c) returned %v", err)
}
err = s.SendHash(c, "key1", "key2", "arg1", "arg2")
if err != nil {
t.Errorf("s.SendHash(c, ...) returned %v", err)
}
err = c.Flush()
if err != nil {
t.Errorf("c.Flush() returned %v", err)
}
v, err = c.Receive()
if err != nil {
t.Errorf("c.Receive() returned %v", err)
}
if !reflect.DeepEqual(v, reply) {
t.Errorf("s.SendHash(c, ..); c.Receive() = %v, want %v", v, reply)
}
err = s.Send(c, "key1", "key2", "arg1", "arg2")
if err != nil {
t.Errorf("s.Send(c, ...) returned %v", err)
}
err = c.Flush()
if err != nil {
t.Errorf("c.Flush() returned %v", err)
}
v, err = c.Receive()
if err != nil {
t.Errorf("c.Receive() returned %v", err)
}
if !reflect.DeepEqual(v, reply) {
t.Errorf("s.Send(c, ..); c.Receive() = %v, want %v", v, reply)
}
}

@ -0,0 +1,12 @@
version: "3.7"
services:
redis:
image: redis
ports:
- 6379:6379
healthcheck:
test: ["CMD", "redis-cli","ping"]
interval: 20s
timeout: 1s
retries: 20

@ -10,10 +10,9 @@ import (
) )
const ( const (
_traceComponentName = "pkg/cache/redis" _traceComponentName = "library/cache/redis"
_tracePeerService = "redis" _tracePeerService = "redis"
_traceSpanKind = "client" _traceSpanKind = "client"
_slowLogDuration = time.Millisecond * 250
) )
var _internalTags = []trace.Tag{ var _internalTags = []trace.Tag{
@ -25,25 +24,27 @@ var _internalTags = []trace.Tag{
type traceConn struct { type traceConn struct {
// tr for pipeline, if tr != nil meaning on pipeline // tr for pipeline, if tr != nil meaning on pipeline
tr trace.Trace tr trace.Trace
ctx context.Context
// connTag include e.g. ip,port // connTag include e.g. ip,port
connTags []trace.Tag connTags []trace.Tag
ctx context.Context
// origin redis conn // origin redis conn
Conn Conn
pending int pending int
// TODO: split slow log from trace.
slowLogThreshold time.Duration
} }
func (t *traceConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) { func (t *traceConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
statement := getStatement(commandName, args...) statement := getStatement(commandName, args...)
defer slowLog(statement, time.Now()) defer t.slowLog(statement, time.Now())
root, ok := trace.FromContext(t.ctx)
// NOTE: ignored empty commandName // NOTE: ignored empty commandName
// current sdk will Do empty command after pipeline finished // current sdk will Do empty command after pipeline finished
if !ok || commandName == "" { if t.tr == nil || commandName == "" {
return t.Conn.Do(commandName, args...) return t.Conn.Do(commandName, args...)
} }
tr := root.Fork("", "Redis:"+commandName) tr := t.tr.Fork("", "Redis:"+commandName)
tr.SetTag(_internalTags...) tr.SetTag(_internalTags...)
tr.SetTag(t.connTags...) tr.SetTag(t.connTags...)
tr.SetTag(trace.TagString(trace.TagDBStatement, statement)) tr.SetTag(trace.TagString(trace.TagDBStatement, statement))
@ -52,16 +53,15 @@ func (t *traceConn) Do(commandName string, args ...interface{}) (reply interface
return return
} }
func (t *traceConn) Send(commandName string, args ...interface{}) error { func (t *traceConn) Send(commandName string, args ...interface{}) (err error) {
statement := getStatement(commandName, args...) statement := getStatement(commandName, args...)
defer slowLog(statement, time.Now()) defer t.slowLog(statement, time.Now())
t.pending++ t.pending++
root, ok := trace.FromContext(t.ctx) if t.tr == nil {
if !ok {
return t.Conn.Send(commandName, args...) return t.Conn.Send(commandName, args...)
} }
if t.tr == nil { if t.pending == 1 {
t.tr = root.Fork("", "Redis:Pipeline") t.tr = t.tr.Fork("", "Redis:Pipeline")
t.tr.SetTag(_internalTags...) t.tr.SetTag(_internalTags...)
t.tr.SetTag(t.connTags...) t.tr.SetTag(t.connTags...)
} }
@ -69,8 +69,7 @@ func (t *traceConn) Send(commandName string, args ...interface{}) error {
trace.Log(trace.LogEvent, "Send"), trace.Log(trace.LogEvent, "Send"),
trace.Log("db.statement", statement), trace.Log("db.statement", statement),
) )
err := t.Conn.Send(commandName, args...) if err = t.Conn.Send(commandName, args...); err != nil {
if err != nil {
t.tr.SetTag(trace.TagBool(trace.TagError, true)) t.tr.SetTag(trace.TagBool(trace.TagError, true))
t.tr.SetLog( t.tr.SetLog(
trace.Log(trace.LogEvent, "Send Fail"), trace.Log(trace.LogEvent, "Send Fail"),
@ -81,7 +80,7 @@ func (t *traceConn) Send(commandName string, args ...interface{}) error {
} }
func (t *traceConn) Flush() error { func (t *traceConn) Flush() error {
defer slowLog("Flush", time.Now()) defer t.slowLog("Flush", time.Now())
if t.tr == nil { if t.tr == nil {
return t.Conn.Flush() return t.Conn.Flush()
} }
@ -98,7 +97,7 @@ func (t *traceConn) Flush() error {
} }
func (t *traceConn) Receive() (reply interface{}, err error) { func (t *traceConn) Receive() (reply interface{}, err error) {
defer slowLog("Receive", time.Now()) defer t.slowLog("Receive", time.Now())
if t.tr == nil { if t.tr == nil {
return t.Conn.Receive() return t.Conn.Receive()
} }
@ -122,13 +121,16 @@ func (t *traceConn) Receive() (reply interface{}, err error) {
} }
func (t *traceConn) WithContext(ctx context.Context) Conn { func (t *traceConn) WithContext(ctx context.Context) Conn {
t.ctx = ctx t.Conn = t.Conn.WithContext(ctx)
if root, ok := trace.FromContext(ctx); ok {
t.tr = root
}
return t return t
} }
func slowLog(statement string, now time.Time) { func (t *traceConn) slowLog(statement string, now time.Time) {
du := time.Since(now) du := time.Since(now)
if du > _slowLogDuration { if du > t.slowLogThreshold {
log.Warn("%s slow log statement: %s time: %v", _tracePeerService, statement, du) log.Warn("%s slow log statement: %s time: %v", _tracePeerService, statement, du)
} }
} }

@ -0,0 +1,192 @@
package redis
import (
"context"
"fmt"
"testing"
"time"
"github.com/bilibili/kratos/pkg/net/trace"
"github.com/stretchr/testify/assert"
)
const testTraceSlowLogThreshold = time.Duration(250 * time.Millisecond)
type mockTrace struct {
tags []trace.Tag
logs []trace.LogField
perr *error
operationName string
finished bool
}
func (m *mockTrace) Fork(serviceName string, operationName string) trace.Trace {
m.operationName = operationName
return m
}
func (m *mockTrace) Follow(serviceName string, operationName string) trace.Trace {
panic("not implemented")
}
func (m *mockTrace) Finish(err *error) {
m.perr = err
m.finished = true
}
func (m *mockTrace) SetTag(tags ...trace.Tag) trace.Trace {
m.tags = append(m.tags, tags...)
return m
}
func (m *mockTrace) SetLog(logs ...trace.LogField) trace.Trace {
m.logs = append(m.logs, logs...)
return m
}
func (m *mockTrace) Visit(fn func(k, v string)) {}
func (m *mockTrace) SetTitle(title string) {}
func (m *mockTrace) TraceID() string { return "" }
type mockConn struct{}
func (c *mockConn) Close() error { return nil }
func (c *mockConn) Err() error { return nil }
func (c *mockConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
return nil, nil
}
func (c *mockConn) Send(commandName string, args ...interface{}) error { return nil }
func (c *mockConn) Flush() error { return nil }
func (c *mockConn) Receive() (reply interface{}, err error) { return nil, nil }
func (c *mockConn) WithContext(context.Context) Conn { return c }
func TestTraceDo(t *testing.T) {
tr := &mockTrace{}
ctx := trace.NewContext(context.Background(), tr)
tc := &traceConn{Conn: &mockConn{}, slowLogThreshold: testTraceSlowLogThreshold}
conn := tc.WithContext(ctx)
conn.Do("GET", "test")
assert.Equal(t, "Redis:GET", tr.operationName)
assert.NotEmpty(t, tr.tags)
assert.True(t, tr.finished)
}
func TestTraceDoErr(t *testing.T) {
tr := &mockTrace{}
ctx := trace.NewContext(context.Background(), tr)
tc := &traceConn{Conn: MockErr{Error: fmt.Errorf("hhhhhhh")},
slowLogThreshold: testTraceSlowLogThreshold}
conn := tc.WithContext(ctx)
conn.Do("GET", "test")
assert.Equal(t, "Redis:GET", tr.operationName)
assert.True(t, tr.finished)
assert.NotNil(t, *tr.perr)
}
func TestTracePipeline(t *testing.T) {
tr := &mockTrace{}
ctx := trace.NewContext(context.Background(), tr)
tc := &traceConn{Conn: &mockConn{}, slowLogThreshold: testTraceSlowLogThreshold}
conn := tc.WithContext(ctx)
N := 2
for i := 0; i < N; i++ {
conn.Send("GET", "hello, world")
}
conn.Flush()
for i := 0; i < N; i++ {
conn.Receive()
}
assert.Equal(t, "Redis:Pipeline", tr.operationName)
assert.NotEmpty(t, tr.tags)
assert.NotEmpty(t, tr.logs)
assert.True(t, tr.finished)
}
func TestTracePipelineErr(t *testing.T) {
tr := &mockTrace{}
ctx := trace.NewContext(context.Background(), tr)
tc := &traceConn{Conn: MockErr{Error: fmt.Errorf("hahah")},
slowLogThreshold: testTraceSlowLogThreshold}
conn := tc.WithContext(ctx)
N := 2
for i := 0; i < N; i++ {
conn.Send("GET", "hello, world")
}
conn.Flush()
for i := 0; i < N; i++ {
conn.Receive()
}
assert.Equal(t, "Redis:Pipeline", tr.operationName)
assert.NotEmpty(t, tr.tags)
assert.NotEmpty(t, tr.logs)
assert.True(t, tr.finished)
var isError bool
for _, tag := range tr.tags {
if tag.Key == "error" {
isError = true
}
}
assert.True(t, isError)
}
func TestSendStatement(t *testing.T) {
tr := &mockTrace{}
ctx := trace.NewContext(context.Background(), tr)
tc := &traceConn{Conn: MockErr{Error: fmt.Errorf("hahah")},
slowLogThreshold: testTraceSlowLogThreshold}
conn := tc.WithContext(ctx)
conn.Send("SET", "hello", "test")
conn.Flush()
conn.Receive()
assert.Equal(t, "Redis:Pipeline", tr.operationName)
assert.NotEmpty(t, tr.tags)
assert.NotEmpty(t, tr.logs)
assert.Equal(t, "event", tr.logs[0].Key)
assert.Equal(t, "Send", tr.logs[0].Value)
assert.Equal(t, "db.statement", tr.logs[1].Key)
assert.Equal(t, "SET hello", tr.logs[1].Value)
assert.True(t, tr.finished)
var isError bool
for _, tag := range tr.tags {
if tag.Key == "error" {
isError = true
}
}
assert.True(t, isError)
}
func TestDoStatement(t *testing.T) {
tr := &mockTrace{}
ctx := trace.NewContext(context.Background(), tr)
tc := &traceConn{Conn: MockErr{Error: fmt.Errorf("hahah")},
slowLogThreshold: testTraceSlowLogThreshold}
conn := tc.WithContext(ctx)
conn.Do("SET", "hello", "test")
assert.Equal(t, "Redis:SET", tr.operationName)
assert.Equal(t, "SET hello", tr.tags[len(tr.tags)-1].Value)
assert.True(t, tr.finished)
}
func BenchmarkTraceConn(b *testing.B) {
for i := 0; i < b.N; i++ {
c, err := DialDefaultServer()
if err != nil {
b.Fatal(err)
}
t := &traceConn{
Conn: c,
connTags: []trace.Tag{trace.TagString(trace.TagPeerAddress, "abc")},
slowLogThreshold: time.Duration(1 * time.Second),
}
c2 := t.WithContext(context.TODO())
if _, err := c2.Do("PING"); err != nil {
b.Fatal(err)
}
c2.Close()
}
}

@ -0,0 +1,17 @@
package redis
import (
"context"
"time"
)
func shrinkDeadline(ctx context.Context, timeout time.Duration) time.Time {
var timeoutTime = time.Now().Add(timeout)
if ctx == nil {
return timeoutTime
}
if deadline, ok := ctx.Deadline(); ok && timeoutTime.After(deadline) {
return deadline
}
return timeoutTime
}

@ -0,0 +1,37 @@
package redis
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestShrinkDeadline(t *testing.T) {
t.Run("test not deadline", func(t *testing.T) {
timeout := time.Second
timeoutTime := time.Now().Add(timeout)
tm := shrinkDeadline(context.Background(), timeout)
assert.True(t, tm.After(timeoutTime))
})
t.Run("test big deadline", func(t *testing.T) {
timeout := time.Second
timeoutTime := time.Now().Add(timeout)
deadlineTime := time.Now().Add(2 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
tm := shrinkDeadline(ctx, timeout)
assert.True(t, tm.After(timeoutTime) && tm.Before(deadlineTime))
})
t.Run("test small deadline", func(t *testing.T) {
timeout := time.Second
deadlineTime := time.Now().Add(500 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
tm := shrinkDeadline(ctx, timeout)
assert.True(t, tm.After(deadlineTime) && tm.Before(time.Now().Add(timeout)))
})
}
Loading…
Cancel
Save