// Copyright 2012 Gary Burd // // Licensed under the Apache License, Version 2.0 (the "License"): you may // not use this file except in compliance with the License. You may obtain // a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the // License for the specific language governing permissions and limitations // under the License. package redis import ( "bytes" "context" "crypto/rand" "crypto/sha1" "errors" "io" "strconv" "sync" "time" "github.com/bilibili/kratos/pkg/container/pool" "github.com/bilibili/kratos/pkg/net/trace" xtime "github.com/bilibili/kratos/pkg/time" ) var beginTime, _ = time.Parse("2006-01-02 15:04:05", "2006-01-02 15:04:05") var ( errConnClosed = errors.New("redigo: connection closed") ) // Pool . type Pool struct { *pool.Slice // config c *Config // statfunc statfunc func(name, addr, cmd string, t time.Time, err error) func() } // Config client settings. type Config struct { *pool.Config Name string // redis name, for trace Proto string Addr string Auth string DialTimeout xtime.Duration ReadTimeout xtime.Duration WriteTimeout xtime.Duration } // NewPool creates a new pool. func NewPool(c *Config, options ...DialOption) (p *Pool) { if c.DialTimeout <= 0 || c.ReadTimeout <= 0 || c.WriteTimeout <= 0 { panic("must config redis timeout") } ops := []DialOption{ DialConnectTimeout(time.Duration(c.DialTimeout)), DialReadTimeout(time.Duration(c.ReadTimeout)), DialWriteTimeout(time.Duration(c.WriteTimeout)), DialPassword(c.Auth), } ops = append(ops, options...) p1 := pool.NewSlice(c.Config) p1.New = func(ctx context.Context) (io.Closer, error) { conn, err := Dial(c.Proto, c.Addr, ops...) if err != nil { return nil, err } return &traceConn{Conn: conn, connTags: []trace.Tag{trace.TagString(trace.TagPeerAddress, c.Addr)}}, nil } p = &Pool{Slice: p1, c: c, statfunc: pstat} return } // Get gets a connection. The application must close the returned connection. // This method always returns a valid connection so that applications can defer // error handling to the first use of the connection. If there is an error // getting an underlying connection, then the connection Err, Do, Send, Flush // and Receive methods return that error. func (p *Pool) Get(ctx context.Context) Conn { c, err := p.Slice.Get(ctx) if err != nil { return errorConnection{err} } c1, _ := c.(Conn) return &pooledConnection{p: p, c: c1.WithContext(ctx), ctx: ctx, now: beginTime} } // Close releases the resources used by the pool. func (p *Pool) Close() error { return p.Slice.Close() } type pooledConnection struct { p *Pool c Conn state int now time.Time cmds []string ctx context.Context } var ( sentinel []byte sentinelOnce sync.Once ) func initSentinel() { p := make([]byte, 64) if _, err := rand.Read(p); err == nil { sentinel = p } else { h := sha1.New() io.WriteString(h, "Oops, rand failed. Use time instead.") io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10)) sentinel = h.Sum(nil) } } // SetStatFunc set stat func. func (p *Pool) SetStatFunc(fn func(name, addr, cmd string, t time.Time, err error) func()) { p.statfunc = fn } func pstat(name, addr, cmd string, t time.Time, err error) func() { return func() { _metricReqDur.Observe(int64(time.Since(t)/time.Millisecond), name, addr, cmd) if err != nil { if msg := formatErr(err, name, addr); msg != "" { _metricReqErr.Inc(name, addr, cmd, msg) } return } _metricHits.Inc(name, addr) } } func (pc *pooledConnection) Close() error { c := pc.c if _, ok := c.(errorConnection); ok { return nil } pc.c = errorConnection{errConnClosed} if pc.state&MultiState != 0 { c.Send("DISCARD") pc.state &^= (MultiState | WatchState) } else if pc.state&WatchState != 0 { c.Send("UNWATCH") pc.state &^= WatchState } if pc.state&SubscribeState != 0 { c.Send("UNSUBSCRIBE") c.Send("PUNSUBSCRIBE") // To detect the end of the message stream, ask the server to echo // a sentinel value and read until we see that value. sentinelOnce.Do(initSentinel) c.Send("ECHO", sentinel) c.Flush() for { p, err := c.Receive() if err != nil { break } if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) { pc.state &^= SubscribeState break } } } _, err := c.Do("") pc.p.Slice.Put(context.Background(), c, pc.state != 0 || c.Err() != nil) return err } func (pc *pooledConnection) Err() error { return pc.c.Err() } func (pc *pooledConnection) Do(commandName string, args ...interface{}) (reply interface{}, err error) { now := time.Now() ci := LookupCommandInfo(commandName) pc.state = (pc.state | ci.Set) &^ ci.Clear reply, err = pc.c.Do(commandName, args...) pc.p.statfunc(pc.p.c.Name, pc.p.c.Addr, commandName, now, err)() return } func (pc *pooledConnection) Send(commandName string, args ...interface{}) (err error) { ci := LookupCommandInfo(commandName) pc.state = (pc.state | ci.Set) &^ ci.Clear if pc.now.Equal(beginTime) { // mark first send time pc.now = time.Now() } pc.cmds = append(pc.cmds, commandName) return pc.c.Send(commandName, args...) } func (pc *pooledConnection) Flush() error { return pc.c.Flush() } func (pc *pooledConnection) Receive() (reply interface{}, err error) { reply, err = pc.c.Receive() if len(pc.cmds) > 0 { cmd := pc.cmds[0] pc.cmds = pc.cmds[1:] pc.p.statfunc(pc.p.c.Name, pc.p.c.Addr, cmd, pc.now, err)() } return } func (pc *pooledConnection) WithContext(ctx context.Context) Conn { pc.ctx = ctx return pc } type errorConnection struct{ err error } func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err } func (ec errorConnection) Send(string, ...interface{}) error { return ec.err } func (ec errorConnection) Err() error { return ec.err } func (ec errorConnection) Close() error { return ec.err } func (ec errorConnection) Flush() error { return ec.err } func (ec errorConnection) Receive() (interface{}, error) { return nil, ec.err } func (ec errorConnection) WithContext(context.Context) Conn { return ec }