commit
dbdfe47e5b
@ -0,0 +1,25 @@ |
|||||||
|
# cache/memcache |
||||||
|
|
||||||
|
##### 项目简介 |
||||||
|
1. 提供protobuf,gob,json序列化方式,gzip的memcache接口 |
||||||
|
|
||||||
|
#### 使用方式 |
||||||
|
```golang |
||||||
|
// 初始化 注意这里只是示例 展示用法 不能每次都New 只需要初始化一次 |
||||||
|
mc := memcache.New(&memcache.Config{}) |
||||||
|
// 程序关闭的时候调用close方法 |
||||||
|
defer mc.Close() |
||||||
|
// 增加 key |
||||||
|
err = mc.Set(c, &memcache.Item{}) |
||||||
|
// 删除key |
||||||
|
err := mc.Delete(c,key) |
||||||
|
// 获得某个key的内容 |
||||||
|
err := mc.Get(c,key).Scan(&v) |
||||||
|
// 获取多个key的内容 |
||||||
|
replies, err := mc.GetMulti(c, keys) |
||||||
|
for _, key := range replies.Keys() { |
||||||
|
if err = replies.Scan(key, &v); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
``` |
@ -0,0 +1,187 @@ |
|||||||
|
package memcache |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
) |
||||||
|
|
||||||
|
// Memcache memcache client
|
||||||
|
type Memcache struct { |
||||||
|
pool *Pool |
||||||
|
} |
||||||
|
|
||||||
|
// Reply is the result of Get
|
||||||
|
type Reply struct { |
||||||
|
err error |
||||||
|
item *Item |
||||||
|
conn Conn |
||||||
|
closed bool |
||||||
|
} |
||||||
|
|
||||||
|
// Replies is the result of GetMulti
|
||||||
|
type Replies struct { |
||||||
|
err error |
||||||
|
items map[string]*Item |
||||||
|
usedItems map[string]struct{} |
||||||
|
conn Conn |
||||||
|
closed bool |
||||||
|
} |
||||||
|
|
||||||
|
// New get a memcache client
|
||||||
|
func New(c *Config) *Memcache { |
||||||
|
return &Memcache{pool: NewPool(c)} |
||||||
|
} |
||||||
|
|
||||||
|
// Close close connection pool
|
||||||
|
func (mc *Memcache) Close() error { |
||||||
|
return mc.pool.Close() |
||||||
|
} |
||||||
|
|
||||||
|
// Conn direct get a connection
|
||||||
|
func (mc *Memcache) Conn(c context.Context) Conn { |
||||||
|
return mc.pool.Get(c) |
||||||
|
} |
||||||
|
|
||||||
|
// Set writes the given item, unconditionally.
|
||||||
|
func (mc *Memcache) Set(c context.Context, item *Item) (err error) { |
||||||
|
conn := mc.pool.Get(c) |
||||||
|
err = conn.Set(item) |
||||||
|
conn.Close() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Add writes the given item, if no value already exists for its key.
|
||||||
|
// ErrNotStored is returned if that condition is not met.
|
||||||
|
func (mc *Memcache) Add(c context.Context, item *Item) (err error) { |
||||||
|
conn := mc.pool.Get(c) |
||||||
|
err = conn.Add(item) |
||||||
|
conn.Close() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Replace writes the given item, but only if the server *does* already hold data for this key.
|
||||||
|
func (mc *Memcache) Replace(c context.Context, item *Item) (err error) { |
||||||
|
conn := mc.pool.Get(c) |
||||||
|
err = conn.Replace(item) |
||||||
|
conn.Close() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// CompareAndSwap writes the given item that was previously returned by Get
|
||||||
|
func (mc *Memcache) CompareAndSwap(c context.Context, item *Item) (err error) { |
||||||
|
conn := mc.pool.Get(c) |
||||||
|
err = conn.CompareAndSwap(item) |
||||||
|
conn.Close() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Get sends a command to the server for gets data.
|
||||||
|
func (mc *Memcache) Get(c context.Context, key string) *Reply { |
||||||
|
conn := mc.pool.Get(c) |
||||||
|
item, err := conn.Get(key) |
||||||
|
if err != nil { |
||||||
|
conn.Close() |
||||||
|
} |
||||||
|
return &Reply{err: err, item: item, conn: conn} |
||||||
|
} |
||||||
|
|
||||||
|
// Item get raw Item
|
||||||
|
func (r *Reply) Item() *Item { |
||||||
|
return r.item |
||||||
|
} |
||||||
|
|
||||||
|
// Scan converts value, read from the memcache
|
||||||
|
func (r *Reply) Scan(v interface{}) (err error) { |
||||||
|
if r.err != nil { |
||||||
|
return r.err |
||||||
|
} |
||||||
|
err = r.conn.Scan(r.item, v) |
||||||
|
if !r.closed { |
||||||
|
r.conn.Close() |
||||||
|
r.closed = true |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// GetMulti is a batch version of Get
|
||||||
|
func (mc *Memcache) GetMulti(c context.Context, keys []string) (*Replies, error) { |
||||||
|
conn := mc.pool.Get(c) |
||||||
|
items, err := conn.GetMulti(keys) |
||||||
|
rs := &Replies{err: err, items: items, conn: conn, usedItems: make(map[string]struct{}, len(keys))} |
||||||
|
if (err != nil) || (len(items) == 0) { |
||||||
|
rs.Close() |
||||||
|
} |
||||||
|
return rs, err |
||||||
|
} |
||||||
|
|
||||||
|
// Close close rows.
|
||||||
|
func (rs *Replies) Close() (err error) { |
||||||
|
if !rs.closed { |
||||||
|
err = rs.conn.Close() |
||||||
|
rs.closed = true |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Item get Item from rows
|
||||||
|
func (rs *Replies) Item(key string) *Item { |
||||||
|
return rs.items[key] |
||||||
|
} |
||||||
|
|
||||||
|
// Scan converts value, read from key in rows
|
||||||
|
func (rs *Replies) Scan(key string, v interface{}) (err error) { |
||||||
|
if rs.err != nil { |
||||||
|
return rs.err |
||||||
|
} |
||||||
|
item, ok := rs.items[key] |
||||||
|
if !ok { |
||||||
|
rs.Close() |
||||||
|
return ErrNotFound |
||||||
|
} |
||||||
|
rs.usedItems[key] = struct{}{} |
||||||
|
err = rs.conn.Scan(item, v) |
||||||
|
if (err != nil) || (len(rs.items) == len(rs.usedItems)) { |
||||||
|
rs.Close() |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Keys keys of result
|
||||||
|
func (rs *Replies) Keys() (keys []string) { |
||||||
|
keys = make([]string, 0, len(rs.items)) |
||||||
|
for key := range rs.items { |
||||||
|
keys = append(keys, key) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Touch updates the expiry for the given key.
|
||||||
|
func (mc *Memcache) Touch(c context.Context, key string, timeout int32) (err error) { |
||||||
|
conn := mc.pool.Get(c) |
||||||
|
err = conn.Touch(key, timeout) |
||||||
|
conn.Close() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Delete deletes the item with the provided key.
|
||||||
|
func (mc *Memcache) Delete(c context.Context, key string) (err error) { |
||||||
|
conn := mc.pool.Get(c) |
||||||
|
err = conn.Delete(key) |
||||||
|
conn.Close() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Increment atomically increments key by delta.
|
||||||
|
func (mc *Memcache) Increment(c context.Context, key string, delta uint64) (newValue uint64, err error) { |
||||||
|
conn := mc.pool.Get(c) |
||||||
|
newValue, err = conn.Increment(key, delta) |
||||||
|
conn.Close() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Decrement atomically decrements key by delta.
|
||||||
|
func (mc *Memcache) Decrement(c context.Context, key string, delta uint64) (newValue uint64, err error) { |
||||||
|
conn := mc.pool.Get(c) |
||||||
|
newValue, err = conn.Decrement(key, delta) |
||||||
|
conn.Close() |
||||||
|
return |
||||||
|
} |
@ -0,0 +1,685 @@ |
|||||||
|
package memcache |
||||||
|
|
||||||
|
import ( |
||||||
|
"bufio" |
||||||
|
"bytes" |
||||||
|
"compress/gzip" |
||||||
|
"context" |
||||||
|
"encoding/gob" |
||||||
|
"encoding/json" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"net" |
||||||
|
"strconv" |
||||||
|
"strings" |
||||||
|
"sync" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/gogo/protobuf/proto" |
||||||
|
pkgerr "github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
crlf = []byte("\r\n") |
||||||
|
spaceStr = string(" ") |
||||||
|
replyOK = []byte("OK\r\n") |
||||||
|
replyStored = []byte("STORED\r\n") |
||||||
|
replyNotStored = []byte("NOT_STORED\r\n") |
||||||
|
replyExists = []byte("EXISTS\r\n") |
||||||
|
replyNotFound = []byte("NOT_FOUND\r\n") |
||||||
|
replyDeleted = []byte("DELETED\r\n") |
||||||
|
replyEnd = []byte("END\r\n") |
||||||
|
replyTouched = []byte("TOUCHED\r\n") |
||||||
|
replyValueStr = "VALUE" |
||||||
|
replyClientErrorPrefix = []byte("CLIENT_ERROR ") |
||||||
|
replyServerErrorPrefix = []byte("SERVER_ERROR ") |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
_encodeBuf = 4096 // 4kb
|
||||||
|
// 1024*1024 - 1, set error???
|
||||||
|
_largeValue = 1000 * 1000 // 1MB
|
||||||
|
) |
||||||
|
|
||||||
|
type reader struct { |
||||||
|
io.Reader |
||||||
|
} |
||||||
|
|
||||||
|
func (r *reader) Reset(rd io.Reader) { |
||||||
|
r.Reader = rd |
||||||
|
} |
||||||
|
|
||||||
|
// conn is the low-level implementation of Conn
|
||||||
|
type conn struct { |
||||||
|
// Shared
|
||||||
|
mu sync.Mutex |
||||||
|
err error |
||||||
|
conn net.Conn |
||||||
|
// Read & Write
|
||||||
|
readTimeout time.Duration |
||||||
|
writeTimeout time.Duration |
||||||
|
rw *bufio.ReadWriter |
||||||
|
// Item Reader
|
||||||
|
ir bytes.Reader |
||||||
|
// Compress
|
||||||
|
gr gzip.Reader |
||||||
|
gw *gzip.Writer |
||||||
|
cb bytes.Buffer |
||||||
|
// Encoding
|
||||||
|
edb bytes.Buffer |
||||||
|
// json
|
||||||
|
jr reader |
||||||
|
jd *json.Decoder |
||||||
|
je *json.Encoder |
||||||
|
// protobuffer
|
||||||
|
ped *proto.Buffer |
||||||
|
} |
||||||
|
|
||||||
|
// DialOption specifies an option for dialing a Memcache server.
|
||||||
|
type DialOption struct { |
||||||
|
f func(*dialOptions) |
||||||
|
} |
||||||
|
|
||||||
|
type dialOptions struct { |
||||||
|
readTimeout time.Duration |
||||||
|
writeTimeout time.Duration |
||||||
|
dial func(network, addr string) (net.Conn, error) |
||||||
|
} |
||||||
|
|
||||||
|
// DialReadTimeout specifies the timeout for reading a single command reply.
|
||||||
|
func DialReadTimeout(d time.Duration) DialOption { |
||||||
|
return DialOption{func(do *dialOptions) { |
||||||
|
do.readTimeout = d |
||||||
|
}} |
||||||
|
} |
||||||
|
|
||||||
|
// DialWriteTimeout specifies the timeout for writing a single command.
|
||||||
|
func DialWriteTimeout(d time.Duration) DialOption { |
||||||
|
return DialOption{func(do *dialOptions) { |
||||||
|
do.writeTimeout = d |
||||||
|
}} |
||||||
|
} |
||||||
|
|
||||||
|
// DialConnectTimeout specifies the timeout for connecting to the Memcache server.
|
||||||
|
func DialConnectTimeout(d time.Duration) DialOption { |
||||||
|
return DialOption{func(do *dialOptions) { |
||||||
|
dialer := net.Dialer{Timeout: d} |
||||||
|
do.dial = dialer.Dial |
||||||
|
}} |
||||||
|
} |
||||||
|
|
||||||
|
// DialNetDial specifies a custom dial function for creating TCP
|
||||||
|
// connections. If this option is left out, then net.Dial is
|
||||||
|
// used. DialNetDial overrides DialConnectTimeout.
|
||||||
|
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption { |
||||||
|
return DialOption{func(do *dialOptions) { |
||||||
|
do.dial = dial |
||||||
|
}} |
||||||
|
} |
||||||
|
|
||||||
|
// Dial connects to the Memcache server at the given network and
|
||||||
|
// address using the specified options.
|
||||||
|
func Dial(network, address string, options ...DialOption) (Conn, error) { |
||||||
|
do := dialOptions{ |
||||||
|
dial: net.Dial, |
||||||
|
} |
||||||
|
for _, option := range options { |
||||||
|
option.f(&do) |
||||||
|
} |
||||||
|
netConn, err := do.dial(network, address) |
||||||
|
if err != nil { |
||||||
|
return nil, pkgerr.WithStack(err) |
||||||
|
} |
||||||
|
return NewConn(netConn, do.readTimeout, do.writeTimeout), nil |
||||||
|
} |
||||||
|
|
||||||
|
// NewConn returns a new memcache connection for the given net connection.
|
||||||
|
func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn { |
||||||
|
if writeTimeout <= 0 || readTimeout <= 0 { |
||||||
|
panic("must config memcache timeout") |
||||||
|
} |
||||||
|
c := &conn{ |
||||||
|
conn: netConn, |
||||||
|
rw: bufio.NewReadWriter(bufio.NewReader(netConn), |
||||||
|
bufio.NewWriter(netConn)), |
||||||
|
readTimeout: readTimeout, |
||||||
|
writeTimeout: writeTimeout, |
||||||
|
} |
||||||
|
c.jd = json.NewDecoder(&c.jr) |
||||||
|
c.je = json.NewEncoder(&c.edb) |
||||||
|
c.gw = gzip.NewWriter(&c.cb) |
||||||
|
c.edb.Grow(_encodeBuf) |
||||||
|
// NOTE reuse bytes.Buffer internal buf
|
||||||
|
// DON'T concurrency call Scan
|
||||||
|
c.ped = proto.NewBuffer(c.edb.Bytes()) |
||||||
|
return c |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Close() error { |
||||||
|
c.mu.Lock() |
||||||
|
err := c.err |
||||||
|
if c.err == nil { |
||||||
|
c.err = pkgerr.New("memcache: closed") |
||||||
|
err = c.conn.Close() |
||||||
|
} |
||||||
|
c.mu.Unlock() |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) fatal(err error) error { |
||||||
|
c.mu.Lock() |
||||||
|
if c.err == nil { |
||||||
|
c.err = pkgerr.WithStack(err) |
||||||
|
// Close connection to force errors on subsequent calls and to unblock
|
||||||
|
// other reader or writer.
|
||||||
|
c.conn.Close() |
||||||
|
} |
||||||
|
c.mu.Unlock() |
||||||
|
return c.err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Err() error { |
||||||
|
c.mu.Lock() |
||||||
|
err := c.err |
||||||
|
c.mu.Unlock() |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Add(item *Item) error { |
||||||
|
return c.populate("add", item) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Set(item *Item) error { |
||||||
|
return c.populate("set", item) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Replace(item *Item) error { |
||||||
|
return c.populate("replace", item) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) CompareAndSwap(item *Item) error { |
||||||
|
return c.populate("cas", item) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) populate(cmd string, item *Item) (err error) { |
||||||
|
if !legalKey(item.Key) { |
||||||
|
return pkgerr.WithStack(ErrMalformedKey) |
||||||
|
} |
||||||
|
var res []byte |
||||||
|
if res, err = c.encode(item); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
l := len(res) |
||||||
|
count := l/(_largeValue) + 1 |
||||||
|
if count == 1 { |
||||||
|
item.Value = res |
||||||
|
return c.populateOne(cmd, item) |
||||||
|
} |
||||||
|
nItem := &Item{ |
||||||
|
Key: item.Key, |
||||||
|
Value: []byte(strconv.Itoa(l)), |
||||||
|
Expiration: item.Expiration, |
||||||
|
Flags: item.Flags | flagLargeValue, |
||||||
|
} |
||||||
|
err = c.populateOne(cmd, nItem) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
k := item.Key |
||||||
|
nItem.Flags = item.Flags |
||||||
|
for i := 1; i <= count; i++ { |
||||||
|
if i == count { |
||||||
|
nItem.Value = res[_largeValue*(count-1):] |
||||||
|
} else { |
||||||
|
nItem.Value = res[_largeValue*(i-1) : _largeValue*i] |
||||||
|
} |
||||||
|
nItem.Key = fmt.Sprintf("%s%d", k, i) |
||||||
|
if err = c.populateOne(cmd, nItem); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) populateOne(cmd string, item *Item) (err error) { |
||||||
|
if c.writeTimeout != 0 { |
||||||
|
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
||||||
|
} |
||||||
|
|
||||||
|
// <command name> <key> <flags> <exptime> <bytes> [noreply]\r\n
|
||||||
|
if cmd == "cas" { |
||||||
|
_, err = fmt.Fprintf(c.rw, "%s %s %d %d %d %d\r\n", |
||||||
|
cmd, item.Key, item.Flags, item.Expiration, len(item.Value), item.cas) |
||||||
|
} else { |
||||||
|
_, err = fmt.Fprintf(c.rw, "%s %s %d %d %d\r\n", |
||||||
|
cmd, item.Key, item.Flags, item.Expiration, len(item.Value)) |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
return c.fatal(err) |
||||||
|
} |
||||||
|
c.rw.Write(item.Value) |
||||||
|
c.rw.Write(crlf) |
||||||
|
if err = c.rw.Flush(); err != nil { |
||||||
|
return c.fatal(err) |
||||||
|
} |
||||||
|
if c.readTimeout != 0 { |
||||||
|
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) |
||||||
|
} |
||||||
|
line, err := c.rw.ReadSlice('\n') |
||||||
|
if err != nil { |
||||||
|
return c.fatal(err) |
||||||
|
} |
||||||
|
switch { |
||||||
|
case bytes.Equal(line, replyStored): |
||||||
|
return nil |
||||||
|
case bytes.Equal(line, replyNotStored): |
||||||
|
return ErrNotStored |
||||||
|
case bytes.Equal(line, replyExists): |
||||||
|
return ErrCASConflict |
||||||
|
case bytes.Equal(line, replyNotFound): |
||||||
|
return ErrNotFound |
||||||
|
} |
||||||
|
return pkgerr.WithStack(protocolError(string(line))) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Get(key string) (r *Item, err error) { |
||||||
|
if !legalKey(key) { |
||||||
|
return nil, pkgerr.WithStack(ErrMalformedKey) |
||||||
|
} |
||||||
|
if c.writeTimeout != 0 { |
||||||
|
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
||||||
|
} |
||||||
|
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", key); err != nil { |
||||||
|
return nil, c.fatal(err) |
||||||
|
} |
||||||
|
if err = c.rw.Flush(); err != nil { |
||||||
|
return nil, c.fatal(err) |
||||||
|
} |
||||||
|
if err = c.parseGetReply(func(it *Item) { |
||||||
|
r = it |
||||||
|
}); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
if r == nil { |
||||||
|
err = ErrNotFound |
||||||
|
return |
||||||
|
} |
||||||
|
if r.Flags&flagLargeValue != flagLargeValue { |
||||||
|
return |
||||||
|
} |
||||||
|
if r, err = c.getLargeValue(r); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) GetMulti(keys []string) (res map[string]*Item, err error) { |
||||||
|
for _, key := range keys { |
||||||
|
if !legalKey(key) { |
||||||
|
return nil, pkgerr.WithStack(ErrMalformedKey) |
||||||
|
} |
||||||
|
} |
||||||
|
if c.writeTimeout != 0 { |
||||||
|
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
||||||
|
} |
||||||
|
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil { |
||||||
|
return nil, c.fatal(err) |
||||||
|
} |
||||||
|
if err = c.rw.Flush(); err != nil { |
||||||
|
return nil, c.fatal(err) |
||||||
|
} |
||||||
|
res = make(map[string]*Item, len(keys)) |
||||||
|
if err = c.parseGetReply(func(it *Item) { |
||||||
|
res[it.Key] = it |
||||||
|
}); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
for k, v := range res { |
||||||
|
if v.Flags&flagLargeValue != flagLargeValue { |
||||||
|
continue |
||||||
|
} |
||||||
|
r, err := c.getLargeValue(v) |
||||||
|
if err != nil { |
||||||
|
return res, err |
||||||
|
} |
||||||
|
res[k] = r |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) getMulti(keys []string) (res map[string]*Item, err error) { |
||||||
|
for _, key := range keys { |
||||||
|
if !legalKey(key) { |
||||||
|
return nil, pkgerr.WithStack(ErrMalformedKey) |
||||||
|
} |
||||||
|
} |
||||||
|
if c.writeTimeout != 0 { |
||||||
|
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
||||||
|
} |
||||||
|
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil { |
||||||
|
return nil, c.fatal(err) |
||||||
|
} |
||||||
|
if err = c.rw.Flush(); err != nil { |
||||||
|
return nil, c.fatal(err) |
||||||
|
} |
||||||
|
res = make(map[string]*Item, len(keys)) |
||||||
|
err = c.parseGetReply(func(it *Item) { |
||||||
|
res[it.Key] = it |
||||||
|
}) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) getLargeValue(it *Item) (r *Item, err error) { |
||||||
|
l, err := strconv.Atoi(string(it.Value)) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
count := l/_largeValue + 1 |
||||||
|
keys := make([]string, 0, count) |
||||||
|
for i := 1; i <= count; i++ { |
||||||
|
keys = append(keys, fmt.Sprintf("%s%d", it.Key, i)) |
||||||
|
} |
||||||
|
items, err := c.getMulti(keys) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
if len(items) < count { |
||||||
|
err = ErrNotFound |
||||||
|
return |
||||||
|
} |
||||||
|
v := make([]byte, 0, l) |
||||||
|
for _, k := range keys { |
||||||
|
if items[k] == nil || items[k].Value == nil { |
||||||
|
err = ErrNotFound |
||||||
|
return |
||||||
|
} |
||||||
|
v = append(v, items[k].Value...) |
||||||
|
} |
||||||
|
it.Value = v |
||||||
|
it.Flags = it.Flags ^ flagLargeValue |
||||||
|
r = it |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) parseGetReply(f func(*Item)) error { |
||||||
|
if c.readTimeout != 0 { |
||||||
|
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) |
||||||
|
} |
||||||
|
for { |
||||||
|
line, err := c.rw.ReadSlice('\n') |
||||||
|
if err != nil { |
||||||
|
return c.fatal(err) |
||||||
|
} |
||||||
|
if bytes.Equal(line, replyEnd) { |
||||||
|
return nil |
||||||
|
} |
||||||
|
if bytes.HasPrefix(line, replyServerErrorPrefix) { |
||||||
|
errMsg := line[len(replyServerErrorPrefix):] |
||||||
|
return c.fatal(protocolError(errMsg)) |
||||||
|
} |
||||||
|
it := new(Item) |
||||||
|
size, err := scanGetReply(line, it) |
||||||
|
if err != nil { |
||||||
|
return c.fatal(err) |
||||||
|
} |
||||||
|
it.Value = make([]byte, size+2) |
||||||
|
if _, err = io.ReadFull(c.rw, it.Value); err != nil { |
||||||
|
return c.fatal(err) |
||||||
|
} |
||||||
|
if !bytes.HasSuffix(it.Value, crlf) { |
||||||
|
return c.fatal(protocolError("corrupt get reply, no except CRLF")) |
||||||
|
} |
||||||
|
it.Value = it.Value[:size] |
||||||
|
f(it) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func scanGetReply(line []byte, item *Item) (size int, err error) { |
||||||
|
if !bytes.HasSuffix(line, crlf) { |
||||||
|
return 0, protocolError("corrupt get reply, no except CRLF") |
||||||
|
} |
||||||
|
// VALUE <key> <flags> <bytes> [<cas unique>]
|
||||||
|
chunks := strings.Split(string(line[:len(line)-2]), spaceStr) |
||||||
|
if len(chunks) < 4 { |
||||||
|
return 0, protocolError("corrupt get reply") |
||||||
|
} |
||||||
|
if chunks[0] != replyValueStr { |
||||||
|
return 0, protocolError("corrupt get reply, no except VALUE") |
||||||
|
} |
||||||
|
item.Key = chunks[1] |
||||||
|
flags64, err := strconv.ParseUint(chunks[2], 10, 32) |
||||||
|
if err != nil { |
||||||
|
return 0, err |
||||||
|
} |
||||||
|
item.Flags = uint32(flags64) |
||||||
|
if size, err = strconv.Atoi(chunks[3]); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
if len(chunks) > 4 { |
||||||
|
item.cas, err = strconv.ParseUint(chunks[4], 10, 64) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Touch(key string, expire int32) (err error) { |
||||||
|
if !legalKey(key) { |
||||||
|
return pkgerr.WithStack(ErrMalformedKey) |
||||||
|
} |
||||||
|
line, err := c.writeReadLine("touch %s %d\r\n", key, expire) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
switch { |
||||||
|
case bytes.Equal(line, replyTouched): |
||||||
|
return nil |
||||||
|
case bytes.Equal(line, replyNotFound): |
||||||
|
return ErrNotFound |
||||||
|
default: |
||||||
|
return pkgerr.WithStack(protocolError(string(line))) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Increment(key string, delta uint64) (uint64, error) { |
||||||
|
return c.incrDecr("incr", key, delta) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Decrement(key string, delta uint64) (newValue uint64, err error) { |
||||||
|
return c.incrDecr("decr", key, delta) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) incrDecr(cmd, key string, delta uint64) (uint64, error) { |
||||||
|
if !legalKey(key) { |
||||||
|
return 0, pkgerr.WithStack(ErrMalformedKey) |
||||||
|
} |
||||||
|
line, err := c.writeReadLine("%s %s %d\r\n", cmd, key, delta) |
||||||
|
if err != nil { |
||||||
|
return 0, err |
||||||
|
} |
||||||
|
switch { |
||||||
|
case bytes.Equal(line, replyNotFound): |
||||||
|
return 0, ErrNotFound |
||||||
|
case bytes.HasPrefix(line, replyClientErrorPrefix): |
||||||
|
errMsg := line[len(replyClientErrorPrefix):] |
||||||
|
return 0, pkgerr.WithStack(protocolError(errMsg)) |
||||||
|
} |
||||||
|
val, err := strconv.ParseUint(string(line[:len(line)-2]), 10, 64) |
||||||
|
if err != nil { |
||||||
|
return 0, err |
||||||
|
} |
||||||
|
return val, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Delete(key string) (err error) { |
||||||
|
if !legalKey(key) { |
||||||
|
return pkgerr.WithStack(ErrMalformedKey) |
||||||
|
} |
||||||
|
line, err := c.writeReadLine("delete %s\r\n", key) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
switch { |
||||||
|
case bytes.Equal(line, replyOK): |
||||||
|
return nil |
||||||
|
case bytes.Equal(line, replyDeleted): |
||||||
|
return nil |
||||||
|
case bytes.Equal(line, replyNotStored): |
||||||
|
return ErrNotStored |
||||||
|
case bytes.Equal(line, replyExists): |
||||||
|
return ErrCASConflict |
||||||
|
case bytes.Equal(line, replyNotFound): |
||||||
|
return ErrNotFound |
||||||
|
} |
||||||
|
return pkgerr.WithStack(protocolError(string(line))) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) writeReadLine(format string, args ...interface{}) ([]byte, error) { |
||||||
|
if c.writeTimeout != 0 { |
||||||
|
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
||||||
|
} |
||||||
|
_, err := fmt.Fprintf(c.rw, format, args...) |
||||||
|
if err != nil { |
||||||
|
return nil, c.fatal(pkgerr.WithStack(err)) |
||||||
|
} |
||||||
|
if err = c.rw.Flush(); err != nil { |
||||||
|
return nil, c.fatal(pkgerr.WithStack(err)) |
||||||
|
} |
||||||
|
if c.readTimeout != 0 { |
||||||
|
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) |
||||||
|
} |
||||||
|
line, err := c.rw.ReadSlice('\n') |
||||||
|
if err != nil { |
||||||
|
return line, c.fatal(pkgerr.WithStack(err)) |
||||||
|
} |
||||||
|
return line, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Scan(item *Item, v interface{}) (err error) { |
||||||
|
c.ir.Reset(item.Value) |
||||||
|
if item.Flags&FlagGzip == FlagGzip { |
||||||
|
if err = c.gr.Reset(&c.ir); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
if err = c.decode(&c.gr, item, v); err != nil { |
||||||
|
err = pkgerr.WithStack(err) |
||||||
|
return |
||||||
|
} |
||||||
|
err = c.gr.Close() |
||||||
|
} else { |
||||||
|
err = c.decode(&c.ir, item, v) |
||||||
|
} |
||||||
|
err = pkgerr.WithStack(err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) WithContext(ctx context.Context) Conn { |
||||||
|
// FIXME: implement WithContext
|
||||||
|
return c |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) encode(item *Item) (data []byte, err error) { |
||||||
|
if (item.Flags | _flagEncoding) == _flagEncoding { |
||||||
|
if item.Value == nil { |
||||||
|
return nil, ErrItem |
||||||
|
} |
||||||
|
} else if item.Object == nil { |
||||||
|
return nil, ErrItem |
||||||
|
} |
||||||
|
// encoding
|
||||||
|
switch { |
||||||
|
case item.Flags&FlagGOB == FlagGOB: |
||||||
|
c.edb.Reset() |
||||||
|
if err = gob.NewEncoder(&c.edb).Encode(item.Object); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
data = c.edb.Bytes() |
||||||
|
case item.Flags&FlagProtobuf == FlagProtobuf: |
||||||
|
c.edb.Reset() |
||||||
|
c.ped.SetBuf(c.edb.Bytes()) |
||||||
|
pb, ok := item.Object.(proto.Message) |
||||||
|
if !ok { |
||||||
|
err = ErrItemObject |
||||||
|
return |
||||||
|
} |
||||||
|
if err = c.ped.Marshal(pb); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
data = c.ped.Bytes() |
||||||
|
case item.Flags&FlagJSON == FlagJSON: |
||||||
|
c.edb.Reset() |
||||||
|
if err = c.je.Encode(item.Object); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
data = c.edb.Bytes() |
||||||
|
default: |
||||||
|
data = item.Value |
||||||
|
} |
||||||
|
// compress
|
||||||
|
if item.Flags&FlagGzip == FlagGzip { |
||||||
|
c.cb.Reset() |
||||||
|
c.gw.Reset(&c.cb) |
||||||
|
if _, err = c.gw.Write(data); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
if err = c.gw.Close(); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
data = c.cb.Bytes() |
||||||
|
} |
||||||
|
if len(data) > 8000000 { |
||||||
|
err = ErrValueSize |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) decode(rd io.Reader, item *Item, v interface{}) (err error) { |
||||||
|
var data []byte |
||||||
|
switch { |
||||||
|
case item.Flags&FlagGOB == FlagGOB: |
||||||
|
err = gob.NewDecoder(rd).Decode(v) |
||||||
|
case item.Flags&FlagJSON == FlagJSON: |
||||||
|
c.jr.Reset(rd) |
||||||
|
err = c.jd.Decode(v) |
||||||
|
default: |
||||||
|
data = item.Value |
||||||
|
if item.Flags&FlagGzip == FlagGzip { |
||||||
|
c.edb.Reset() |
||||||
|
if _, err = io.Copy(&c.edb, rd); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
data = c.edb.Bytes() |
||||||
|
} |
||||||
|
if item.Flags&FlagProtobuf == FlagProtobuf { |
||||||
|
m, ok := v.(proto.Message) |
||||||
|
if !ok { |
||||||
|
err = ErrItemObject |
||||||
|
return |
||||||
|
} |
||||||
|
c.ped.SetBuf(data) |
||||||
|
err = c.ped.Unmarshal(m) |
||||||
|
} else { |
||||||
|
switch v.(type) { |
||||||
|
case *[]byte: |
||||||
|
d := v.(*[]byte) |
||||||
|
*d = data |
||||||
|
case *string: |
||||||
|
d := v.(*string) |
||||||
|
*d = string(data) |
||||||
|
case interface{}: |
||||||
|
err = json.Unmarshal(data, v) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func legalKey(key string) bool { |
||||||
|
if len(key) > 250 || len(key) == 0 { |
||||||
|
return false |
||||||
|
} |
||||||
|
for i := 0; i < len(key); i++ { |
||||||
|
if key[i] <= ' ' || key[i] == 0x7f { |
||||||
|
return false |
||||||
|
} |
||||||
|
} |
||||||
|
return true |
||||||
|
} |
@ -0,0 +1,76 @@ |
|||||||
|
package memcache |
||||||
|
|
||||||
|
import ( |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"strings" |
||||||
|
|
||||||
|
pkgerr "github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
// ErrNotFound not found
|
||||||
|
ErrNotFound = errors.New("memcache: key not found") |
||||||
|
// ErrExists exists
|
||||||
|
ErrExists = errors.New("memcache: key exists") |
||||||
|
// ErrNotStored not stored
|
||||||
|
ErrNotStored = errors.New("memcache: key not stored") |
||||||
|
// ErrCASConflict means that a CompareAndSwap call failed due to the
|
||||||
|
// cached value being modified between the Get and the CompareAndSwap.
|
||||||
|
// If the cached value was simply evicted rather than replaced,
|
||||||
|
// ErrNotStored will be returned instead.
|
||||||
|
ErrCASConflict = errors.New("memcache: compare-and-swap conflict") |
||||||
|
|
||||||
|
// ErrPoolExhausted is returned from a pool connection method (Store, Get,
|
||||||
|
// Delete, IncrDecr, Err) when the maximum number of database connections
|
||||||
|
// in the pool has been reached.
|
||||||
|
ErrPoolExhausted = errors.New("memcache: connection pool exhausted") |
||||||
|
// ErrPoolClosed pool closed
|
||||||
|
ErrPoolClosed = errors.New("memcache: connection pool closed") |
||||||
|
// ErrConnClosed conn closed
|
||||||
|
ErrConnClosed = errors.New("memcache: connection closed") |
||||||
|
// ErrMalformedKey is returned when an invalid key is used.
|
||||||
|
// Keys must be at maximum 250 bytes long and not
|
||||||
|
// contain whitespace or control characters.
|
||||||
|
ErrMalformedKey = errors.New("memcache: malformed key is too long or contains invalid characters") |
||||||
|
// ErrValueSize item value size must less than 1mb
|
||||||
|
ErrValueSize = errors.New("memcache: item value size must not greater than 1mb") |
||||||
|
// ErrStat stat error for monitor
|
||||||
|
ErrStat = errors.New("memcache unexpected errors") |
||||||
|
// ErrItem item nil.
|
||||||
|
ErrItem = errors.New("memcache: item object nil") |
||||||
|
// ErrItemObject object type Assertion failed
|
||||||
|
ErrItemObject = errors.New("memcache: item object protobuf type assertion failed") |
||||||
|
) |
||||||
|
|
||||||
|
type protocolError string |
||||||
|
|
||||||
|
func (pe protocolError) Error() string { |
||||||
|
return fmt.Sprintf("memcache: %s (possible server error or unsupported concurrent read by application)", string(pe)) |
||||||
|
} |
||||||
|
|
||||||
|
func formatErr(err error) string { |
||||||
|
e := pkgerr.Cause(err) |
||||||
|
switch e { |
||||||
|
case ErrNotFound, ErrExists, ErrNotStored, nil: |
||||||
|
return "" |
||||||
|
default: |
||||||
|
es := e.Error() |
||||||
|
switch { |
||||||
|
case strings.HasPrefix(es, "read"): |
||||||
|
return "read timeout" |
||||||
|
case strings.HasPrefix(es, "dial"): |
||||||
|
return "dial timeout" |
||||||
|
case strings.HasPrefix(es, "write"): |
||||||
|
return "write timeout" |
||||||
|
case strings.Contains(es, "EOF"): |
||||||
|
return "eof" |
||||||
|
case strings.Contains(es, "reset"): |
||||||
|
return "reset" |
||||||
|
case strings.Contains(es, "broken"): |
||||||
|
return "broken pipe" |
||||||
|
default: |
||||||
|
return "unexpected err" |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,136 @@ |
|||||||
|
package memcache |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
) |
||||||
|
|
||||||
|
// Error represents an error returned in a command reply.
|
||||||
|
type Error string |
||||||
|
|
||||||
|
func (err Error) Error() string { return string(err) } |
||||||
|
|
||||||
|
const ( |
||||||
|
// Flag, 15(encoding) bit+ 17(compress) bit
|
||||||
|
|
||||||
|
// FlagRAW default flag.
|
||||||
|
FlagRAW = uint32(0) |
||||||
|
// FlagGOB gob encoding.
|
||||||
|
FlagGOB = uint32(1) << 0 |
||||||
|
// FlagJSON json encoding.
|
||||||
|
FlagJSON = uint32(1) << 1 |
||||||
|
// FlagProtobuf protobuf
|
||||||
|
FlagProtobuf = uint32(1) << 2 |
||||||
|
|
||||||
|
_flagEncoding = uint32(0xFFFF8000) |
||||||
|
|
||||||
|
// FlagGzip gzip compress.
|
||||||
|
FlagGzip = uint32(1) << 15 |
||||||
|
|
||||||
|
// left mv 31??? not work!!!
|
||||||
|
flagLargeValue = uint32(1) << 30 |
||||||
|
) |
||||||
|
|
||||||
|
// Item is an reply to be got or stored in a memcached server.
|
||||||
|
type Item struct { |
||||||
|
// Key is the Item's key (250 bytes maximum).
|
||||||
|
Key string |
||||||
|
|
||||||
|
// Value is the Item's value.
|
||||||
|
Value []byte |
||||||
|
|
||||||
|
// Object is the Item's object for use codec.
|
||||||
|
Object interface{} |
||||||
|
|
||||||
|
// Flags are server-opaque flags whose semantics are entirely
|
||||||
|
// up to the app.
|
||||||
|
Flags uint32 |
||||||
|
|
||||||
|
// Expiration is the cache expiration time, in seconds: either a relative
|
||||||
|
// time from now (up to 1 month), or an absolute Unix epoch time.
|
||||||
|
// Zero means the Item has no expiration time.
|
||||||
|
Expiration int32 |
||||||
|
|
||||||
|
// Compare and swap ID.
|
||||||
|
cas uint64 |
||||||
|
} |
||||||
|
|
||||||
|
// Conn represents a connection to a Memcache server.
|
||||||
|
// Command Reference: https://github.com/memcached/memcached/wiki/Commands
|
||||||
|
type Conn interface { |
||||||
|
// Close closes the connection.
|
||||||
|
Close() error |
||||||
|
|
||||||
|
// Err returns a non-nil value if the connection is broken. The returned
|
||||||
|
// value is either the first non-nil value returned from the underlying
|
||||||
|
// network connection or a protocol parsing error. Applications should
|
||||||
|
// close broken connections.
|
||||||
|
Err() error |
||||||
|
|
||||||
|
// Add writes the given item, if no value already exists for its key.
|
||||||
|
// ErrNotStored is returned if that condition is not met.
|
||||||
|
Add(item *Item) error |
||||||
|
|
||||||
|
// Set writes the given item, unconditionally.
|
||||||
|
Set(item *Item) error |
||||||
|
|
||||||
|
// Replace writes the given item, but only if the server *does* already
|
||||||
|
// hold data for this key.
|
||||||
|
Replace(item *Item) error |
||||||
|
|
||||||
|
// Get sends a command to the server for gets data.
|
||||||
|
Get(key string) (*Item, error) |
||||||
|
|
||||||
|
// GetMulti is a batch version of Get. The returned map from keys to items
|
||||||
|
// may have fewer elements than the input slice, due to memcache cache
|
||||||
|
// misses. Each key must be at most 250 bytes in length.
|
||||||
|
// If no error is returned, the returned map will also be non-nil.
|
||||||
|
GetMulti(keys []string) (map[string]*Item, error) |
||||||
|
|
||||||
|
// Delete deletes the item with the provided key.
|
||||||
|
// The error ErrCacheMiss is returned if the item didn't already exist in
|
||||||
|
// the cache.
|
||||||
|
Delete(key string) error |
||||||
|
|
||||||
|
// Increment atomically increments key by delta. The return value is the
|
||||||
|
// new value after being incremented or an error. If the value didn't exist
|
||||||
|
// in memcached the error is ErrCacheMiss. The value in memcached must be
|
||||||
|
// an decimal number, or an error will be returned.
|
||||||
|
// On 64-bit overflow, the new value wraps around.
|
||||||
|
Increment(key string, delta uint64) (newValue uint64, err error) |
||||||
|
|
||||||
|
// Decrement atomically decrements key by delta. The return value is the
|
||||||
|
// new value after being decremented or an error. If the value didn't exist
|
||||||
|
// in memcached the error is ErrCacheMiss. The value in memcached must be
|
||||||
|
// an decimal number, or an error will be returned. On underflow, the new
|
||||||
|
// value is capped at zero and does not wrap around.
|
||||||
|
Decrement(key string, delta uint64) (newValue uint64, err error) |
||||||
|
|
||||||
|
// CompareAndSwap writes the given item that was previously returned by
|
||||||
|
// Get, if the value was neither modified or evicted between the Get and
|
||||||
|
// the CompareAndSwap calls. The item's Key should not change between calls
|
||||||
|
// but all other item fields may differ. ErrCASConflict is returned if the
|
||||||
|
// value was modified in between the calls.
|
||||||
|
// ErrNotStored is returned if the value was evicted in between the calls.
|
||||||
|
CompareAndSwap(item *Item) error |
||||||
|
|
||||||
|
// Touch updates the expiry for the given key. The seconds parameter is
|
||||||
|
// either a Unix timestamp or, if seconds is less than 1 month, the number
|
||||||
|
// of seconds into the future at which time the item will expire.
|
||||||
|
//ErrCacheMiss is returned if the key is not in the cache. The key must be
|
||||||
|
// at most 250 bytes in length.
|
||||||
|
Touch(key string, seconds int32) (err error) |
||||||
|
|
||||||
|
// Scan converts value read from the memcache into the following
|
||||||
|
// common Go types and special types:
|
||||||
|
//
|
||||||
|
// *string
|
||||||
|
// *[]byte
|
||||||
|
// *interface{}
|
||||||
|
//
|
||||||
|
Scan(item *Item, v interface{}) (err error) |
||||||
|
|
||||||
|
// WithContext return a Conn with its context changed to ctx
|
||||||
|
// the context controls the entire lifetime of Conn before you change it
|
||||||
|
// NOTE: this method is not thread-safe
|
||||||
|
WithContext(ctx context.Context) Conn |
||||||
|
} |
@ -0,0 +1,59 @@ |
|||||||
|
package memcache |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
) |
||||||
|
|
||||||
|
// MockErr for unit test.
|
||||||
|
type MockErr struct { |
||||||
|
Error error |
||||||
|
} |
||||||
|
|
||||||
|
var _ Conn = MockErr{} |
||||||
|
|
||||||
|
// MockWith return a mock conn.
|
||||||
|
func MockWith(err error) MockErr { |
||||||
|
return MockErr{Error: err} |
||||||
|
} |
||||||
|
|
||||||
|
// Err .
|
||||||
|
func (m MockErr) Err() error { return m.Error } |
||||||
|
|
||||||
|
// Close .
|
||||||
|
func (m MockErr) Close() error { return m.Error } |
||||||
|
|
||||||
|
// Add .
|
||||||
|
func (m MockErr) Add(item *Item) error { return m.Error } |
||||||
|
|
||||||
|
// Set .
|
||||||
|
func (m MockErr) Set(item *Item) error { return m.Error } |
||||||
|
|
||||||
|
// Replace .
|
||||||
|
func (m MockErr) Replace(item *Item) error { return m.Error } |
||||||
|
|
||||||
|
// CompareAndSwap .
|
||||||
|
func (m MockErr) CompareAndSwap(item *Item) error { return m.Error } |
||||||
|
|
||||||
|
// Get .
|
||||||
|
func (m MockErr) Get(key string) (*Item, error) { return nil, m.Error } |
||||||
|
|
||||||
|
// GetMulti .
|
||||||
|
func (m MockErr) GetMulti(keys []string) (map[string]*Item, error) { return nil, m.Error } |
||||||
|
|
||||||
|
// Touch .
|
||||||
|
func (m MockErr) Touch(key string, timeout int32) error { return m.Error } |
||||||
|
|
||||||
|
// Delete .
|
||||||
|
func (m MockErr) Delete(key string) error { return m.Error } |
||||||
|
|
||||||
|
// Increment .
|
||||||
|
func (m MockErr) Increment(key string, delta uint64) (uint64, error) { return 0, m.Error } |
||||||
|
|
||||||
|
// Decrement .
|
||||||
|
func (m MockErr) Decrement(key string, delta uint64) (uint64, error) { return 0, m.Error } |
||||||
|
|
||||||
|
// Scan .
|
||||||
|
func (m MockErr) Scan(item *Item, v interface{}) error { return m.Error } |
||||||
|
|
||||||
|
// WithContext .
|
||||||
|
func (m MockErr) WithContext(ctx context.Context) Conn { return m } |
@ -0,0 +1,197 @@ |
|||||||
|
package memcache |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"io" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/container/pool" |
||||||
|
"github.com/bilibili/kratos/pkg/stat" |
||||||
|
xtime "github.com/bilibili/kratos/pkg/time" |
||||||
|
) |
||||||
|
|
||||||
|
var stats = stat.Cache |
||||||
|
|
||||||
|
// Config memcache config.
|
||||||
|
type Config struct { |
||||||
|
*pool.Config |
||||||
|
|
||||||
|
Name string // memcache name, for trace
|
||||||
|
Proto string |
||||||
|
Addr string |
||||||
|
DialTimeout xtime.Duration |
||||||
|
ReadTimeout xtime.Duration |
||||||
|
WriteTimeout xtime.Duration |
||||||
|
} |
||||||
|
|
||||||
|
// Pool memcache connection pool struct.
|
||||||
|
type Pool struct { |
||||||
|
p pool.Pool |
||||||
|
c *Config |
||||||
|
} |
||||||
|
|
||||||
|
// NewPool new a memcache conn pool.
|
||||||
|
func NewPool(c *Config) (p *Pool) { |
||||||
|
if c.DialTimeout <= 0 || c.ReadTimeout <= 0 || c.WriteTimeout <= 0 { |
||||||
|
panic("must config memcache timeout") |
||||||
|
} |
||||||
|
p1 := pool.NewList(c.Config) |
||||||
|
cnop := DialConnectTimeout(time.Duration(c.DialTimeout)) |
||||||
|
rdop := DialReadTimeout(time.Duration(c.ReadTimeout)) |
||||||
|
wrop := DialWriteTimeout(time.Duration(c.WriteTimeout)) |
||||||
|
p1.New = func(ctx context.Context) (io.Closer, error) { |
||||||
|
conn, err := Dial(c.Proto, c.Addr, cnop, rdop, wrop) |
||||||
|
return &traceConn{Conn: conn, address: c.Addr}, err |
||||||
|
} |
||||||
|
p = &Pool{p: p1, c: c} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Get gets a connection. The application must close the returned connection.
|
||||||
|
// This method always returns a valid connection so that applications can defer
|
||||||
|
// error handling to the first use of the connection. If there is an error
|
||||||
|
// getting an underlying connection, then the connection Err, Do, Send, Flush
|
||||||
|
// and Receive methods return that error.
|
||||||
|
func (p *Pool) Get(ctx context.Context) Conn { |
||||||
|
c, err := p.p.Get(ctx) |
||||||
|
if err != nil { |
||||||
|
return errorConnection{err} |
||||||
|
} |
||||||
|
c1, _ := c.(Conn) |
||||||
|
return &pooledConnection{p: p, c: c1.WithContext(ctx), ctx: ctx} |
||||||
|
} |
||||||
|
|
||||||
|
// Close release the resources used by the pool.
|
||||||
|
func (p *Pool) Close() error { |
||||||
|
return p.p.Close() |
||||||
|
} |
||||||
|
|
||||||
|
type pooledConnection struct { |
||||||
|
p *Pool |
||||||
|
c Conn |
||||||
|
ctx context.Context |
||||||
|
} |
||||||
|
|
||||||
|
func pstat(key string, t time.Time, err error) { |
||||||
|
stats.Timing(key, int64(time.Since(t)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
if msg := formatErr(err); msg != "" { |
||||||
|
stats.Incr("memcache", msg) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Close() error { |
||||||
|
c := pc.c |
||||||
|
if _, ok := c.(errorConnection); ok { |
||||||
|
return nil |
||||||
|
} |
||||||
|
pc.c = errorConnection{ErrConnClosed} |
||||||
|
pc.p.p.Put(context.Background(), c, c.Err() != nil) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Err() error { |
||||||
|
return pc.c.Err() |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Set(item *Item) (err error) { |
||||||
|
now := time.Now() |
||||||
|
err = pc.c.Set(item) |
||||||
|
pstat("memcache:set", now, err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Add(item *Item) (err error) { |
||||||
|
now := time.Now() |
||||||
|
err = pc.c.Add(item) |
||||||
|
pstat("memcache:add", now, err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Replace(item *Item) (err error) { |
||||||
|
now := time.Now() |
||||||
|
err = pc.c.Replace(item) |
||||||
|
pstat("memcache:replace", now, err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) CompareAndSwap(item *Item) (err error) { |
||||||
|
now := time.Now() |
||||||
|
err = pc.c.CompareAndSwap(item) |
||||||
|
pstat("memcache:cas", now, err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Get(key string) (r *Item, err error) { |
||||||
|
now := time.Now() |
||||||
|
r, err = pc.c.Get(key) |
||||||
|
pstat("memcache:get", now, err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) GetMulti(keys []string) (res map[string]*Item, err error) { |
||||||
|
// if keys is empty slice returns empty map direct
|
||||||
|
if len(keys) == 0 { |
||||||
|
return make(map[string]*Item), nil |
||||||
|
} |
||||||
|
now := time.Now() |
||||||
|
res, err = pc.c.GetMulti(keys) |
||||||
|
pstat("memcache:gets", now, err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Touch(key string, timeout int32) (err error) { |
||||||
|
now := time.Now() |
||||||
|
err = pc.c.Touch(key, timeout) |
||||||
|
pstat("memcache:touch", now, err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Scan(item *Item, v interface{}) error { |
||||||
|
return pc.c.Scan(item, v) |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) WithContext(ctx context.Context) Conn { |
||||||
|
// TODO: set context
|
||||||
|
pc.ctx = ctx |
||||||
|
return pc |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Delete(key string) (err error) { |
||||||
|
now := time.Now() |
||||||
|
err = pc.c.Delete(key) |
||||||
|
pstat("memcache:delete", now, err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Increment(key string, delta uint64) (newValue uint64, err error) { |
||||||
|
now := time.Now() |
||||||
|
newValue, err = pc.c.Increment(key, delta) |
||||||
|
pstat("memcache:increment", now, err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Decrement(key string, delta uint64) (newValue uint64, err error) { |
||||||
|
now := time.Now() |
||||||
|
newValue, err = pc.c.Decrement(key, delta) |
||||||
|
pstat("memcache:decrement", now, err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
type errorConnection struct{ err error } |
||||||
|
|
||||||
|
func (ec errorConnection) Err() error { return ec.err } |
||||||
|
func (ec errorConnection) Close() error { return ec.err } |
||||||
|
func (ec errorConnection) Add(item *Item) error { return ec.err } |
||||||
|
func (ec errorConnection) Set(item *Item) error { return ec.err } |
||||||
|
func (ec errorConnection) Replace(item *Item) error { return ec.err } |
||||||
|
func (ec errorConnection) CompareAndSwap(item *Item) error { return ec.err } |
||||||
|
func (ec errorConnection) Get(key string) (*Item, error) { return nil, ec.err } |
||||||
|
func (ec errorConnection) GetMulti(keys []string) (map[string]*Item, error) { return nil, ec.err } |
||||||
|
func (ec errorConnection) Touch(key string, timeout int32) error { return ec.err } |
||||||
|
func (ec errorConnection) Delete(key string) error { return ec.err } |
||||||
|
func (ec errorConnection) Increment(key string, delta uint64) (uint64, error) { return 0, ec.err } |
||||||
|
func (ec errorConnection) Decrement(key string, delta uint64) (uint64, error) { return 0, ec.err } |
||||||
|
func (ec errorConnection) Scan(item *Item, v interface{}) error { return ec.err } |
||||||
|
func (ec errorConnection) WithContext(ctx context.Context) Conn { return ec } |
@ -0,0 +1,109 @@ |
|||||||
|
package memcache |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"strconv" |
||||||
|
"strings" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
"github.com/bilibili/kratos/pkg/net/trace" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
_traceFamily = "memcache" |
||||||
|
_traceSpanKind = "client" |
||||||
|
_traceComponentName = "library/cache/memcache" |
||||||
|
_tracePeerService = "memcache" |
||||||
|
_slowLogDuration = time.Millisecond * 250 |
||||||
|
) |
||||||
|
|
||||||
|
type traceConn struct { |
||||||
|
Conn |
||||||
|
ctx context.Context |
||||||
|
address string |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) setTrace(action, statement string) func(error) error { |
||||||
|
now := time.Now() |
||||||
|
parent, ok := trace.FromContext(t.ctx) |
||||||
|
if !ok { |
||||||
|
return func(err error) error { return err } |
||||||
|
} |
||||||
|
span := parent.Fork(_traceFamily, "Memcache:"+action) |
||||||
|
span.SetTag( |
||||||
|
trace.String(trace.TagSpanKind, _traceSpanKind), |
||||||
|
trace.String(trace.TagComponent, _traceComponentName), |
||||||
|
trace.String(trace.TagPeerService, _tracePeerService), |
||||||
|
trace.String(trace.TagPeerAddress, t.address), |
||||||
|
trace.String(trace.TagDBStatement, action+" "+statement), |
||||||
|
) |
||||||
|
return func(err error) error { |
||||||
|
span.Finish(&err) |
||||||
|
t := time.Since(now) |
||||||
|
if t > _slowLogDuration { |
||||||
|
log.Warn("%s slow log action: %s key: %s time: %v", _traceFamily, action, statement, t) |
||||||
|
} |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) WithContext(ctx context.Context) Conn { |
||||||
|
t.ctx = ctx |
||||||
|
t.Conn = t.Conn.WithContext(ctx) |
||||||
|
return t |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) Add(item *Item) error { |
||||||
|
finishFn := t.setTrace("Add", item.Key) |
||||||
|
return finishFn(t.Conn.Add(item)) |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) Set(item *Item) error { |
||||||
|
finishFn := t.setTrace("Set", item.Key) |
||||||
|
return finishFn(t.Conn.Set(item)) |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) Replace(item *Item) error { |
||||||
|
finishFn := t.setTrace("Replace", item.Key) |
||||||
|
return finishFn(t.Conn.Replace(item)) |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) Get(key string) (*Item, error) { |
||||||
|
finishFn := t.setTrace("Get", key) |
||||||
|
item, err := t.Conn.Get(key) |
||||||
|
return item, finishFn(err) |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) GetMulti(keys []string) (map[string]*Item, error) { |
||||||
|
finishFn := t.setTrace("GetMulti", strings.Join(keys, " ")) |
||||||
|
items, err := t.Conn.GetMulti(keys) |
||||||
|
return items, finishFn(err) |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) Delete(key string) error { |
||||||
|
finishFn := t.setTrace("Delete", key) |
||||||
|
return finishFn(t.Conn.Delete(key)) |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) Increment(key string, delta uint64) (newValue uint64, err error) { |
||||||
|
finishFn := t.setTrace("Increment", key+" "+strconv.FormatUint(delta, 10)) |
||||||
|
newValue, err = t.Conn.Increment(key, delta) |
||||||
|
return newValue, finishFn(err) |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) Decrement(key string, delta uint64) (newValue uint64, err error) { |
||||||
|
finishFn := t.setTrace("Decrement", key+" "+strconv.FormatUint(delta, 10)) |
||||||
|
newValue, err = t.Conn.Decrement(key, delta) |
||||||
|
return newValue, finishFn(err) |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) CompareAndSwap(item *Item) error { |
||||||
|
finishFn := t.setTrace("CompareAndSwap", item.Key) |
||||||
|
return finishFn(t.Conn.CompareAndSwap(item)) |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) Touch(key string, seconds int32) (err error) { |
||||||
|
finishFn := t.setTrace("Touch", key+" "+strconv.Itoa(int(seconds))) |
||||||
|
return finishFn(t.Conn.Touch(key, seconds)) |
||||||
|
} |
@ -0,0 +1,32 @@ |
|||||||
|
package memcache |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/gogo/protobuf/proto" |
||||||
|
) |
||||||
|
|
||||||
|
// RawItem item with FlagRAW flag.
|
||||||
|
//
|
||||||
|
// Expiration is the cache expiration time, in seconds: either a relative
|
||||||
|
// time from now (up to 1 month), or an absolute Unix epoch time.
|
||||||
|
// Zero means the Item has no expiration time.
|
||||||
|
func RawItem(key string, data []byte, flags uint32, expiration int32) *Item { |
||||||
|
return &Item{Key: key, Flags: flags | FlagRAW, Value: data, Expiration: expiration} |
||||||
|
} |
||||||
|
|
||||||
|
// JSONItem item with FlagJSON flag.
|
||||||
|
//
|
||||||
|
// Expiration is the cache expiration time, in seconds: either a relative
|
||||||
|
// time from now (up to 1 month), or an absolute Unix epoch time.
|
||||||
|
// Zero means the Item has no expiration time.
|
||||||
|
func JSONItem(key string, v interface{}, flags uint32, expiration int32) *Item { |
||||||
|
return &Item{Key: key, Flags: flags | FlagJSON, Object: v, Expiration: expiration} |
||||||
|
} |
||||||
|
|
||||||
|
// ProtobufItem item with FlagProtobuf flag.
|
||||||
|
//
|
||||||
|
// Expiration is the cache expiration time, in seconds: either a relative
|
||||||
|
// time from now (up to 1 month), or an absolute Unix epoch time.
|
||||||
|
// Zero means the Item has no expiration time.
|
||||||
|
func ProtobufItem(key string, message proto.Message, flags uint32, expiration int32) *Item { |
||||||
|
return &Item{Key: key, Flags: flags | FlagProtobuf, Object: message, Expiration: expiration} |
||||||
|
} |
@ -0,0 +1,7 @@ |
|||||||
|
# cache/redis |
||||||
|
|
||||||
|
##### 项目简介 |
||||||
|
1. 提供redis接口 |
||||||
|
|
||||||
|
#### 使用方式 |
||||||
|
请参考doc.go |
@ -0,0 +1,57 @@ |
|||||||
|
// Copyright 2014 Gary Burd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
// License for the specific language governing permissions and limitations
|
||||||
|
// under the License.
|
||||||
|
|
||||||
|
package redis |
||||||
|
|
||||||
|
import ( |
||||||
|
"strings" |
||||||
|
) |
||||||
|
|
||||||
|
// redis state
|
||||||
|
const ( |
||||||
|
WatchState = 1 << iota |
||||||
|
MultiState |
||||||
|
SubscribeState |
||||||
|
MonitorState |
||||||
|
) |
||||||
|
|
||||||
|
// CommandInfo command info.
|
||||||
|
type CommandInfo struct { |
||||||
|
Set, Clear int |
||||||
|
} |
||||||
|
|
||||||
|
var commandInfos = map[string]CommandInfo{ |
||||||
|
"WATCH": {Set: WatchState}, |
||||||
|
"UNWATCH": {Clear: WatchState}, |
||||||
|
"MULTI": {Set: MultiState}, |
||||||
|
"EXEC": {Clear: WatchState | MultiState}, |
||||||
|
"DISCARD": {Clear: WatchState | MultiState}, |
||||||
|
"PSUBSCRIBE": {Set: SubscribeState}, |
||||||
|
"SUBSCRIBE": {Set: SubscribeState}, |
||||||
|
"MONITOR": {Set: MonitorState}, |
||||||
|
} |
||||||
|
|
||||||
|
func init() { |
||||||
|
for n, ci := range commandInfos { |
||||||
|
commandInfos[strings.ToLower(n)] = ci |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// LookupCommandInfo get command info.
|
||||||
|
func LookupCommandInfo(commandName string) CommandInfo { |
||||||
|
if ci, ok := commandInfos[commandName]; ok { |
||||||
|
return ci |
||||||
|
} |
||||||
|
return commandInfos[strings.ToUpper(commandName)] |
||||||
|
} |
@ -0,0 +1,597 @@ |
|||||||
|
// Copyright 2012 Gary Burd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
// License for the specific language governing permissions and limitations
|
||||||
|
// under the License.
|
||||||
|
|
||||||
|
package redis |
||||||
|
|
||||||
|
import ( |
||||||
|
"bufio" |
||||||
|
"bytes" |
||||||
|
"context" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"net" |
||||||
|
"net/url" |
||||||
|
"regexp" |
||||||
|
"strconv" |
||||||
|
"sync" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/stat" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var stats = stat.Cache |
||||||
|
|
||||||
|
// conn is the low-level implementation of Conn
|
||||||
|
type conn struct { |
||||||
|
|
||||||
|
// Shared
|
||||||
|
mu sync.Mutex |
||||||
|
pending int |
||||||
|
err error |
||||||
|
conn net.Conn |
||||||
|
|
||||||
|
// Read
|
||||||
|
readTimeout time.Duration |
||||||
|
br *bufio.Reader |
||||||
|
|
||||||
|
// Write
|
||||||
|
writeTimeout time.Duration |
||||||
|
bw *bufio.Writer |
||||||
|
|
||||||
|
// Scratch space for formatting argument length.
|
||||||
|
// '*' or '$', length, "\r\n"
|
||||||
|
lenScratch [32]byte |
||||||
|
|
||||||
|
// Scratch space for formatting integers and floats.
|
||||||
|
numScratch [40]byte |
||||||
|
// stat func,default prom
|
||||||
|
stat func(string, *error) func() |
||||||
|
} |
||||||
|
|
||||||
|
func statfunc(cmd string, err *error) func() { |
||||||
|
now := time.Now() |
||||||
|
return func() { |
||||||
|
stats.Timing(fmt.Sprintf("redis:%s", cmd), int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
if msg := formatErr(*err); msg != "" { |
||||||
|
stats.Incr("redis", msg) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// DialTimeout acts like Dial but takes timeouts for establishing the
|
||||||
|
// connection to the server, writing a command and reading a reply.
|
||||||
|
//
|
||||||
|
// Deprecated: Use Dial with options instead.
|
||||||
|
func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) { |
||||||
|
return Dial(network, address, |
||||||
|
DialConnectTimeout(connectTimeout), |
||||||
|
DialReadTimeout(readTimeout), |
||||||
|
DialWriteTimeout(writeTimeout)) |
||||||
|
} |
||||||
|
|
||||||
|
// DialOption specifies an option for dialing a Redis server.
|
||||||
|
type DialOption struct { |
||||||
|
f func(*dialOptions) |
||||||
|
} |
||||||
|
|
||||||
|
type dialOptions struct { |
||||||
|
readTimeout time.Duration |
||||||
|
writeTimeout time.Duration |
||||||
|
dial func(network, addr string) (net.Conn, error) |
||||||
|
db int |
||||||
|
password string |
||||||
|
stat func(string, *error) func() |
||||||
|
} |
||||||
|
|
||||||
|
// DialStats specifies stat func for stats.default statfunc.
|
||||||
|
func DialStats(fn func(string, *error) func()) DialOption { |
||||||
|
return DialOption{func(do *dialOptions) { |
||||||
|
do.stat = fn |
||||||
|
}} |
||||||
|
} |
||||||
|
|
||||||
|
// DialReadTimeout specifies the timeout for reading a single command reply.
|
||||||
|
func DialReadTimeout(d time.Duration) DialOption { |
||||||
|
return DialOption{func(do *dialOptions) { |
||||||
|
do.readTimeout = d |
||||||
|
}} |
||||||
|
} |
||||||
|
|
||||||
|
// DialWriteTimeout specifies the timeout for writing a single command.
|
||||||
|
func DialWriteTimeout(d time.Duration) DialOption { |
||||||
|
return DialOption{func(do *dialOptions) { |
||||||
|
do.writeTimeout = d |
||||||
|
}} |
||||||
|
} |
||||||
|
|
||||||
|
// DialConnectTimeout specifies the timeout for connecting to the Redis server.
|
||||||
|
func DialConnectTimeout(d time.Duration) DialOption { |
||||||
|
return DialOption{func(do *dialOptions) { |
||||||
|
dialer := net.Dialer{Timeout: d} |
||||||
|
do.dial = dialer.Dial |
||||||
|
}} |
||||||
|
} |
||||||
|
|
||||||
|
// DialNetDial specifies a custom dial function for creating TCP
|
||||||
|
// connections. If this option is left out, then net.Dial is
|
||||||
|
// used. DialNetDial overrides DialConnectTimeout.
|
||||||
|
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption { |
||||||
|
return DialOption{func(do *dialOptions) { |
||||||
|
do.dial = dial |
||||||
|
}} |
||||||
|
} |
||||||
|
|
||||||
|
// DialDatabase specifies the database to select when dialing a connection.
|
||||||
|
func DialDatabase(db int) DialOption { |
||||||
|
return DialOption{func(do *dialOptions) { |
||||||
|
do.db = db |
||||||
|
}} |
||||||
|
} |
||||||
|
|
||||||
|
// DialPassword specifies the password to use when connecting to
|
||||||
|
// the Redis server.
|
||||||
|
func DialPassword(password string) DialOption { |
||||||
|
return DialOption{func(do *dialOptions) { |
||||||
|
do.password = password |
||||||
|
}} |
||||||
|
} |
||||||
|
|
||||||
|
// Dial connects to the Redis server at the given network and
|
||||||
|
// address using the specified options.
|
||||||
|
func Dial(network, address string, options ...DialOption) (Conn, error) { |
||||||
|
do := dialOptions{ |
||||||
|
dial: net.Dial, |
||||||
|
} |
||||||
|
for _, option := range options { |
||||||
|
option.f(&do) |
||||||
|
} |
||||||
|
|
||||||
|
netConn, err := do.dial(network, address) |
||||||
|
if err != nil { |
||||||
|
return nil, errors.WithStack(err) |
||||||
|
} |
||||||
|
c := &conn{ |
||||||
|
conn: netConn, |
||||||
|
bw: bufio.NewWriter(netConn), |
||||||
|
br: bufio.NewReader(netConn), |
||||||
|
readTimeout: do.readTimeout, |
||||||
|
writeTimeout: do.writeTimeout, |
||||||
|
stat: statfunc, |
||||||
|
} |
||||||
|
|
||||||
|
if do.password != "" { |
||||||
|
if _, err := c.Do("AUTH", do.password); err != nil { |
||||||
|
netConn.Close() |
||||||
|
return nil, errors.WithStack(err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if do.db != 0 { |
||||||
|
if _, err := c.Do("SELECT", do.db); err != nil { |
||||||
|
netConn.Close() |
||||||
|
return nil, errors.WithStack(err) |
||||||
|
} |
||||||
|
} |
||||||
|
if do.stat != nil { |
||||||
|
c.stat = do.stat |
||||||
|
} |
||||||
|
return c, nil |
||||||
|
} |
||||||
|
|
||||||
|
var pathDBRegexp = regexp.MustCompile(`/(\d+)\z`) |
||||||
|
|
||||||
|
// DialURL connects to a Redis server at the given URL using the Redis
|
||||||
|
// URI scheme. URLs should follow the draft IANA specification for the
|
||||||
|
// scheme (https://www.iana.org/assignments/uri-schemes/prov/redis).
|
||||||
|
func DialURL(rawurl string, options ...DialOption) (Conn, error) { |
||||||
|
u, err := url.Parse(rawurl) |
||||||
|
if err != nil { |
||||||
|
return nil, errors.WithStack(err) |
||||||
|
} |
||||||
|
|
||||||
|
if u.Scheme != "redis" { |
||||||
|
return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme) |
||||||
|
} |
||||||
|
|
||||||
|
// As per the IANA draft spec, the host defaults to localhost and
|
||||||
|
// the port defaults to 6379.
|
||||||
|
host, port, err := net.SplitHostPort(u.Host) |
||||||
|
if err != nil { |
||||||
|
// assume port is missing
|
||||||
|
host = u.Host |
||||||
|
port = "6379" |
||||||
|
} |
||||||
|
if host == "" { |
||||||
|
host = "localhost" |
||||||
|
} |
||||||
|
address := net.JoinHostPort(host, port) |
||||||
|
|
||||||
|
if u.User != nil { |
||||||
|
password, isSet := u.User.Password() |
||||||
|
if isSet { |
||||||
|
options = append(options, DialPassword(password)) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
match := pathDBRegexp.FindStringSubmatch(u.Path) |
||||||
|
if len(match) == 2 { |
||||||
|
db, err := strconv.Atoi(match[1]) |
||||||
|
if err != nil { |
||||||
|
return nil, errors.Errorf("invalid database: %s", u.Path[1:]) |
||||||
|
} |
||||||
|
if db != 0 { |
||||||
|
options = append(options, DialDatabase(db)) |
||||||
|
} |
||||||
|
} else if u.Path != "" { |
||||||
|
return nil, errors.Errorf("invalid database: %s", u.Path[1:]) |
||||||
|
} |
||||||
|
|
||||||
|
return Dial("tcp", address, options...) |
||||||
|
} |
||||||
|
|
||||||
|
// NewConn new a redis conn.
|
||||||
|
func NewConn(c *Config) (cn Conn, err error) { |
||||||
|
cnop := DialConnectTimeout(time.Duration(c.DialTimeout)) |
||||||
|
rdop := DialReadTimeout(time.Duration(c.ReadTimeout)) |
||||||
|
wrop := DialWriteTimeout(time.Duration(c.WriteTimeout)) |
||||||
|
auop := DialPassword(c.Auth) |
||||||
|
// new conn
|
||||||
|
cn, err = Dial(c.Proto, c.Addr, cnop, rdop, wrop, auop) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Close() error { |
||||||
|
c.mu.Lock() |
||||||
|
err := c.err |
||||||
|
if c.err == nil { |
||||||
|
c.err = errors.New("redigo: closed") |
||||||
|
err = c.conn.Close() |
||||||
|
} |
||||||
|
c.mu.Unlock() |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) fatal(err error) error { |
||||||
|
c.mu.Lock() |
||||||
|
if c.err == nil { |
||||||
|
c.err = err |
||||||
|
// Close connection to force errors on subsequent calls and to unblock
|
||||||
|
// other reader or writer.
|
||||||
|
c.conn.Close() |
||||||
|
} |
||||||
|
c.mu.Unlock() |
||||||
|
return errors.WithStack(c.err) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Err() error { |
||||||
|
c.mu.Lock() |
||||||
|
err := c.err |
||||||
|
c.mu.Unlock() |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) writeLen(prefix byte, n int) error { |
||||||
|
c.lenScratch[len(c.lenScratch)-1] = '\n' |
||||||
|
c.lenScratch[len(c.lenScratch)-2] = '\r' |
||||||
|
i := len(c.lenScratch) - 3 |
||||||
|
for { |
||||||
|
c.lenScratch[i] = byte('0' + n%10) |
||||||
|
i-- |
||||||
|
n = n / 10 |
||||||
|
if n == 0 { |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
c.lenScratch[i] = prefix |
||||||
|
_, err := c.bw.Write(c.lenScratch[i:]) |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) writeString(s string) error { |
||||||
|
c.writeLen('$', len(s)) |
||||||
|
c.bw.WriteString(s) |
||||||
|
_, err := c.bw.WriteString("\r\n") |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) writeBytes(p []byte) error { |
||||||
|
c.writeLen('$', len(p)) |
||||||
|
c.bw.Write(p) |
||||||
|
_, err := c.bw.WriteString("\r\n") |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) writeInt64(n int64) error { |
||||||
|
return errors.WithStack(c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10))) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) writeFloat64(n float64) error { |
||||||
|
return errors.WithStack(c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) writeCommand(cmd string, args []interface{}) (err error) { |
||||||
|
if c.writeTimeout != 0 { |
||||||
|
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
||||||
|
} |
||||||
|
c.writeLen('*', 1+len(args)) |
||||||
|
err = c.writeString(cmd) |
||||||
|
for _, arg := range args { |
||||||
|
if err != nil { |
||||||
|
break |
||||||
|
} |
||||||
|
switch arg := arg.(type) { |
||||||
|
case string: |
||||||
|
err = c.writeString(arg) |
||||||
|
case []byte: |
||||||
|
err = c.writeBytes(arg) |
||||||
|
case int: |
||||||
|
err = c.writeInt64(int64(arg)) |
||||||
|
case int64: |
||||||
|
err = c.writeInt64(arg) |
||||||
|
case float64: |
||||||
|
err = c.writeFloat64(arg) |
||||||
|
case bool: |
||||||
|
if arg { |
||||||
|
err = c.writeString("1") |
||||||
|
} else { |
||||||
|
err = c.writeString("0") |
||||||
|
} |
||||||
|
case nil: |
||||||
|
err = c.writeString("") |
||||||
|
default: |
||||||
|
var buf bytes.Buffer |
||||||
|
fmt.Fprint(&buf, arg) |
||||||
|
err = errors.WithStack(c.writeBytes(buf.Bytes())) |
||||||
|
} |
||||||
|
} |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
type protocolError string |
||||||
|
|
||||||
|
func (pe protocolError) Error() string { |
||||||
|
return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe)) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) readLine() ([]byte, error) { |
||||||
|
p, err := c.br.ReadSlice('\n') |
||||||
|
if err == bufio.ErrBufferFull { |
||||||
|
return nil, errors.WithStack(protocolError("long response line")) |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
i := len(p) - 2 |
||||||
|
if i < 0 || p[i] != '\r' { |
||||||
|
return nil, errors.WithStack(protocolError("bad response line terminator")) |
||||||
|
} |
||||||
|
return p[:i], nil |
||||||
|
} |
||||||
|
|
||||||
|
// parseLen parses bulk string and array lengths.
|
||||||
|
func parseLen(p []byte) (int, error) { |
||||||
|
if len(p) == 0 { |
||||||
|
return -1, errors.WithStack(protocolError("malformed length")) |
||||||
|
} |
||||||
|
|
||||||
|
if p[0] == '-' && len(p) == 2 && p[1] == '1' { |
||||||
|
// handle $-1 and $-1 null replies.
|
||||||
|
return -1, nil |
||||||
|
} |
||||||
|
|
||||||
|
var n int |
||||||
|
for _, b := range p { |
||||||
|
n *= 10 |
||||||
|
if b < '0' || b > '9' { |
||||||
|
return -1, errors.WithStack(protocolError("illegal bytes in length")) |
||||||
|
} |
||||||
|
n += int(b - '0') |
||||||
|
} |
||||||
|
|
||||||
|
return n, nil |
||||||
|
} |
||||||
|
|
||||||
|
// parseInt parses an integer reply.
|
||||||
|
func parseInt(p []byte) (interface{}, error) { |
||||||
|
if len(p) == 0 { |
||||||
|
return 0, errors.WithStack(protocolError("malformed integer")) |
||||||
|
} |
||||||
|
|
||||||
|
var negate bool |
||||||
|
if p[0] == '-' { |
||||||
|
negate = true |
||||||
|
p = p[1:] |
||||||
|
if len(p) == 0 { |
||||||
|
return 0, errors.WithStack(protocolError("malformed integer")) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
var n int64 |
||||||
|
for _, b := range p { |
||||||
|
n *= 10 |
||||||
|
if b < '0' || b > '9' { |
||||||
|
return 0, errors.WithStack(protocolError("illegal bytes in length")) |
||||||
|
} |
||||||
|
n += int64(b - '0') |
||||||
|
} |
||||||
|
|
||||||
|
if negate { |
||||||
|
n = -n |
||||||
|
} |
||||||
|
return n, nil |
||||||
|
} |
||||||
|
|
||||||
|
var ( |
||||||
|
okReply interface{} = "OK" |
||||||
|
pongReply interface{} = "PONG" |
||||||
|
) |
||||||
|
|
||||||
|
func (c *conn) readReply() (interface{}, error) { |
||||||
|
line, err := c.readLine() |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if len(line) == 0 { |
||||||
|
return nil, errors.WithStack(protocolError("short response line")) |
||||||
|
} |
||||||
|
switch line[0] { |
||||||
|
case '+': |
||||||
|
switch { |
||||||
|
case len(line) == 3 && line[1] == 'O' && line[2] == 'K': |
||||||
|
// Avoid allocation for frequent "+OK" response.
|
||||||
|
return okReply, nil |
||||||
|
case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G': |
||||||
|
// Avoid allocation in PING command benchmarks :)
|
||||||
|
return pongReply, nil |
||||||
|
default: |
||||||
|
return string(line[1:]), nil |
||||||
|
} |
||||||
|
case '-': |
||||||
|
return Error(string(line[1:])), nil |
||||||
|
case ':': |
||||||
|
return parseInt(line[1:]) |
||||||
|
case '$': |
||||||
|
n, err := parseLen(line[1:]) |
||||||
|
if n < 0 || err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
p := make([]byte, n) |
||||||
|
_, err = io.ReadFull(c.br, p) |
||||||
|
if err != nil { |
||||||
|
return nil, errors.WithStack(err) |
||||||
|
} |
||||||
|
if line1, err := c.readLine(); err != nil { |
||||||
|
return nil, err |
||||||
|
} else if len(line1) != 0 { |
||||||
|
return nil, errors.WithStack(protocolError("bad bulk string format")) |
||||||
|
} |
||||||
|
return p, nil |
||||||
|
case '*': |
||||||
|
n, err := parseLen(line[1:]) |
||||||
|
if n < 0 || err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
r := make([]interface{}, n) |
||||||
|
for i := range r { |
||||||
|
r[i], err = c.readReply() |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
} |
||||||
|
return r, nil |
||||||
|
} |
||||||
|
return nil, errors.WithStack(protocolError("unexpected response line")) |
||||||
|
} |
||||||
|
func (c *conn) Send(cmd string, args ...interface{}) (err error) { |
||||||
|
c.mu.Lock() |
||||||
|
c.pending++ |
||||||
|
c.mu.Unlock() |
||||||
|
if err = c.writeCommand(cmd, args); err != nil { |
||||||
|
c.fatal(err) |
||||||
|
} |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Flush() (err error) { |
||||||
|
if c.writeTimeout != 0 { |
||||||
|
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
||||||
|
} |
||||||
|
if err = c.bw.Flush(); err != nil { |
||||||
|
c.fatal(err) |
||||||
|
} |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Receive() (reply interface{}, err error) { |
||||||
|
if c.readTimeout != 0 { |
||||||
|
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) |
||||||
|
} |
||||||
|
if reply, err = c.readReply(); err != nil { |
||||||
|
return nil, c.fatal(err) |
||||||
|
} |
||||||
|
// When using pub/sub, the number of receives can be greater than the
|
||||||
|
// number of sends. To enable normal use of the connection after
|
||||||
|
// unsubscribing from all channels, we do not decrement pending to a
|
||||||
|
// negative value.
|
||||||
|
//
|
||||||
|
// The pending field is decremented after the reply is read to handle the
|
||||||
|
// case where Receive is called before Send.
|
||||||
|
c.mu.Lock() |
||||||
|
if c.pending > 0 { |
||||||
|
c.pending-- |
||||||
|
} |
||||||
|
c.mu.Unlock() |
||||||
|
if err, ok := reply.(Error); ok { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { |
||||||
|
c.mu.Lock() |
||||||
|
pending := c.pending |
||||||
|
c.pending = 0 |
||||||
|
c.mu.Unlock() |
||||||
|
if cmd == "" && pending == 0 { |
||||||
|
return nil, nil |
||||||
|
} |
||||||
|
var err error |
||||||
|
defer c.stat(cmd, &err)() |
||||||
|
if cmd != "" { |
||||||
|
err = c.writeCommand(cmd, args) |
||||||
|
} |
||||||
|
if err == nil { |
||||||
|
err = errors.WithStack(c.bw.Flush()) |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
return nil, c.fatal(err) |
||||||
|
} |
||||||
|
if c.readTimeout != 0 { |
||||||
|
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) |
||||||
|
} |
||||||
|
if cmd == "" { |
||||||
|
reply := make([]interface{}, pending) |
||||||
|
for i := range reply { |
||||||
|
var r interface{} |
||||||
|
r, err = c.readReply() |
||||||
|
if err != nil { |
||||||
|
break |
||||||
|
} |
||||||
|
reply[i] = r |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
return nil, c.fatal(err) |
||||||
|
} |
||||||
|
return reply, nil |
||||||
|
} |
||||||
|
|
||||||
|
var reply interface{} |
||||||
|
for i := 0; i <= pending; i++ { |
||||||
|
var e error |
||||||
|
if reply, e = c.readReply(); e != nil { |
||||||
|
return nil, c.fatal(e) |
||||||
|
} |
||||||
|
if e, ok := reply.(Error); ok && err == nil { |
||||||
|
err = e |
||||||
|
} |
||||||
|
} |
||||||
|
return reply, err |
||||||
|
} |
||||||
|
|
||||||
|
// WithContext FIXME: implement WithContext
|
||||||
|
func (c *conn) WithContext(ctx context.Context) Conn { return c } |
@ -0,0 +1,169 @@ |
|||||||
|
// Copyright 2012 Gary Burd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
// License for the specific language governing permissions and limitations
|
||||||
|
// under the License.
|
||||||
|
|
||||||
|
// Package redis is a client for the Redis database.
|
||||||
|
//
|
||||||
|
// The Redigo FAQ (https://github.com/garyburd/redigo/wiki/FAQ) contains more
|
||||||
|
// documentation about this package.
|
||||||
|
//
|
||||||
|
// Connections
|
||||||
|
//
|
||||||
|
// The Conn interface is the primary interface for working with Redis.
|
||||||
|
// Applications create connections by calling the Dial, DialWithTimeout or
|
||||||
|
// NewConn functions. In the future, functions will be added for creating
|
||||||
|
// sharded and other types of connections.
|
||||||
|
//
|
||||||
|
// The application must call the connection Close method when the application
|
||||||
|
// is done with the connection.
|
||||||
|
//
|
||||||
|
// Executing Commands
|
||||||
|
//
|
||||||
|
// The Conn interface has a generic method for executing Redis commands:
|
||||||
|
//
|
||||||
|
// Do(commandName string, args ...interface{}) (reply interface{}, err error)
|
||||||
|
//
|
||||||
|
// The Redis command reference (http://redis.io/commands) lists the available
|
||||||
|
// commands. An example of using the Redis APPEND command is:
|
||||||
|
//
|
||||||
|
// n, err := conn.Do("APPEND", "key", "value")
|
||||||
|
//
|
||||||
|
// The Do method converts command arguments to binary strings for transmission
|
||||||
|
// to the server as follows:
|
||||||
|
//
|
||||||
|
// Go Type Conversion
|
||||||
|
// []byte Sent as is
|
||||||
|
// string Sent as is
|
||||||
|
// int, int64 strconv.FormatInt(v)
|
||||||
|
// float64 strconv.FormatFloat(v, 'g', -1, 64)
|
||||||
|
// bool true -> "1", false -> "0"
|
||||||
|
// nil ""
|
||||||
|
// all other types fmt.Print(v)
|
||||||
|
//
|
||||||
|
// Redis command reply types are represented using the following Go types:
|
||||||
|
//
|
||||||
|
// Redis type Go type
|
||||||
|
// error redis.Error
|
||||||
|
// integer int64
|
||||||
|
// simple string string
|
||||||
|
// bulk string []byte or nil if value not present.
|
||||||
|
// array []interface{} or nil if value not present.
|
||||||
|
//
|
||||||
|
// Use type assertions or the reply helper functions to convert from
|
||||||
|
// interface{} to the specific Go type for the command result.
|
||||||
|
//
|
||||||
|
// Pipelining
|
||||||
|
//
|
||||||
|
// Connections support pipelining using the Send, Flush and Receive methods.
|
||||||
|
//
|
||||||
|
// Send(commandName string, args ...interface{}) error
|
||||||
|
// Flush() error
|
||||||
|
// Receive() (reply interface{}, err error)
|
||||||
|
//
|
||||||
|
// Send writes the command to the connection's output buffer. Flush flushes the
|
||||||
|
// connection's output buffer to the server. Receive reads a single reply from
|
||||||
|
// the server. The following example shows a simple pipeline.
|
||||||
|
//
|
||||||
|
// c.Send("SET", "foo", "bar")
|
||||||
|
// c.Send("GET", "foo")
|
||||||
|
// c.Flush()
|
||||||
|
// c.Receive() // reply from SET
|
||||||
|
// v, err = c.Receive() // reply from GET
|
||||||
|
//
|
||||||
|
// The Do method combines the functionality of the Send, Flush and Receive
|
||||||
|
// methods. The Do method starts by writing the command and flushing the output
|
||||||
|
// buffer. Next, the Do method receives all pending replies including the reply
|
||||||
|
// for the command just sent by Do. If any of the received replies is an error,
|
||||||
|
// then Do returns the error. If there are no errors, then Do returns the last
|
||||||
|
// reply. If the command argument to the Do method is "", then the Do method
|
||||||
|
// will flush the output buffer and receive pending replies without sending a
|
||||||
|
// command.
|
||||||
|
//
|
||||||
|
// Use the Send and Do methods to implement pipelined transactions.
|
||||||
|
//
|
||||||
|
// c.Send("MULTI")
|
||||||
|
// c.Send("INCR", "foo")
|
||||||
|
// c.Send("INCR", "bar")
|
||||||
|
// r, err := c.Do("EXEC")
|
||||||
|
// fmt.Println(r) // prints [1, 1]
|
||||||
|
//
|
||||||
|
// Concurrency
|
||||||
|
//
|
||||||
|
// Connections do not support concurrent calls to the write methods (Send,
|
||||||
|
// Flush) or concurrent calls to the read method (Receive). Connections do
|
||||||
|
// allow a concurrent reader and writer.
|
||||||
|
//
|
||||||
|
// Because the Do method combines the functionality of Send, Flush and Receive,
|
||||||
|
// the Do method cannot be called concurrently with the other methods.
|
||||||
|
//
|
||||||
|
// For full concurrent access to Redis, use the thread-safe Pool to get and
|
||||||
|
// release connections from within a goroutine.
|
||||||
|
//
|
||||||
|
// Publish and Subscribe
|
||||||
|
//
|
||||||
|
// Use the Send, Flush and Receive methods to implement Pub/Sub subscribers.
|
||||||
|
//
|
||||||
|
// c.Send("SUBSCRIBE", "example")
|
||||||
|
// c.Flush()
|
||||||
|
// for {
|
||||||
|
// reply, err := c.Receive()
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// // process pushed message
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// The PubSubConn type wraps a Conn with convenience methods for implementing
|
||||||
|
// subscribers. The Subscribe, PSubscribe, Unsubscribe and PUnsubscribe methods
|
||||||
|
// send and flush a subscription management command. The receive method
|
||||||
|
// converts a pushed message to convenient types for use in a type switch.
|
||||||
|
//
|
||||||
|
// psc := redis.PubSubConn{c}
|
||||||
|
// psc.Subscribe("example")
|
||||||
|
// for {
|
||||||
|
// switch v := psc.Receive().(type) {
|
||||||
|
// case redis.Message:
|
||||||
|
// fmt.Printf("%s: message: %s\n", v.Channel, v.Data)
|
||||||
|
// case redis.Subscription:
|
||||||
|
// fmt.Printf("%s: %s %d\n", v.Channel, v.Kind, v.Count)
|
||||||
|
// case error:
|
||||||
|
// return v
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Reply Helpers
|
||||||
|
//
|
||||||
|
// The Bool, Int, Bytes, String, Strings and Values functions convert a reply
|
||||||
|
// to a value of a specific type. To allow convenient wrapping of calls to the
|
||||||
|
// connection Do and Receive methods, the functions take a second argument of
|
||||||
|
// type error. If the error is non-nil, then the helper function returns the
|
||||||
|
// error. If the error is nil, the function converts the reply to the specified
|
||||||
|
// type:
|
||||||
|
//
|
||||||
|
// exists, err := redis.Bool(c.Do("EXISTS", "foo"))
|
||||||
|
// if err != nil {
|
||||||
|
// // handle error return from c.Do or type conversion error.
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// The Scan function converts elements of a array reply to Go types:
|
||||||
|
//
|
||||||
|
// var value1 int
|
||||||
|
// var value2 string
|
||||||
|
// reply, err := redis.Values(c.Do("MGET", "key1", "key2"))
|
||||||
|
// if err != nil {
|
||||||
|
// // handle error
|
||||||
|
// }
|
||||||
|
// if _, err := redis.Scan(reply, &value1, &value2); err != nil {
|
||||||
|
// // handle error
|
||||||
|
// }
|
||||||
|
package redis |
@ -0,0 +1,33 @@ |
|||||||
|
package redis |
||||||
|
|
||||||
|
import ( |
||||||
|
"strings" |
||||||
|
|
||||||
|
pkgerr "github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
func formatErr(err error) string { |
||||||
|
e := pkgerr.Cause(err) |
||||||
|
switch e { |
||||||
|
case ErrNil, nil: |
||||||
|
return "" |
||||||
|
default: |
||||||
|
es := e.Error() |
||||||
|
switch { |
||||||
|
case strings.HasPrefix(es, "read"): |
||||||
|
return "read timeout" |
||||||
|
case strings.HasPrefix(es, "dial"): |
||||||
|
return "dial timeout" |
||||||
|
case strings.HasPrefix(es, "write"): |
||||||
|
return "write timeout" |
||||||
|
case strings.Contains(es, "EOF"): |
||||||
|
return "eof" |
||||||
|
case strings.Contains(es, "reset"): |
||||||
|
return "reset" |
||||||
|
case strings.Contains(es, "broken"): |
||||||
|
return "broken pipe" |
||||||
|
default: |
||||||
|
return "unexpected err" |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,117 @@ |
|||||||
|
// Copyright 2012 Gary Burd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
// License for the specific language governing permissions and limitations
|
||||||
|
// under the License.
|
||||||
|
|
||||||
|
package redis |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"fmt" |
||||||
|
"log" |
||||||
|
) |
||||||
|
|
||||||
|
// NewLoggingConn returns a logging wrapper around a connection.
|
||||||
|
func NewLoggingConn(conn Conn, logger *log.Logger, prefix string) Conn { |
||||||
|
if prefix != "" { |
||||||
|
prefix = prefix + "." |
||||||
|
} |
||||||
|
return &loggingConn{conn, logger, prefix} |
||||||
|
} |
||||||
|
|
||||||
|
type loggingConn struct { |
||||||
|
Conn |
||||||
|
logger *log.Logger |
||||||
|
prefix string |
||||||
|
} |
||||||
|
|
||||||
|
func (c *loggingConn) Close() error { |
||||||
|
err := c.Conn.Close() |
||||||
|
var buf bytes.Buffer |
||||||
|
fmt.Fprintf(&buf, "%sClose() -> (%v)", c.prefix, err) |
||||||
|
c.logger.Output(2, buf.String()) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *loggingConn) printValue(buf *bytes.Buffer, v interface{}) { |
||||||
|
const chop = 32 |
||||||
|
switch v := v.(type) { |
||||||
|
case []byte: |
||||||
|
if len(v) > chop { |
||||||
|
fmt.Fprintf(buf, "%q...", v[:chop]) |
||||||
|
} else { |
||||||
|
fmt.Fprintf(buf, "%q", v) |
||||||
|
} |
||||||
|
case string: |
||||||
|
if len(v) > chop { |
||||||
|
fmt.Fprintf(buf, "%q...", v[:chop]) |
||||||
|
} else { |
||||||
|
fmt.Fprintf(buf, "%q", v) |
||||||
|
} |
||||||
|
case []interface{}: |
||||||
|
if len(v) == 0 { |
||||||
|
buf.WriteString("[]") |
||||||
|
} else { |
||||||
|
sep := "[" |
||||||
|
fin := "]" |
||||||
|
if len(v) > chop { |
||||||
|
v = v[:chop] |
||||||
|
fin = "...]" |
||||||
|
} |
||||||
|
for _, vv := range v { |
||||||
|
buf.WriteString(sep) |
||||||
|
c.printValue(buf, vv) |
||||||
|
sep = ", " |
||||||
|
} |
||||||
|
buf.WriteString(fin) |
||||||
|
} |
||||||
|
default: |
||||||
|
fmt.Fprint(buf, v) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (c *loggingConn) print(method, commandName string, args []interface{}, reply interface{}, err error) { |
||||||
|
var buf bytes.Buffer |
||||||
|
fmt.Fprintf(&buf, "%s%s(", c.prefix, method) |
||||||
|
if method != "Receive" { |
||||||
|
buf.WriteString(commandName) |
||||||
|
for _, arg := range args { |
||||||
|
buf.WriteString(", ") |
||||||
|
c.printValue(&buf, arg) |
||||||
|
} |
||||||
|
} |
||||||
|
buf.WriteString(") -> (") |
||||||
|
if method != "Send" { |
||||||
|
c.printValue(&buf, reply) |
||||||
|
buf.WriteString(", ") |
||||||
|
} |
||||||
|
fmt.Fprintf(&buf, "%v)", err) |
||||||
|
c.logger.Output(3, buf.String()) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{}, error) { |
||||||
|
reply, err := c.Conn.Do(commandName, args...) |
||||||
|
c.print("Do", commandName, args, reply, err) |
||||||
|
return reply, err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *loggingConn) Send(commandName string, args ...interface{}) error { |
||||||
|
err := c.Conn.Send(commandName, args...) |
||||||
|
c.print("Send", commandName, args, nil, err) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *loggingConn) Receive() (interface{}, error) { |
||||||
|
reply, err := c.Conn.Receive() |
||||||
|
c.print("Receive", "", nil, reply, err) |
||||||
|
return reply, err |
||||||
|
} |
@ -0,0 +1,36 @@ |
|||||||
|
package redis |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
) |
||||||
|
|
||||||
|
// MockErr for unit test.
|
||||||
|
type MockErr struct { |
||||||
|
Error error |
||||||
|
} |
||||||
|
|
||||||
|
// MockWith return a mock conn.
|
||||||
|
func MockWith(err error) MockErr { |
||||||
|
return MockErr{Error: err} |
||||||
|
} |
||||||
|
|
||||||
|
// Err .
|
||||||
|
func (m MockErr) Err() error { return m.Error } |
||||||
|
|
||||||
|
// Close .
|
||||||
|
func (m MockErr) Close() error { return m.Error } |
||||||
|
|
||||||
|
// Do .
|
||||||
|
func (m MockErr) Do(commandName string, args ...interface{}) (interface{}, error) { return nil, m.Error } |
||||||
|
|
||||||
|
// Send .
|
||||||
|
func (m MockErr) Send(commandName string, args ...interface{}) error { return m.Error } |
||||||
|
|
||||||
|
// Flush .
|
||||||
|
func (m MockErr) Flush() error { return m.Error } |
||||||
|
|
||||||
|
// Receive .
|
||||||
|
func (m MockErr) Receive() (interface{}, error) { return nil, m.Error } |
||||||
|
|
||||||
|
// WithContext .
|
||||||
|
func (m MockErr) WithContext(context.Context) Conn { return m } |
@ -0,0 +1,218 @@ |
|||||||
|
// Copyright 2012 Gary Burd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
// License for the specific language governing permissions and limitations
|
||||||
|
// under the License.
|
||||||
|
|
||||||
|
package redis |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"context" |
||||||
|
"crypto/rand" |
||||||
|
"crypto/sha1" |
||||||
|
"errors" |
||||||
|
"io" |
||||||
|
"strconv" |
||||||
|
"sync" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/container/pool" |
||||||
|
"github.com/bilibili/kratos/pkg/net/trace" |
||||||
|
xtime "github.com/bilibili/kratos/pkg/time" |
||||||
|
) |
||||||
|
|
||||||
|
var beginTime, _ = time.Parse("2006-01-02 15:04:05", "2006-01-02 15:04:05") |
||||||
|
|
||||||
|
var ( |
||||||
|
errConnClosed = errors.New("redigo: connection closed") |
||||||
|
) |
||||||
|
|
||||||
|
// Pool .
|
||||||
|
type Pool struct { |
||||||
|
*pool.Slice |
||||||
|
// config
|
||||||
|
c *Config |
||||||
|
} |
||||||
|
|
||||||
|
// Config client settings.
|
||||||
|
type Config struct { |
||||||
|
*pool.Config |
||||||
|
|
||||||
|
Name string // redis name, for trace
|
||||||
|
Proto string |
||||||
|
Addr string |
||||||
|
Auth string |
||||||
|
DialTimeout xtime.Duration |
||||||
|
ReadTimeout xtime.Duration |
||||||
|
WriteTimeout xtime.Duration |
||||||
|
} |
||||||
|
|
||||||
|
// NewPool creates a new pool.
|
||||||
|
func NewPool(c *Config, options ...DialOption) (p *Pool) { |
||||||
|
if c.DialTimeout <= 0 || c.ReadTimeout <= 0 || c.WriteTimeout <= 0 { |
||||||
|
panic("must config redis timeout") |
||||||
|
} |
||||||
|
p1 := pool.NewSlice(c.Config) |
||||||
|
cnop := DialConnectTimeout(time.Duration(c.DialTimeout)) |
||||||
|
options = append(options, cnop) |
||||||
|
rdop := DialReadTimeout(time.Duration(c.ReadTimeout)) |
||||||
|
options = append(options, rdop) |
||||||
|
wrop := DialWriteTimeout(time.Duration(c.WriteTimeout)) |
||||||
|
options = append(options, wrop) |
||||||
|
auop := DialPassword(c.Auth) |
||||||
|
options = append(options, auop) |
||||||
|
// new pool
|
||||||
|
p1.New = func(ctx context.Context) (io.Closer, error) { |
||||||
|
conn, err := Dial(c.Proto, c.Addr, options...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return &traceConn{Conn: conn, connTags: []trace.Tag{trace.TagString(trace.TagPeerAddress, c.Addr)}}, nil |
||||||
|
} |
||||||
|
p = &Pool{Slice: p1, c: c} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Get gets a connection. The application must close the returned connection.
|
||||||
|
// This method always returns a valid connection so that applications can defer
|
||||||
|
// error handling to the first use of the connection. If there is an error
|
||||||
|
// getting an underlying connection, then the connection Err, Do, Send, Flush
|
||||||
|
// and Receive methods return that error.
|
||||||
|
func (p *Pool) Get(ctx context.Context) Conn { |
||||||
|
c, err := p.Slice.Get(ctx) |
||||||
|
if err != nil { |
||||||
|
return errorConnection{err} |
||||||
|
} |
||||||
|
c1, _ := c.(Conn) |
||||||
|
return &pooledConnection{p: p, c: c1.WithContext(ctx), ctx: ctx, now: beginTime} |
||||||
|
} |
||||||
|
|
||||||
|
// Close releases the resources used by the pool.
|
||||||
|
func (p *Pool) Close() error { |
||||||
|
return p.Slice.Close() |
||||||
|
} |
||||||
|
|
||||||
|
type pooledConnection struct { |
||||||
|
p *Pool |
||||||
|
c Conn |
||||||
|
state int |
||||||
|
|
||||||
|
now time.Time |
||||||
|
cmds []string |
||||||
|
ctx context.Context |
||||||
|
} |
||||||
|
|
||||||
|
var ( |
||||||
|
sentinel []byte |
||||||
|
sentinelOnce sync.Once |
||||||
|
) |
||||||
|
|
||||||
|
func initSentinel() { |
||||||
|
p := make([]byte, 64) |
||||||
|
if _, err := rand.Read(p); err == nil { |
||||||
|
sentinel = p |
||||||
|
} else { |
||||||
|
h := sha1.New() |
||||||
|
io.WriteString(h, "Oops, rand failed. Use time instead.") |
||||||
|
io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10)) |
||||||
|
sentinel = h.Sum(nil) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Close() error { |
||||||
|
c := pc.c |
||||||
|
if _, ok := c.(errorConnection); ok { |
||||||
|
return nil |
||||||
|
} |
||||||
|
pc.c = errorConnection{errConnClosed} |
||||||
|
|
||||||
|
if pc.state&MultiState != 0 { |
||||||
|
c.Send("DISCARD") |
||||||
|
pc.state &^= (MultiState | WatchState) |
||||||
|
} else if pc.state&WatchState != 0 { |
||||||
|
c.Send("UNWATCH") |
||||||
|
pc.state &^= WatchState |
||||||
|
} |
||||||
|
if pc.state&SubscribeState != 0 { |
||||||
|
c.Send("UNSUBSCRIBE") |
||||||
|
c.Send("PUNSUBSCRIBE") |
||||||
|
// To detect the end of the message stream, ask the server to echo
|
||||||
|
// a sentinel value and read until we see that value.
|
||||||
|
sentinelOnce.Do(initSentinel) |
||||||
|
c.Send("ECHO", sentinel) |
||||||
|
c.Flush() |
||||||
|
for { |
||||||
|
p, err := c.Receive() |
||||||
|
if err != nil { |
||||||
|
break |
||||||
|
} |
||||||
|
if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) { |
||||||
|
pc.state &^= SubscribeState |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
_, err := c.Do("") |
||||||
|
pc.p.Slice.Put(context.Background(), c, pc.state != 0 || c.Err() != nil) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Err() error { |
||||||
|
return pc.c.Err() |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Do(commandName string, args ...interface{}) (reply interface{}, err error) { |
||||||
|
ci := LookupCommandInfo(commandName) |
||||||
|
pc.state = (pc.state | ci.Set) &^ ci.Clear |
||||||
|
reply, err = pc.c.Do(commandName, args...) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Send(commandName string, args ...interface{}) (err error) { |
||||||
|
ci := LookupCommandInfo(commandName) |
||||||
|
pc.state = (pc.state | ci.Set) &^ ci.Clear |
||||||
|
if pc.now.Equal(beginTime) { |
||||||
|
// mark first send time
|
||||||
|
pc.now = time.Now() |
||||||
|
} |
||||||
|
pc.cmds = append(pc.cmds, commandName) |
||||||
|
return pc.c.Send(commandName, args...) |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Flush() error { |
||||||
|
return pc.c.Flush() |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) Receive() (reply interface{}, err error) { |
||||||
|
reply, err = pc.c.Receive() |
||||||
|
if len(pc.cmds) > 0 { |
||||||
|
pc.cmds = pc.cmds[1:] |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (pc *pooledConnection) WithContext(ctx context.Context) Conn { |
||||||
|
pc.ctx = ctx |
||||||
|
return pc |
||||||
|
} |
||||||
|
|
||||||
|
type errorConnection struct{ err error } |
||||||
|
|
||||||
|
func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) { |
||||||
|
return nil, ec.err |
||||||
|
} |
||||||
|
func (ec errorConnection) Send(string, ...interface{}) error { return ec.err } |
||||||
|
func (ec errorConnection) Err() error { return ec.err } |
||||||
|
func (ec errorConnection) Close() error { return ec.err } |
||||||
|
func (ec errorConnection) Flush() error { return ec.err } |
||||||
|
func (ec errorConnection) Receive() (interface{}, error) { return nil, ec.err } |
||||||
|
func (ec errorConnection) WithContext(context.Context) Conn { return ec } |
@ -0,0 +1,152 @@ |
|||||||
|
// Copyright 2012 Gary Burd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
// License for the specific language governing permissions and limitations
|
||||||
|
// under the License.
|
||||||
|
|
||||||
|
package redis |
||||||
|
|
||||||
|
import ( |
||||||
|
"errors" |
||||||
|
|
||||||
|
pkgerr "github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
errPubSub = errors.New("redigo: unknown pubsub notification") |
||||||
|
) |
||||||
|
|
||||||
|
// Subscription represents a subscribe or unsubscribe notification.
|
||||||
|
type Subscription struct { |
||||||
|
|
||||||
|
// Kind is "subscribe", "unsubscribe", "psubscribe" or "punsubscribe"
|
||||||
|
Kind string |
||||||
|
|
||||||
|
// The channel that was changed.
|
||||||
|
Channel string |
||||||
|
|
||||||
|
// The current number of subscriptions for connection.
|
||||||
|
Count int |
||||||
|
} |
||||||
|
|
||||||
|
// Message represents a message notification.
|
||||||
|
type Message struct { |
||||||
|
|
||||||
|
// The originating channel.
|
||||||
|
Channel string |
||||||
|
|
||||||
|
// The message data.
|
||||||
|
Data []byte |
||||||
|
} |
||||||
|
|
||||||
|
// PMessage represents a pmessage notification.
|
||||||
|
type PMessage struct { |
||||||
|
|
||||||
|
// The matched pattern.
|
||||||
|
Pattern string |
||||||
|
|
||||||
|
// The originating channel.
|
||||||
|
Channel string |
||||||
|
|
||||||
|
// The message data.
|
||||||
|
Data []byte |
||||||
|
} |
||||||
|
|
||||||
|
// Pong represents a pubsub pong notification.
|
||||||
|
type Pong struct { |
||||||
|
Data string |
||||||
|
} |
||||||
|
|
||||||
|
// PubSubConn wraps a Conn with convenience methods for subscribers.
|
||||||
|
type PubSubConn struct { |
||||||
|
Conn Conn |
||||||
|
} |
||||||
|
|
||||||
|
// Close closes the connection.
|
||||||
|
func (c PubSubConn) Close() error { |
||||||
|
return c.Conn.Close() |
||||||
|
} |
||||||
|
|
||||||
|
// Subscribe subscribes the connection to the specified channels.
|
||||||
|
func (c PubSubConn) Subscribe(channel ...interface{}) error { |
||||||
|
c.Conn.Send("SUBSCRIBE", channel...) |
||||||
|
return c.Conn.Flush() |
||||||
|
} |
||||||
|
|
||||||
|
// PSubscribe subscribes the connection to the given patterns.
|
||||||
|
func (c PubSubConn) PSubscribe(channel ...interface{}) error { |
||||||
|
c.Conn.Send("PSUBSCRIBE", channel...) |
||||||
|
return c.Conn.Flush() |
||||||
|
} |
||||||
|
|
||||||
|
// Unsubscribe unsubscribes the connection from the given channels, or from all
|
||||||
|
// of them if none is given.
|
||||||
|
func (c PubSubConn) Unsubscribe(channel ...interface{}) error { |
||||||
|
c.Conn.Send("UNSUBSCRIBE", channel...) |
||||||
|
return c.Conn.Flush() |
||||||
|
} |
||||||
|
|
||||||
|
// PUnsubscribe unsubscribes the connection from the given patterns, or from all
|
||||||
|
// of them if none is given.
|
||||||
|
func (c PubSubConn) PUnsubscribe(channel ...interface{}) error { |
||||||
|
c.Conn.Send("PUNSUBSCRIBE", channel...) |
||||||
|
return c.Conn.Flush() |
||||||
|
} |
||||||
|
|
||||||
|
// Ping sends a PING to the server with the specified data.
|
||||||
|
func (c PubSubConn) Ping(data string) error { |
||||||
|
c.Conn.Send("PING", data) |
||||||
|
return c.Conn.Flush() |
||||||
|
} |
||||||
|
|
||||||
|
// Receive returns a pushed message as a Subscription, Message, PMessage, Pong
|
||||||
|
// or error. The return value is intended to be used directly in a type switch
|
||||||
|
// as illustrated in the PubSubConn example.
|
||||||
|
func (c PubSubConn) Receive() interface{} { |
||||||
|
reply, err := Values(c.Conn.Receive()) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
var kind string |
||||||
|
reply, err = Scan(reply, &kind) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
switch kind { |
||||||
|
case "message": |
||||||
|
var m Message |
||||||
|
if _, err := Scan(reply, &m.Channel, &m.Data); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return m |
||||||
|
case "pmessage": |
||||||
|
var pm PMessage |
||||||
|
if _, err := Scan(reply, &pm.Pattern, &pm.Channel, &pm.Data); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return pm |
||||||
|
case "subscribe", "psubscribe", "unsubscribe", "punsubscribe": |
||||||
|
s := Subscription{Kind: kind} |
||||||
|
if _, err := Scan(reply, &s.Channel, &s.Count); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return s |
||||||
|
case "pong": |
||||||
|
var p Pong |
||||||
|
if _, err := Scan(reply, &p.Data); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return p |
||||||
|
} |
||||||
|
return pkgerr.WithStack(errPubSub) |
||||||
|
} |
@ -0,0 +1,51 @@ |
|||||||
|
// Copyright 2012 Gary Burd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
// License for the specific language governing permissions and limitations
|
||||||
|
// under the License.
|
||||||
|
|
||||||
|
package redis |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
) |
||||||
|
|
||||||
|
// Error represents an error returned in a command reply.
|
||||||
|
type Error string |
||||||
|
|
||||||
|
func (err Error) Error() string { return string(err) } |
||||||
|
|
||||||
|
// Conn represents a connection to a Redis server.
|
||||||
|
type Conn interface { |
||||||
|
// Close closes the connection.
|
||||||
|
Close() error |
||||||
|
|
||||||
|
// Err returns a non-nil value if the connection is broken. The returned
|
||||||
|
// value is either the first non-nil value returned from the underlying
|
||||||
|
// network connection or a protocol parsing error. Applications should
|
||||||
|
// close broken connections.
|
||||||
|
Err() error |
||||||
|
|
||||||
|
// Do sends a command to the server and returns the received reply.
|
||||||
|
Do(commandName string, args ...interface{}) (reply interface{}, err error) |
||||||
|
|
||||||
|
// Send writes the command to the client's output buffer.
|
||||||
|
Send(commandName string, args ...interface{}) error |
||||||
|
|
||||||
|
// Flush flushes the output buffer to the Redis server.
|
||||||
|
Flush() error |
||||||
|
|
||||||
|
// Receive receives a single reply from the Redis server
|
||||||
|
Receive() (reply interface{}, err error) |
||||||
|
|
||||||
|
// WithContext
|
||||||
|
WithContext(ctx context.Context) Conn |
||||||
|
} |
@ -0,0 +1,409 @@ |
|||||||
|
// Copyright 2012 Gary Burd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
// License for the specific language governing permissions and limitations
|
||||||
|
// under the License.
|
||||||
|
|
||||||
|
package redis |
||||||
|
|
||||||
|
import ( |
||||||
|
"errors" |
||||||
|
"strconv" |
||||||
|
|
||||||
|
pkgerr "github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
// ErrNil indicates that a reply value is nil.
|
||||||
|
var ErrNil = errors.New("redigo: nil returned") |
||||||
|
|
||||||
|
// Int is a helper that converts a command reply to an integer. If err is not
|
||||||
|
// equal to nil, then Int returns 0, err. Otherwise, Int converts the
|
||||||
|
// reply to an int as follows:
|
||||||
|
//
|
||||||
|
// Reply type Result
|
||||||
|
// integer int(reply), nil
|
||||||
|
// bulk string parsed reply, nil
|
||||||
|
// nil 0, ErrNil
|
||||||
|
// other 0, error
|
||||||
|
func Int(reply interface{}, err error) (int, error) { |
||||||
|
if err != nil { |
||||||
|
return 0, err |
||||||
|
} |
||||||
|
switch reply := reply.(type) { |
||||||
|
case int64: |
||||||
|
x := int(reply) |
||||||
|
if int64(x) != reply { |
||||||
|
return 0, pkgerr.WithStack(strconv.ErrRange) |
||||||
|
} |
||||||
|
return x, nil |
||||||
|
case []byte: |
||||||
|
n, err := strconv.ParseInt(string(reply), 10, 0) |
||||||
|
return int(n), pkgerr.WithStack(err) |
||||||
|
case nil: |
||||||
|
return 0, ErrNil |
||||||
|
case Error: |
||||||
|
return 0, reply |
||||||
|
} |
||||||
|
return 0, pkgerr.Errorf("redigo: unexpected type for Int, got type %T", reply) |
||||||
|
} |
||||||
|
|
||||||
|
// Int64 is a helper that converts a command reply to 64 bit integer. If err is
|
||||||
|
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
|
||||||
|
// reply to an int64 as follows:
|
||||||
|
//
|
||||||
|
// Reply type Result
|
||||||
|
// integer reply, nil
|
||||||
|
// bulk string parsed reply, nil
|
||||||
|
// nil 0, ErrNil
|
||||||
|
// other 0, error
|
||||||
|
func Int64(reply interface{}, err error) (int64, error) { |
||||||
|
if err != nil { |
||||||
|
return 0, err |
||||||
|
} |
||||||
|
switch reply := reply.(type) { |
||||||
|
case int64: |
||||||
|
return reply, nil |
||||||
|
case []byte: |
||||||
|
n, err := strconv.ParseInt(string(reply), 10, 64) |
||||||
|
return n, pkgerr.WithStack(err) |
||||||
|
case nil: |
||||||
|
return 0, ErrNil |
||||||
|
case Error: |
||||||
|
return 0, reply |
||||||
|
} |
||||||
|
return 0, pkgerr.Errorf("redigo: unexpected type for Int64, got type %T", reply) |
||||||
|
} |
||||||
|
|
||||||
|
var errNegativeInt = errors.New("redigo: unexpected value for Uint64") |
||||||
|
|
||||||
|
// Uint64 is a helper that converts a command reply to 64 bit integer. If err is
|
||||||
|
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
|
||||||
|
// reply to an int64 as follows:
|
||||||
|
//
|
||||||
|
// Reply type Result
|
||||||
|
// integer reply, nil
|
||||||
|
// bulk string parsed reply, nil
|
||||||
|
// nil 0, ErrNil
|
||||||
|
// other 0, error
|
||||||
|
func Uint64(reply interface{}, err error) (uint64, error) { |
||||||
|
if err != nil { |
||||||
|
return 0, err |
||||||
|
} |
||||||
|
switch reply := reply.(type) { |
||||||
|
case int64: |
||||||
|
if reply < 0 { |
||||||
|
return 0, pkgerr.WithStack(errNegativeInt) |
||||||
|
} |
||||||
|
return uint64(reply), nil |
||||||
|
case []byte: |
||||||
|
n, err := strconv.ParseUint(string(reply), 10, 64) |
||||||
|
return n, err |
||||||
|
case nil: |
||||||
|
return 0, ErrNil |
||||||
|
case Error: |
||||||
|
return 0, reply |
||||||
|
} |
||||||
|
return 0, pkgerr.Errorf("redigo: unexpected type for Uint64, got type %T", reply) |
||||||
|
} |
||||||
|
|
||||||
|
// Float64 is a helper that converts a command reply to 64 bit float. If err is
|
||||||
|
// not equal to nil, then Float64 returns 0, err. Otherwise, Float64 converts
|
||||||
|
// the reply to an int as follows:
|
||||||
|
//
|
||||||
|
// Reply type Result
|
||||||
|
// bulk string parsed reply, nil
|
||||||
|
// nil 0, ErrNil
|
||||||
|
// other 0, error
|
||||||
|
func Float64(reply interface{}, err error) (float64, error) { |
||||||
|
if err != nil { |
||||||
|
return 0, err |
||||||
|
} |
||||||
|
switch reply := reply.(type) { |
||||||
|
case []byte: |
||||||
|
n, err := strconv.ParseFloat(string(reply), 64) |
||||||
|
return n, pkgerr.WithStack(err) |
||||||
|
case nil: |
||||||
|
return 0, ErrNil |
||||||
|
case Error: |
||||||
|
return 0, reply |
||||||
|
} |
||||||
|
return 0, pkgerr.Errorf("redigo: unexpected type for Float64, got type %T", reply) |
||||||
|
} |
||||||
|
|
||||||
|
// String is a helper that converts a command reply to a string. If err is not
|
||||||
|
// equal to nil, then String returns "", err. Otherwise String converts the
|
||||||
|
// reply to a string as follows:
|
||||||
|
//
|
||||||
|
// Reply type Result
|
||||||
|
// bulk string string(reply), nil
|
||||||
|
// simple string reply, nil
|
||||||
|
// nil "", ErrNil
|
||||||
|
// other "", error
|
||||||
|
func String(reply interface{}, err error) (string, error) { |
||||||
|
if err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
switch reply := reply.(type) { |
||||||
|
case []byte: |
||||||
|
return string(reply), nil |
||||||
|
case string: |
||||||
|
return reply, nil |
||||||
|
case nil: |
||||||
|
return "", ErrNil |
||||||
|
case Error: |
||||||
|
return "", reply |
||||||
|
} |
||||||
|
return "", pkgerr.Errorf("redigo: unexpected type for String, got type %T", reply) |
||||||
|
} |
||||||
|
|
||||||
|
// Bytes is a helper that converts a command reply to a slice of bytes. If err
|
||||||
|
// is not equal to nil, then Bytes returns nil, err. Otherwise Bytes converts
|
||||||
|
// the reply to a slice of bytes as follows:
|
||||||
|
//
|
||||||
|
// Reply type Result
|
||||||
|
// bulk string reply, nil
|
||||||
|
// simple string []byte(reply), nil
|
||||||
|
// nil nil, ErrNil
|
||||||
|
// other nil, error
|
||||||
|
func Bytes(reply interface{}, err error) ([]byte, error) { |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
switch reply := reply.(type) { |
||||||
|
case []byte: |
||||||
|
return reply, nil |
||||||
|
case string: |
||||||
|
return []byte(reply), nil |
||||||
|
case nil: |
||||||
|
return nil, ErrNil |
||||||
|
case Error: |
||||||
|
return nil, reply |
||||||
|
} |
||||||
|
return nil, pkgerr.Errorf("redigo: unexpected type for Bytes, got type %T", reply) |
||||||
|
} |
||||||
|
|
||||||
|
// Bool is a helper that converts a command reply to a boolean. If err is not
|
||||||
|
// equal to nil, then Bool returns false, err. Otherwise Bool converts the
|
||||||
|
// reply to boolean as follows:
|
||||||
|
//
|
||||||
|
// Reply type Result
|
||||||
|
// integer value != 0, nil
|
||||||
|
// bulk string strconv.ParseBool(reply)
|
||||||
|
// nil false, ErrNil
|
||||||
|
// other false, error
|
||||||
|
func Bool(reply interface{}, err error) (bool, error) { |
||||||
|
if err != nil { |
||||||
|
return false, err |
||||||
|
} |
||||||
|
switch reply := reply.(type) { |
||||||
|
case int64: |
||||||
|
return reply != 0, nil |
||||||
|
case []byte: |
||||||
|
b, e := strconv.ParseBool(string(reply)) |
||||||
|
return b, pkgerr.WithStack(e) |
||||||
|
case nil: |
||||||
|
return false, ErrNil |
||||||
|
case Error: |
||||||
|
return false, reply |
||||||
|
} |
||||||
|
return false, pkgerr.Errorf("redigo: unexpected type for Bool, got type %T", reply) |
||||||
|
} |
||||||
|
|
||||||
|
// MultiBulk is a helper that converts an array command reply to a []interface{}.
|
||||||
|
//
|
||||||
|
// Deprecated: Use Values instead.
|
||||||
|
func MultiBulk(reply interface{}, err error) ([]interface{}, error) { return Values(reply, err) } |
||||||
|
|
||||||
|
// Values is a helper that converts an array command reply to a []interface{}.
|
||||||
|
// If err is not equal to nil, then Values returns nil, err. Otherwise, Values
|
||||||
|
// converts the reply as follows:
|
||||||
|
//
|
||||||
|
// Reply type Result
|
||||||
|
// array reply, nil
|
||||||
|
// nil nil, ErrNil
|
||||||
|
// other nil, error
|
||||||
|
func Values(reply interface{}, err error) ([]interface{}, error) { |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
switch reply := reply.(type) { |
||||||
|
case []interface{}: |
||||||
|
return reply, nil |
||||||
|
case nil: |
||||||
|
return nil, ErrNil |
||||||
|
case Error: |
||||||
|
return nil, reply |
||||||
|
} |
||||||
|
return nil, pkgerr.Errorf("redigo: unexpected type for Values, got type %T", reply) |
||||||
|
} |
||||||
|
|
||||||
|
// Strings is a helper that converts an array command reply to a []string. If
|
||||||
|
// err is not equal to nil, then Strings returns nil, err. Nil array items are
|
||||||
|
// converted to "" in the output slice. Strings returns an error if an array
|
||||||
|
// item is not a bulk string or nil.
|
||||||
|
func Strings(reply interface{}, err error) ([]string, error) { |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
switch reply := reply.(type) { |
||||||
|
case []interface{}: |
||||||
|
result := make([]string, len(reply)) |
||||||
|
for i := range reply { |
||||||
|
if reply[i] == nil { |
||||||
|
continue |
||||||
|
} |
||||||
|
p, ok := reply[i].([]byte) |
||||||
|
if !ok { |
||||||
|
return nil, pkgerr.Errorf("redigo: unexpected element type for Strings, got type %T", reply[i]) |
||||||
|
} |
||||||
|
result[i] = string(p) |
||||||
|
} |
||||||
|
return result, nil |
||||||
|
case nil: |
||||||
|
return nil, ErrNil |
||||||
|
case Error: |
||||||
|
return nil, reply |
||||||
|
} |
||||||
|
return nil, pkgerr.Errorf("redigo: unexpected type for Strings, got type %T", reply) |
||||||
|
} |
||||||
|
|
||||||
|
// ByteSlices is a helper that converts an array command reply to a [][]byte.
|
||||||
|
// If err is not equal to nil, then ByteSlices returns nil, err. Nil array
|
||||||
|
// items are stay nil. ByteSlices returns an error if an array item is not a
|
||||||
|
// bulk string or nil.
|
||||||
|
func ByteSlices(reply interface{}, err error) ([][]byte, error) { |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
switch reply := reply.(type) { |
||||||
|
case []interface{}: |
||||||
|
result := make([][]byte, len(reply)) |
||||||
|
for i := range reply { |
||||||
|
if reply[i] == nil { |
||||||
|
continue |
||||||
|
} |
||||||
|
p, ok := reply[i].([]byte) |
||||||
|
if !ok { |
||||||
|
return nil, pkgerr.Errorf("redigo: unexpected element type for ByteSlices, got type %T", reply[i]) |
||||||
|
} |
||||||
|
result[i] = p |
||||||
|
} |
||||||
|
return result, nil |
||||||
|
case nil: |
||||||
|
return nil, ErrNil |
||||||
|
case Error: |
||||||
|
return nil, reply |
||||||
|
} |
||||||
|
return nil, pkgerr.Errorf("redigo: unexpected type for ByteSlices, got type %T", reply) |
||||||
|
} |
||||||
|
|
||||||
|
// Ints is a helper that converts an array command reply to a []int. If
|
||||||
|
// err is not equal to nil, then Ints returns nil, err.
|
||||||
|
func Ints(reply interface{}, err error) ([]int, error) { |
||||||
|
var ints []int |
||||||
|
values, err := Values(reply, err) |
||||||
|
if err != nil { |
||||||
|
return ints, err |
||||||
|
} |
||||||
|
if err := ScanSlice(values, &ints); err != nil { |
||||||
|
return ints, err |
||||||
|
} |
||||||
|
return ints, nil |
||||||
|
} |
||||||
|
|
||||||
|
// Int64s is a helper that converts an array command reply to a []int64. If
|
||||||
|
// err is not equal to nil, then Int64s returns nil, err.
|
||||||
|
func Int64s(reply interface{}, err error) ([]int64, error) { |
||||||
|
var int64s []int64 |
||||||
|
values, err := Values(reply, err) |
||||||
|
if err != nil { |
||||||
|
return int64s, err |
||||||
|
} |
||||||
|
if err := ScanSlice(values, &int64s); err != nil { |
||||||
|
return int64s, err |
||||||
|
} |
||||||
|
return int64s, nil |
||||||
|
} |
||||||
|
|
||||||
|
// StringMap is a helper that converts an array of strings (alternating key, value)
|
||||||
|
// into a map[string]string. The HGETALL and CONFIG GET commands return replies in this format.
|
||||||
|
// Requires an even number of values in result.
|
||||||
|
func StringMap(result interface{}, err error) (map[string]string, error) { |
||||||
|
values, err := Values(result, err) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if len(values)%2 != 0 { |
||||||
|
return nil, pkgerr.New("redigo: StringMap expects even number of values result") |
||||||
|
} |
||||||
|
m := make(map[string]string, len(values)/2) |
||||||
|
for i := 0; i < len(values); i += 2 { |
||||||
|
key, okKey := values[i].([]byte) |
||||||
|
value, okValue := values[i+1].([]byte) |
||||||
|
if !okKey || !okValue { |
||||||
|
return nil, pkgerr.New("redigo: ScanMap key not a bulk string value") |
||||||
|
} |
||||||
|
m[string(key)] = string(value) |
||||||
|
} |
||||||
|
return m, nil |
||||||
|
} |
||||||
|
|
||||||
|
// IntMap is a helper that converts an array of strings (alternating key, value)
|
||||||
|
// into a map[string]int. The HGETALL commands return replies in this format.
|
||||||
|
// Requires an even number of values in result.
|
||||||
|
func IntMap(result interface{}, err error) (map[string]int, error) { |
||||||
|
values, err := Values(result, err) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if len(values)%2 != 0 { |
||||||
|
return nil, pkgerr.New("redigo: IntMap expects even number of values result") |
||||||
|
} |
||||||
|
m := make(map[string]int, len(values)/2) |
||||||
|
for i := 0; i < len(values); i += 2 { |
||||||
|
key, ok := values[i].([]byte) |
||||||
|
if !ok { |
||||||
|
return nil, pkgerr.New("redigo: ScanMap key not a bulk string value") |
||||||
|
} |
||||||
|
value, err := Int(values[i+1], nil) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
m[string(key)] = value |
||||||
|
} |
||||||
|
return m, nil |
||||||
|
} |
||||||
|
|
||||||
|
// Int64Map is a helper that converts an array of strings (alternating key, value)
|
||||||
|
// into a map[string]int64. The HGETALL commands return replies in this format.
|
||||||
|
// Requires an even number of values in result.
|
||||||
|
func Int64Map(result interface{}, err error) (map[string]int64, error) { |
||||||
|
values, err := Values(result, err) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if len(values)%2 != 0 { |
||||||
|
return nil, pkgerr.New("redigo: Int64Map expects even number of values result") |
||||||
|
} |
||||||
|
m := make(map[string]int64, len(values)/2) |
||||||
|
for i := 0; i < len(values); i += 2 { |
||||||
|
key, ok := values[i].([]byte) |
||||||
|
if !ok { |
||||||
|
return nil, pkgerr.New("redigo: ScanMap key not a bulk string value") |
||||||
|
} |
||||||
|
value, err := Int64(values[i+1], nil) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
m[string(key)] = value |
||||||
|
} |
||||||
|
return m, nil |
||||||
|
} |
@ -0,0 +1,559 @@ |
|||||||
|
// Copyright 2012 Gary Burd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
// License for the specific language governing permissions and limitations
|
||||||
|
// under the License.
|
||||||
|
|
||||||
|
package redis |
||||||
|
|
||||||
|
import ( |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"reflect" |
||||||
|
"strconv" |
||||||
|
"strings" |
||||||
|
"sync" |
||||||
|
|
||||||
|
pkgerr "github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
func ensureLen(d reflect.Value, n int) { |
||||||
|
if n > d.Cap() { |
||||||
|
d.Set(reflect.MakeSlice(d.Type(), n, n)) |
||||||
|
} else { |
||||||
|
d.SetLen(n) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func cannotConvert(d reflect.Value, s interface{}) error { |
||||||
|
var sname string |
||||||
|
switch s.(type) { |
||||||
|
case string: |
||||||
|
sname = "Redis simple string" |
||||||
|
case Error: |
||||||
|
sname = "Redis error" |
||||||
|
case int64: |
||||||
|
sname = "Redis integer" |
||||||
|
case []byte: |
||||||
|
sname = "Redis bulk string" |
||||||
|
case []interface{}: |
||||||
|
sname = "Redis array" |
||||||
|
default: |
||||||
|
sname = reflect.TypeOf(s).String() |
||||||
|
} |
||||||
|
return pkgerr.Errorf("cannot convert from %s to %s", sname, d.Type()) |
||||||
|
} |
||||||
|
|
||||||
|
func convertAssignBulkString(d reflect.Value, s []byte) (err error) { |
||||||
|
switch d.Type().Kind() { |
||||||
|
case reflect.Float32, reflect.Float64: |
||||||
|
var x float64 |
||||||
|
x, err = strconv.ParseFloat(string(s), d.Type().Bits()) |
||||||
|
d.SetFloat(x) |
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
||||||
|
var x int64 |
||||||
|
x, err = strconv.ParseInt(string(s), 10, d.Type().Bits()) |
||||||
|
d.SetInt(x) |
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
||||||
|
var x uint64 |
||||||
|
x, err = strconv.ParseUint(string(s), 10, d.Type().Bits()) |
||||||
|
d.SetUint(x) |
||||||
|
case reflect.Bool: |
||||||
|
var x bool |
||||||
|
x, err = strconv.ParseBool(string(s)) |
||||||
|
d.SetBool(x) |
||||||
|
case reflect.String: |
||||||
|
d.SetString(string(s)) |
||||||
|
case reflect.Slice: |
||||||
|
if d.Type().Elem().Kind() != reflect.Uint8 { |
||||||
|
err = cannotConvert(d, s) |
||||||
|
} else { |
||||||
|
d.SetBytes(s) |
||||||
|
} |
||||||
|
default: |
||||||
|
err = cannotConvert(d, s) |
||||||
|
} |
||||||
|
err = pkgerr.WithStack(err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func convertAssignInt(d reflect.Value, s int64) (err error) { |
||||||
|
switch d.Type().Kind() { |
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
||||||
|
d.SetInt(s) |
||||||
|
if d.Int() != s { |
||||||
|
err = strconv.ErrRange |
||||||
|
d.SetInt(0) |
||||||
|
} |
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
||||||
|
if s < 0 { |
||||||
|
err = strconv.ErrRange |
||||||
|
} else { |
||||||
|
x := uint64(s) |
||||||
|
d.SetUint(x) |
||||||
|
if d.Uint() != x { |
||||||
|
err = strconv.ErrRange |
||||||
|
d.SetUint(0) |
||||||
|
} |
||||||
|
} |
||||||
|
case reflect.Bool: |
||||||
|
d.SetBool(s != 0) |
||||||
|
default: |
||||||
|
err = cannotConvert(d, s) |
||||||
|
} |
||||||
|
err = pkgerr.WithStack(err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func convertAssignValue(d reflect.Value, s interface{}) (err error) { |
||||||
|
switch s := s.(type) { |
||||||
|
case []byte: |
||||||
|
err = convertAssignBulkString(d, s) |
||||||
|
case int64: |
||||||
|
err = convertAssignInt(d, s) |
||||||
|
default: |
||||||
|
err = cannotConvert(d, s) |
||||||
|
} |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func convertAssignArray(d reflect.Value, s []interface{}) error { |
||||||
|
if d.Type().Kind() != reflect.Slice { |
||||||
|
return cannotConvert(d, s) |
||||||
|
} |
||||||
|
ensureLen(d, len(s)) |
||||||
|
for i := 0; i < len(s); i++ { |
||||||
|
if err := convertAssignValue(d.Index(i), s[i]); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func convertAssign(d interface{}, s interface{}) (err error) { |
||||||
|
// Handle the most common destination types using type switches and
|
||||||
|
// fall back to reflection for all other types.
|
||||||
|
switch s := s.(type) { |
||||||
|
case nil: |
||||||
|
// ingore
|
||||||
|
case []byte: |
||||||
|
switch d := d.(type) { |
||||||
|
case *string: |
||||||
|
*d = string(s) |
||||||
|
case *int: |
||||||
|
*d, err = strconv.Atoi(string(s)) |
||||||
|
case *bool: |
||||||
|
*d, err = strconv.ParseBool(string(s)) |
||||||
|
case *[]byte: |
||||||
|
*d = s |
||||||
|
case *interface{}: |
||||||
|
*d = s |
||||||
|
case nil: |
||||||
|
// skip value
|
||||||
|
default: |
||||||
|
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr { |
||||||
|
err = cannotConvert(d, s) |
||||||
|
} else { |
||||||
|
err = convertAssignBulkString(d.Elem(), s) |
||||||
|
} |
||||||
|
} |
||||||
|
case int64: |
||||||
|
switch d := d.(type) { |
||||||
|
case *int: |
||||||
|
x := int(s) |
||||||
|
if int64(x) != s { |
||||||
|
err = strconv.ErrRange |
||||||
|
x = 0 |
||||||
|
} |
||||||
|
*d = x |
||||||
|
case *bool: |
||||||
|
*d = s != 0 |
||||||
|
case *interface{}: |
||||||
|
*d = s |
||||||
|
case nil: |
||||||
|
// skip value
|
||||||
|
default: |
||||||
|
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr { |
||||||
|
err = cannotConvert(d, s) |
||||||
|
} else { |
||||||
|
err = convertAssignInt(d.Elem(), s) |
||||||
|
} |
||||||
|
} |
||||||
|
case string: |
||||||
|
switch d := d.(type) { |
||||||
|
case *string: |
||||||
|
*d = string(s) |
||||||
|
default: |
||||||
|
err = cannotConvert(reflect.ValueOf(d), s) |
||||||
|
} |
||||||
|
case []interface{}: |
||||||
|
switch d := d.(type) { |
||||||
|
case *[]interface{}: |
||||||
|
*d = s |
||||||
|
case *interface{}: |
||||||
|
*d = s |
||||||
|
case nil: |
||||||
|
// skip value
|
||||||
|
default: |
||||||
|
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr { |
||||||
|
err = cannotConvert(d, s) |
||||||
|
} else { |
||||||
|
err = convertAssignArray(d.Elem(), s) |
||||||
|
} |
||||||
|
} |
||||||
|
case Error: |
||||||
|
err = s |
||||||
|
default: |
||||||
|
err = cannotConvert(reflect.ValueOf(d), s) |
||||||
|
} |
||||||
|
err = pkgerr.WithStack(err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Scan copies from src to the values pointed at by dest.
|
||||||
|
//
|
||||||
|
// The values pointed at by dest must be an integer, float, boolean, string,
|
||||||
|
// []byte, interface{} or slices of these types. Scan uses the standard strconv
|
||||||
|
// package to convert bulk strings to numeric and boolean types.
|
||||||
|
//
|
||||||
|
// If a dest value is nil, then the corresponding src value is skipped.
|
||||||
|
//
|
||||||
|
// If a src element is nil, then the corresponding dest value is not modified.
|
||||||
|
//
|
||||||
|
// To enable easy use of Scan in a loop, Scan returns the slice of src
|
||||||
|
// following the copied values.
|
||||||
|
func Scan(src []interface{}, dest ...interface{}) ([]interface{}, error) { |
||||||
|
if len(src) < len(dest) { |
||||||
|
return nil, pkgerr.New("redigo.Scan: array short") |
||||||
|
} |
||||||
|
var err error |
||||||
|
for i, d := range dest { |
||||||
|
err = convertAssign(d, src[i]) |
||||||
|
if err != nil { |
||||||
|
err = fmt.Errorf("redigo.Scan: cannot assign to dest %d: %v", i, err) |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
return src[len(dest):], err |
||||||
|
} |
||||||
|
|
||||||
|
type fieldSpec struct { |
||||||
|
name string |
||||||
|
index []int |
||||||
|
omitEmpty bool |
||||||
|
} |
||||||
|
|
||||||
|
type structSpec struct { |
||||||
|
m map[string]*fieldSpec |
||||||
|
l []*fieldSpec |
||||||
|
} |
||||||
|
|
||||||
|
func (ss *structSpec) fieldSpec(name []byte) *fieldSpec { |
||||||
|
return ss.m[string(name)] |
||||||
|
} |
||||||
|
|
||||||
|
func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) { |
||||||
|
for i := 0; i < t.NumField(); i++ { |
||||||
|
f := t.Field(i) |
||||||
|
switch { |
||||||
|
case f.PkgPath != "" && !f.Anonymous: |
||||||
|
// Ignore unexported fields.
|
||||||
|
case f.Anonymous: |
||||||
|
// TODO: Handle pointers. Requires change to decoder and
|
||||||
|
// protection against infinite recursion.
|
||||||
|
if f.Type.Kind() == reflect.Struct { |
||||||
|
compileStructSpec(f.Type, depth, append(index, i), ss) |
||||||
|
} |
||||||
|
default: |
||||||
|
fs := &fieldSpec{name: f.Name} |
||||||
|
tag := f.Tag.Get("redis") |
||||||
|
p := strings.Split(tag, ",") |
||||||
|
if len(p) > 0 { |
||||||
|
if p[0] == "-" { |
||||||
|
continue |
||||||
|
} |
||||||
|
if len(p[0]) > 0 { |
||||||
|
fs.name = p[0] |
||||||
|
} |
||||||
|
for _, s := range p[1:] { |
||||||
|
switch s { |
||||||
|
case "omitempty": |
||||||
|
fs.omitEmpty = true |
||||||
|
default: |
||||||
|
panic(fmt.Errorf("redigo: unknown field tag %s for type %s", s, t.Name())) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
d, found := depth[fs.name] |
||||||
|
if !found { |
||||||
|
d = 1 << 30 |
||||||
|
} |
||||||
|
switch { |
||||||
|
case len(index) == d: |
||||||
|
// At same depth, remove from result.
|
||||||
|
delete(ss.m, fs.name) |
||||||
|
j := 0 |
||||||
|
for i1 := 0; i1 < len(ss.l); i1++ { |
||||||
|
if fs.name != ss.l[i1].name { |
||||||
|
ss.l[j] = ss.l[i1] |
||||||
|
j++ |
||||||
|
} |
||||||
|
} |
||||||
|
ss.l = ss.l[:j] |
||||||
|
case len(index) < d: |
||||||
|
fs.index = make([]int, len(index)+1) |
||||||
|
copy(fs.index, index) |
||||||
|
fs.index[len(index)] = i |
||||||
|
depth[fs.name] = len(index) |
||||||
|
ss.m[fs.name] = fs |
||||||
|
ss.l = append(ss.l, fs) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
var ( |
||||||
|
structSpecMutex sync.RWMutex |
||||||
|
structSpecCache = make(map[reflect.Type]*structSpec) |
||||||
|
) |
||||||
|
|
||||||
|
func structSpecForType(t reflect.Type) *structSpec { |
||||||
|
|
||||||
|
structSpecMutex.RLock() |
||||||
|
ss, found := structSpecCache[t] |
||||||
|
structSpecMutex.RUnlock() |
||||||
|
if found { |
||||||
|
return ss |
||||||
|
} |
||||||
|
|
||||||
|
structSpecMutex.Lock() |
||||||
|
defer structSpecMutex.Unlock() |
||||||
|
ss, found = structSpecCache[t] |
||||||
|
if found { |
||||||
|
return ss |
||||||
|
} |
||||||
|
|
||||||
|
ss = &structSpec{m: make(map[string]*fieldSpec)} |
||||||
|
compileStructSpec(t, make(map[string]int), nil, ss) |
||||||
|
structSpecCache[t] = ss |
||||||
|
return ss |
||||||
|
} |
||||||
|
|
||||||
|
var errScanStructValue = errors.New("redigo.ScanStruct: value must be non-nil pointer to a struct") |
||||||
|
|
||||||
|
// ScanStruct scans alternating names and values from src to a struct. The
|
||||||
|
// HGETALL and CONFIG GET commands return replies in this format.
|
||||||
|
//
|
||||||
|
// ScanStruct uses exported field names to match values in the response. Use
|
||||||
|
// 'redis' field tag to override the name:
|
||||||
|
//
|
||||||
|
// Field int `redis:"myName"`
|
||||||
|
//
|
||||||
|
// Fields with the tag redis:"-" are ignored.
|
||||||
|
//
|
||||||
|
// Integer, float, boolean, string and []byte fields are supported. Scan uses the
|
||||||
|
// standard strconv package to convert bulk string values to numeric and
|
||||||
|
// boolean types.
|
||||||
|
//
|
||||||
|
// If a src element is nil, then the corresponding field is not modified.
|
||||||
|
func ScanStruct(src []interface{}, dest interface{}) error { |
||||||
|
d := reflect.ValueOf(dest) |
||||||
|
if d.Kind() != reflect.Ptr || d.IsNil() { |
||||||
|
return pkgerr.WithStack(errScanStructValue) |
||||||
|
} |
||||||
|
d = d.Elem() |
||||||
|
if d.Kind() != reflect.Struct { |
||||||
|
return pkgerr.WithStack(errScanStructValue) |
||||||
|
} |
||||||
|
ss := structSpecForType(d.Type()) |
||||||
|
|
||||||
|
if len(src)%2 != 0 { |
||||||
|
return pkgerr.New("redigo.ScanStruct: number of values not a multiple of 2") |
||||||
|
} |
||||||
|
|
||||||
|
for i := 0; i < len(src); i += 2 { |
||||||
|
s := src[i+1] |
||||||
|
if s == nil { |
||||||
|
continue |
||||||
|
} |
||||||
|
name, ok := src[i].([]byte) |
||||||
|
if !ok { |
||||||
|
return pkgerr.Errorf("redigo.ScanStruct: key %d not a bulk string value", i) |
||||||
|
} |
||||||
|
fs := ss.fieldSpec(name) |
||||||
|
if fs == nil { |
||||||
|
continue |
||||||
|
} |
||||||
|
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil { |
||||||
|
return pkgerr.Errorf("redigo.ScanStruct: cannot assign field %s: %v", fs.name, err) |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
var ( |
||||||
|
errScanSliceValue = errors.New("redigo.ScanSlice: dest must be non-nil pointer to a struct") |
||||||
|
) |
||||||
|
|
||||||
|
// ScanSlice scans src to the slice pointed to by dest. The elements the dest
|
||||||
|
// slice must be integer, float, boolean, string, struct or pointer to struct
|
||||||
|
// values.
|
||||||
|
//
|
||||||
|
// Struct fields must be integer, float, boolean or string values. All struct
|
||||||
|
// fields are used unless a subset is specified using fieldNames.
|
||||||
|
func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error { |
||||||
|
d := reflect.ValueOf(dest) |
||||||
|
if d.Kind() != reflect.Ptr || d.IsNil() { |
||||||
|
return pkgerr.WithStack(errScanSliceValue) |
||||||
|
} |
||||||
|
d = d.Elem() |
||||||
|
if d.Kind() != reflect.Slice { |
||||||
|
return pkgerr.WithStack(errScanSliceValue) |
||||||
|
} |
||||||
|
|
||||||
|
isPtr := false |
||||||
|
t := d.Type().Elem() |
||||||
|
if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct { |
||||||
|
isPtr = true |
||||||
|
t = t.Elem() |
||||||
|
} |
||||||
|
|
||||||
|
if t.Kind() != reflect.Struct { |
||||||
|
ensureLen(d, len(src)) |
||||||
|
for i, s := range src { |
||||||
|
if s == nil { |
||||||
|
continue |
||||||
|
} |
||||||
|
if err := convertAssignValue(d.Index(i), s); err != nil { |
||||||
|
return pkgerr.Errorf("redigo.ScanSlice: cannot assign element %d: %v", i, err) |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
ss := structSpecForType(t) |
||||||
|
fss := ss.l |
||||||
|
if len(fieldNames) > 0 { |
||||||
|
fss = make([]*fieldSpec, len(fieldNames)) |
||||||
|
for i, name := range fieldNames { |
||||||
|
fss[i] = ss.m[name] |
||||||
|
if fss[i] == nil { |
||||||
|
return pkgerr.Errorf("redigo.ScanSlice: ScanSlice bad field name %s", name) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if len(fss) == 0 { |
||||||
|
return pkgerr.New("redigo.ScanSlice: no struct fields") |
||||||
|
} |
||||||
|
|
||||||
|
n := len(src) / len(fss) |
||||||
|
if n*len(fss) != len(src) { |
||||||
|
return pkgerr.New("redigo.ScanSlice: length not a multiple of struct field count") |
||||||
|
} |
||||||
|
|
||||||
|
ensureLen(d, n) |
||||||
|
for i := 0; i < n; i++ { |
||||||
|
d1 := d.Index(i) |
||||||
|
if isPtr { |
||||||
|
if d1.IsNil() { |
||||||
|
d1.Set(reflect.New(t)) |
||||||
|
} |
||||||
|
d1 = d1.Elem() |
||||||
|
} |
||||||
|
for j, fs := range fss { |
||||||
|
s := src[i*len(fss)+j] |
||||||
|
if s == nil { |
||||||
|
continue |
||||||
|
} |
||||||
|
if err := convertAssignValue(d1.FieldByIndex(fs.index), s); err != nil { |
||||||
|
return pkgerr.Errorf("redigo.ScanSlice: cannot assign element %d to field %s: %v", i*len(fss)+j, fs.name, err) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// Args is a helper for constructing command arguments from structured values.
|
||||||
|
type Args []interface{} |
||||||
|
|
||||||
|
// Add returns the result of appending value to args.
|
||||||
|
func (args Args) Add(value ...interface{}) Args { |
||||||
|
return append(args, value...) |
||||||
|
} |
||||||
|
|
||||||
|
// AddFlat returns the result of appending the flattened value of v to args.
|
||||||
|
//
|
||||||
|
// Maps are flattened by appending the alternating keys and map values to args.
|
||||||
|
//
|
||||||
|
// Slices are flattened by appending the slice elements to args.
|
||||||
|
//
|
||||||
|
// Structs are flattened by appending the alternating names and values of
|
||||||
|
// exported fields to args. If v is a nil struct pointer, then nothing is
|
||||||
|
// appended. The 'redis' field tag overrides struct field names. See ScanStruct
|
||||||
|
// for more information on the use of the 'redis' field tag.
|
||||||
|
//
|
||||||
|
// Other types are appended to args as is.
|
||||||
|
func (args Args) AddFlat(v interface{}) Args { |
||||||
|
rv := reflect.ValueOf(v) |
||||||
|
switch rv.Kind() { |
||||||
|
case reflect.Struct: |
||||||
|
args = flattenStruct(args, rv) |
||||||
|
case reflect.Slice: |
||||||
|
for i := 0; i < rv.Len(); i++ { |
||||||
|
args = append(args, rv.Index(i).Interface()) |
||||||
|
} |
||||||
|
case reflect.Map: |
||||||
|
for _, k := range rv.MapKeys() { |
||||||
|
args = append(args, k.Interface(), rv.MapIndex(k).Interface()) |
||||||
|
} |
||||||
|
case reflect.Ptr: |
||||||
|
if rv.Type().Elem().Kind() == reflect.Struct { |
||||||
|
if !rv.IsNil() { |
||||||
|
args = flattenStruct(args, rv.Elem()) |
||||||
|
} |
||||||
|
} else { |
||||||
|
args = append(args, v) |
||||||
|
} |
||||||
|
default: |
||||||
|
args = append(args, v) |
||||||
|
} |
||||||
|
return args |
||||||
|
} |
||||||
|
|
||||||
|
func flattenStruct(args Args, v reflect.Value) Args { |
||||||
|
ss := structSpecForType(v.Type()) |
||||||
|
for _, fs := range ss.l { |
||||||
|
fv := v.FieldByIndex(fs.index) |
||||||
|
if fs.omitEmpty { |
||||||
|
var empty = false |
||||||
|
switch fv.Kind() { |
||||||
|
case reflect.Array, reflect.Map, reflect.Slice, reflect.String: |
||||||
|
empty = fv.Len() == 0 |
||||||
|
case reflect.Bool: |
||||||
|
empty = !fv.Bool() |
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
||||||
|
empty = fv.Int() == 0 |
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: |
||||||
|
empty = fv.Uint() == 0 |
||||||
|
case reflect.Float32, reflect.Float64: |
||||||
|
empty = fv.Float() == 0 |
||||||
|
case reflect.Interface, reflect.Ptr: |
||||||
|
empty = fv.IsNil() |
||||||
|
} |
||||||
|
if empty { |
||||||
|
continue |
||||||
|
} |
||||||
|
} |
||||||
|
args = append(args, fs.name, fv.Interface()) |
||||||
|
} |
||||||
|
return args |
||||||
|
} |
@ -0,0 +1,86 @@ |
|||||||
|
// Copyright 2012 Gary Burd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
// License for the specific language governing permissions and limitations
|
||||||
|
// under the License.
|
||||||
|
|
||||||
|
package redis |
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto/sha1" |
||||||
|
"encoding/hex" |
||||||
|
"io" |
||||||
|
"strings" |
||||||
|
) |
||||||
|
|
||||||
|
// Script encapsulates the source, hash and key count for a Lua script. See
|
||||||
|
// http://redis.io/commands/eval for information on scripts in Redis.
|
||||||
|
type Script struct { |
||||||
|
keyCount int |
||||||
|
src string |
||||||
|
hash string |
||||||
|
} |
||||||
|
|
||||||
|
// NewScript returns a new script object. If keyCount is greater than or equal
|
||||||
|
// to zero, then the count is automatically inserted in the EVAL command
|
||||||
|
// argument list. If keyCount is less than zero, then the application supplies
|
||||||
|
// the count as the first value in the keysAndArgs argument to the Do, Send and
|
||||||
|
// SendHash methods.
|
||||||
|
func NewScript(keyCount int, src string) *Script { |
||||||
|
h := sha1.New() |
||||||
|
io.WriteString(h, src) |
||||||
|
return &Script{keyCount, src, hex.EncodeToString(h.Sum(nil))} |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Script) args(spec string, keysAndArgs []interface{}) []interface{} { |
||||||
|
var args []interface{} |
||||||
|
if s.keyCount < 0 { |
||||||
|
args = make([]interface{}, 1+len(keysAndArgs)) |
||||||
|
args[0] = spec |
||||||
|
copy(args[1:], keysAndArgs) |
||||||
|
} else { |
||||||
|
args = make([]interface{}, 2+len(keysAndArgs)) |
||||||
|
args[0] = spec |
||||||
|
args[1] = s.keyCount |
||||||
|
copy(args[2:], keysAndArgs) |
||||||
|
} |
||||||
|
return args |
||||||
|
} |
||||||
|
|
||||||
|
// Do evaluates the script. Under the covers, Do optimistically evaluates the
|
||||||
|
// script using the EVALSHA command. If the command fails because the script is
|
||||||
|
// not loaded, then Do evaluates the script using the EVAL command (thus
|
||||||
|
// causing the script to load).
|
||||||
|
func (s *Script) Do(c Conn, keysAndArgs ...interface{}) (interface{}, error) { |
||||||
|
v, err := c.Do("EVALSHA", s.args(s.hash, keysAndArgs)...) |
||||||
|
if e, ok := err.(Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") { |
||||||
|
v, err = c.Do("EVAL", s.args(s.src, keysAndArgs)...) |
||||||
|
} |
||||||
|
return v, err |
||||||
|
} |
||||||
|
|
||||||
|
// SendHash evaluates the script without waiting for the reply. The script is
|
||||||
|
// evaluated with the EVALSHA command. The application must ensure that the
|
||||||
|
// script is loaded by a previous call to Send, Do or Load methods.
|
||||||
|
func (s *Script) SendHash(c Conn, keysAndArgs ...interface{}) error { |
||||||
|
return c.Send("EVALSHA", s.args(s.hash, keysAndArgs)...) |
||||||
|
} |
||||||
|
|
||||||
|
// Send evaluates the script without waiting for the reply.
|
||||||
|
func (s *Script) Send(c Conn, keysAndArgs ...interface{}) error { |
||||||
|
return c.Send("EVAL", s.args(s.src, keysAndArgs)...) |
||||||
|
} |
||||||
|
|
||||||
|
// Load loads the script without evaluating it.
|
||||||
|
func (s *Script) Load(c Conn) error { |
||||||
|
_, err := c.Do("SCRIPT", "LOAD", s.src) |
||||||
|
return err |
||||||
|
} |
@ -0,0 +1,142 @@ |
|||||||
|
package redis |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"fmt" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
"github.com/bilibili/kratos/pkg/net/trace" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
_traceComponentName = "pkg/cache/redis" |
||||||
|
_tracePeerService = "redis" |
||||||
|
_traceSpanKind = "client" |
||||||
|
_slowLogDuration = time.Millisecond * 250 |
||||||
|
) |
||||||
|
|
||||||
|
var _internalTags = []trace.Tag{ |
||||||
|
trace.TagString(trace.TagSpanKind, _traceSpanKind), |
||||||
|
trace.TagString(trace.TagComponent, _traceComponentName), |
||||||
|
trace.TagString(trace.TagPeerService, _tracePeerService), |
||||||
|
} |
||||||
|
|
||||||
|
type traceConn struct { |
||||||
|
// tr for pipeline, if tr != nil meaning on pipeline
|
||||||
|
tr trace.Trace |
||||||
|
ctx context.Context |
||||||
|
// connTag include e.g. ip,port
|
||||||
|
connTags []trace.Tag |
||||||
|
|
||||||
|
// origin redis conn
|
||||||
|
Conn |
||||||
|
pending int |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) { |
||||||
|
statement := getStatement(commandName, args...) |
||||||
|
defer slowLog(statement, time.Now()) |
||||||
|
root, ok := trace.FromContext(t.ctx) |
||||||
|
// NOTE: ignored empty commandName
|
||||||
|
// current sdk will Do empty command after pipeline finished
|
||||||
|
if !ok || commandName == "" { |
||||||
|
return t.Conn.Do(commandName, args...) |
||||||
|
} |
||||||
|
tr := root.Fork("", "Redis:"+commandName) |
||||||
|
tr.SetTag(_internalTags...) |
||||||
|
tr.SetTag(t.connTags...) |
||||||
|
tr.SetTag(trace.TagString(trace.TagDBStatement, statement)) |
||||||
|
reply, err = t.Conn.Do(commandName, args...) |
||||||
|
tr.Finish(&err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) Send(commandName string, args ...interface{}) error { |
||||||
|
statement := getStatement(commandName, args...) |
||||||
|
defer slowLog(statement, time.Now()) |
||||||
|
t.pending++ |
||||||
|
root, ok := trace.FromContext(t.ctx) |
||||||
|
if !ok { |
||||||
|
return t.Conn.Send(commandName, args...) |
||||||
|
} |
||||||
|
if t.tr == nil { |
||||||
|
t.tr = root.Fork("", "Redis:Pipeline") |
||||||
|
t.tr.SetTag(_internalTags...) |
||||||
|
t.tr.SetTag(t.connTags...) |
||||||
|
} |
||||||
|
t.tr.SetLog( |
||||||
|
trace.Log(trace.LogEvent, "Send"), |
||||||
|
trace.Log("db.statement", statement), |
||||||
|
) |
||||||
|
err := t.Conn.Send(commandName, args...) |
||||||
|
if err != nil { |
||||||
|
t.tr.SetTag(trace.TagBool(trace.TagError, true)) |
||||||
|
t.tr.SetLog( |
||||||
|
trace.Log(trace.LogEvent, "Send Fail"), |
||||||
|
trace.Log(trace.LogMessage, err.Error()), |
||||||
|
) |
||||||
|
} |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) Flush() error { |
||||||
|
defer slowLog("Flush", time.Now()) |
||||||
|
if t.tr == nil { |
||||||
|
return t.Conn.Flush() |
||||||
|
} |
||||||
|
t.tr.SetLog(trace.Log(trace.LogEvent, "Flush")) |
||||||
|
err := t.Conn.Flush() |
||||||
|
if err != nil { |
||||||
|
t.tr.SetTag(trace.TagBool(trace.TagError, true)) |
||||||
|
t.tr.SetLog( |
||||||
|
trace.Log(trace.LogEvent, "Flush Fail"), |
||||||
|
trace.Log(trace.LogMessage, err.Error()), |
||||||
|
) |
||||||
|
} |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) Receive() (reply interface{}, err error) { |
||||||
|
defer slowLog("Receive", time.Now()) |
||||||
|
if t.tr == nil { |
||||||
|
return t.Conn.Receive() |
||||||
|
} |
||||||
|
t.tr.SetLog(trace.Log(trace.LogEvent, "Receive")) |
||||||
|
reply, err = t.Conn.Receive() |
||||||
|
if err != nil { |
||||||
|
t.tr.SetTag(trace.TagBool(trace.TagError, true)) |
||||||
|
t.tr.SetLog( |
||||||
|
trace.Log(trace.LogEvent, "Receive Fail"), |
||||||
|
trace.Log(trace.LogMessage, err.Error()), |
||||||
|
) |
||||||
|
} |
||||||
|
if t.pending > 0 { |
||||||
|
t.pending-- |
||||||
|
} |
||||||
|
if t.pending == 0 { |
||||||
|
t.tr.Finish(nil) |
||||||
|
t.tr = nil |
||||||
|
} |
||||||
|
return reply, err |
||||||
|
} |
||||||
|
|
||||||
|
func (t *traceConn) WithContext(ctx context.Context) Conn { |
||||||
|
t.ctx = ctx |
||||||
|
return t |
||||||
|
} |
||||||
|
|
||||||
|
func slowLog(statement string, now time.Time) { |
||||||
|
du := time.Since(now) |
||||||
|
if du > _slowLogDuration { |
||||||
|
log.Warn("%s slow log statement: %s time: %v", _tracePeerService, statement, du) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func getStatement(commandName string, args ...interface{}) (res string) { |
||||||
|
res = commandName |
||||||
|
if len(args) > 0 { |
||||||
|
res = fmt.Sprintf("%s %v", commandName, args[0]) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
@ -0,0 +1,40 @@ |
|||||||
|
### database/hbase |
||||||
|
|
||||||
|
### 项目简介 |
||||||
|
|
||||||
|
Hbase Client,进行封装加入了链路追踪和统计。 |
||||||
|
|
||||||
|
### usage |
||||||
|
```go |
||||||
|
package main |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/database/hbase" |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
config := &hbase.Config{Zookeeper: &hbase.ZKConfig{Addrs: []string{"localhost"}}} |
||||||
|
client := hbase.NewClient(config) |
||||||
|
|
||||||
|
values := map[string]map[string][]byte{"name": {"firstname": []byte("hello"), "lastname": []byte("world")}} |
||||||
|
ctx := context.Background() |
||||||
|
|
||||||
|
_, err := client.PutStr(ctx, "user", "user1", values) |
||||||
|
if err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
|
||||||
|
result, err := client.GetStr(ctx, "user", "user1") |
||||||
|
if err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
fmt.Printf("%v", result) |
||||||
|
} |
||||||
|
``` |
||||||
|
|
||||||
|
##### 依赖包 |
||||||
|
|
||||||
|
1.[gohbase](https://github.com/tsuna/gohbase) |
@ -0,0 +1,23 @@ |
|||||||
|
package hbase |
||||||
|
|
||||||
|
import ( |
||||||
|
xtime "github.com/bilibili/kratos/pkg/time" |
||||||
|
) |
||||||
|
|
||||||
|
// ZKConfig Server&Client settings.
|
||||||
|
type ZKConfig struct { |
||||||
|
Root string |
||||||
|
Addrs []string |
||||||
|
Timeout xtime.Duration |
||||||
|
} |
||||||
|
|
||||||
|
// Config hbase config
|
||||||
|
type Config struct { |
||||||
|
Zookeeper *ZKConfig |
||||||
|
RPCQueueSize int |
||||||
|
FlushInterval xtime.Duration |
||||||
|
EffectiveUser string |
||||||
|
RegionLookupTimeout xtime.Duration |
||||||
|
RegionReadTimeout xtime.Duration |
||||||
|
TestRowKey string |
||||||
|
} |
@ -0,0 +1,297 @@ |
|||||||
|
package hbase |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"io" |
||||||
|
"strings" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/tsuna/gohbase" |
||||||
|
"github.com/tsuna/gohbase/hrpc" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
) |
||||||
|
|
||||||
|
// HookFunc hook function call before every method and hook return function will call after finish.
|
||||||
|
type HookFunc func(ctx context.Context, call hrpc.Call, customName string) func(err error) |
||||||
|
|
||||||
|
// Client hbase client.
|
||||||
|
type Client struct { |
||||||
|
hc gohbase.Client |
||||||
|
addr string |
||||||
|
config *Config |
||||||
|
hooks []HookFunc |
||||||
|
} |
||||||
|
|
||||||
|
// AddHook add hook function.
|
||||||
|
func (c *Client) AddHook(hookFn HookFunc) { |
||||||
|
c.hooks = append(c.hooks, hookFn) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Client) invokeHook(ctx context.Context, call hrpc.Call, customName string) func(error) { |
||||||
|
finishHooks := make([]func(error), 0, len(c.hooks)) |
||||||
|
for _, fn := range c.hooks { |
||||||
|
finishHooks = append(finishHooks, fn(ctx, call, customName)) |
||||||
|
} |
||||||
|
return func(err error) { |
||||||
|
for _, fn := range finishHooks { |
||||||
|
fn(err) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// NewClient new a hbase client.
|
||||||
|
func NewClient(config *Config, options ...gohbase.Option) *Client { |
||||||
|
rawcli := NewRawClient(config, options...) |
||||||
|
rawcli.AddHook(NewSlowLogHook(250 * time.Millisecond)) |
||||||
|
rawcli.AddHook(MetricsHook(nil)) |
||||||
|
rawcli.AddHook(TraceHook("database/hbase", strings.Join(config.Zookeeper.Addrs, ","))) |
||||||
|
return rawcli |
||||||
|
} |
||||||
|
|
||||||
|
// NewRawClient new a hbase client without prometheus metrics and dapper trace hook.
|
||||||
|
func NewRawClient(config *Config, options ...gohbase.Option) *Client { |
||||||
|
zkquorum := strings.Join(config.Zookeeper.Addrs, ",") |
||||||
|
if config.Zookeeper.Root != "" { |
||||||
|
options = append(options, gohbase.ZookeeperRoot(config.Zookeeper.Root)) |
||||||
|
} |
||||||
|
if config.Zookeeper.Timeout != 0 { |
||||||
|
options = append(options, gohbase.ZookeeperTimeout(time.Duration(config.Zookeeper.Timeout))) |
||||||
|
} |
||||||
|
|
||||||
|
if config.RPCQueueSize != 0 { |
||||||
|
log.Warn("RPCQueueSize configuration be ignored") |
||||||
|
} |
||||||
|
// force RpcQueueSize = 1, don't change it !!! it has reason (゜-゜)つロ
|
||||||
|
options = append(options, gohbase.RpcQueueSize(1)) |
||||||
|
|
||||||
|
if config.FlushInterval != 0 { |
||||||
|
options = append(options, gohbase.FlushInterval(time.Duration(config.FlushInterval))) |
||||||
|
} |
||||||
|
if config.EffectiveUser != "" { |
||||||
|
options = append(options, gohbase.EffectiveUser(config.EffectiveUser)) |
||||||
|
} |
||||||
|
if config.RegionLookupTimeout != 0 { |
||||||
|
options = append(options, gohbase.RegionLookupTimeout(time.Duration(config.RegionLookupTimeout))) |
||||||
|
} |
||||||
|
if config.RegionReadTimeout != 0 { |
||||||
|
options = append(options, gohbase.RegionReadTimeout(time.Duration(config.RegionReadTimeout))) |
||||||
|
} |
||||||
|
hc := gohbase.NewClient(zkquorum, options...) |
||||||
|
return &Client{ |
||||||
|
hc: hc, |
||||||
|
addr: zkquorum, |
||||||
|
config: config, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// ScanAll do scan command and return all result
|
||||||
|
// NOTE: if err != nil the results is safe for range operate even not result found
|
||||||
|
func (c *Client) ScanAll(ctx context.Context, table []byte, options ...func(hrpc.Call) error) (results []*hrpc.Result, err error) { |
||||||
|
cursor, err := c.Scan(ctx, table, options...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
for { |
||||||
|
result, err := cursor.Next() |
||||||
|
if err != nil { |
||||||
|
if err == io.EOF { |
||||||
|
break |
||||||
|
} |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
results = append(results, result) |
||||||
|
} |
||||||
|
return results, nil |
||||||
|
} |
||||||
|
|
||||||
|
type scanTrace struct { |
||||||
|
hrpc.Scanner |
||||||
|
finishHook func(error) |
||||||
|
} |
||||||
|
|
||||||
|
func (s *scanTrace) Next() (*hrpc.Result, error) { |
||||||
|
result, err := s.Scanner.Next() |
||||||
|
if err != nil { |
||||||
|
s.finishHook(err) |
||||||
|
} |
||||||
|
return result, err |
||||||
|
} |
||||||
|
|
||||||
|
func (s *scanTrace) Close() error { |
||||||
|
err := s.Scanner.Close() |
||||||
|
s.finishHook(err) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// Scan do a scan command.
|
||||||
|
func (c *Client) Scan(ctx context.Context, table []byte, options ...func(hrpc.Call) error) (scanner hrpc.Scanner, err error) { |
||||||
|
var scan *hrpc.Scan |
||||||
|
scan, err = hrpc.NewScan(ctx, table, options...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
st := &scanTrace{} |
||||||
|
st.finishHook = c.invokeHook(ctx, scan, "Scan") |
||||||
|
st.Scanner = c.hc.Scan(scan) |
||||||
|
return st, nil |
||||||
|
} |
||||||
|
|
||||||
|
// ScanStr scan string
|
||||||
|
func (c *Client) ScanStr(ctx context.Context, table string, options ...func(hrpc.Call) error) (hrpc.Scanner, error) { |
||||||
|
return c.Scan(ctx, []byte(table), options...) |
||||||
|
} |
||||||
|
|
||||||
|
// ScanStrAll scan string
|
||||||
|
// NOTE: if err != nil the results is safe for range operate even not result found
|
||||||
|
func (c *Client) ScanStrAll(ctx context.Context, table string, options ...func(hrpc.Call) error) ([]*hrpc.Result, error) { |
||||||
|
return c.ScanAll(ctx, []byte(table), options...) |
||||||
|
} |
||||||
|
|
||||||
|
// ScanRange get a scanner for the given table and key range.
|
||||||
|
// The range is half-open, i.e. [startRow; stopRow[ -- stopRow is not
|
||||||
|
// included in the range.
|
||||||
|
func (c *Client) ScanRange(ctx context.Context, table, startRow, stopRow []byte, options ...func(hrpc.Call) error) (scanner hrpc.Scanner, err error) { |
||||||
|
var scan *hrpc.Scan |
||||||
|
scan, err = hrpc.NewScanRange(ctx, table, startRow, stopRow, options...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
st := &scanTrace{} |
||||||
|
st.finishHook = c.invokeHook(ctx, scan, "ScanRange") |
||||||
|
st.Scanner = c.hc.Scan(scan) |
||||||
|
return st, nil |
||||||
|
} |
||||||
|
|
||||||
|
// ScanRangeStr get a scanner for the given table and key range.
|
||||||
|
// The range is half-open, i.e. [startRow; stopRow[ -- stopRow is not
|
||||||
|
// included in the range.
|
||||||
|
func (c *Client) ScanRangeStr(ctx context.Context, table, startRow, stopRow string, options ...func(hrpc.Call) error) (hrpc.Scanner, error) { |
||||||
|
return c.ScanRange(ctx, []byte(table), []byte(startRow), []byte(stopRow), options...) |
||||||
|
} |
||||||
|
|
||||||
|
// Get get result for the given table and row key.
|
||||||
|
// NOTE: if err != nil then result != nil, if result not exists result.Cells length is 0
|
||||||
|
func (c *Client) Get(ctx context.Context, table, key []byte, options ...func(hrpc.Call) error) (result *hrpc.Result, err error) { |
||||||
|
var get *hrpc.Get |
||||||
|
get, err = hrpc.NewGet(ctx, table, key, options...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
finishHook := c.invokeHook(ctx, get, "GET") |
||||||
|
result, err = c.hc.Get(get) |
||||||
|
finishHook(err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// GetStr do a get command.
|
||||||
|
// NOTE: if err != nil then result != nil, if result not exists result.Cells length is 0
|
||||||
|
func (c *Client) GetStr(ctx context.Context, table, key string, options ...func(hrpc.Call) error) (result *hrpc.Result, err error) { |
||||||
|
return c.Get(ctx, []byte(table), []byte(key), options...) |
||||||
|
} |
||||||
|
|
||||||
|
// PutStr insert the given family-column-values in the given row key of the given table.
|
||||||
|
func (c *Client) PutStr(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) { |
||||||
|
put, err := hrpc.NewPutStr(ctx, table, key, values, options...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
finishHook := c.invokeHook(ctx, put, "PUT") |
||||||
|
result, err := c.hc.Put(put) |
||||||
|
finishHook(err) |
||||||
|
return result, err |
||||||
|
} |
||||||
|
|
||||||
|
// Delete is used to perform Delete operations on a single row.
|
||||||
|
// To delete entire row, values should be nil.
|
||||||
|
//
|
||||||
|
// To delete specific families, qualifiers map should be nil:
|
||||||
|
// map[string]map[string][]byte{
|
||||||
|
// "cf1": nil,
|
||||||
|
// "cf2": nil,
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// To delete specific qualifiers:
|
||||||
|
// map[string]map[string][]byte{
|
||||||
|
// "cf": map[string][]byte{
|
||||||
|
// "q1": nil,
|
||||||
|
// "q2": nil,
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// To delete all versions before and at a timestamp, pass hrpc.Timestamp() option.
|
||||||
|
// By default all versions will be removed.
|
||||||
|
//
|
||||||
|
// To delete only a specific version at a timestamp, pass hrpc.DeleteOneVersion() option
|
||||||
|
// along with a timestamp. For delete specific qualifiers request, if timestamp is not
|
||||||
|
// passed, only the latest version will be removed. For delete specific families request,
|
||||||
|
// the timestamp should be passed or it will have no effect as it's an expensive
|
||||||
|
// operation to perform.
|
||||||
|
func (c *Client) Delete(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) { |
||||||
|
del, err := hrpc.NewDelStr(ctx, table, key, values, options...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
finishHook := c.invokeHook(ctx, del, "Delete") |
||||||
|
result, err := c.hc.Delete(del) |
||||||
|
finishHook(err) |
||||||
|
return result, err |
||||||
|
} |
||||||
|
|
||||||
|
// Append do a append command.
|
||||||
|
func (c *Client) Append(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) { |
||||||
|
appd, err := hrpc.NewAppStr(ctx, table, key, values, options...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
finishHook := c.invokeHook(ctx, appd, "Append") |
||||||
|
result, err := c.hc.Append(appd) |
||||||
|
finishHook(err) |
||||||
|
return result, err |
||||||
|
} |
||||||
|
|
||||||
|
// Increment the given values in HBase under the given table and key.
|
||||||
|
func (c *Client) Increment(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (int64, error) { |
||||||
|
increment, err := hrpc.NewIncStr(ctx, table, key, values, options...) |
||||||
|
if err != nil { |
||||||
|
return 0, err |
||||||
|
} |
||||||
|
finishHook := c.invokeHook(ctx, increment, "Increment") |
||||||
|
result, err := c.hc.Increment(increment) |
||||||
|
finishHook(err) |
||||||
|
return result, err |
||||||
|
} |
||||||
|
|
||||||
|
// IncrementSingle increment the given value by amount in HBase under the given table, key, family and qualifier.
|
||||||
|
func (c *Client) IncrementSingle(ctx context.Context, table string, key string, family string, qualifier string, amount int64, options ...func(hrpc.Call) error) (int64, error) { |
||||||
|
increment, err := hrpc.NewIncStrSingle(ctx, table, key, family, qualifier, amount, options...) |
||||||
|
if err != nil { |
||||||
|
return 0, err |
||||||
|
} |
||||||
|
|
||||||
|
finishHook := c.invokeHook(ctx, increment, "IncrementSingle") |
||||||
|
result, err := c.hc.Increment(increment) |
||||||
|
finishHook(err) |
||||||
|
return result, err |
||||||
|
} |
||||||
|
|
||||||
|
// Ping ping.
|
||||||
|
func (c *Client) Ping(ctx context.Context) (err error) { |
||||||
|
testRowKey := "test" |
||||||
|
if c.config.TestRowKey != "" { |
||||||
|
testRowKey = c.config.TestRowKey |
||||||
|
} |
||||||
|
values := map[string]map[string][]byte{"test": map[string][]byte{"test": []byte("test")}} |
||||||
|
_, err = c.PutStr(ctx, "test", testRowKey, values) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Close close client.
|
||||||
|
func (c *Client) Close() error { |
||||||
|
c.hc.Close() |
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,48 @@ |
|||||||
|
package hbase |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"io" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/tsuna/gohbase" |
||||||
|
"github.com/tsuna/gohbase/hrpc" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/stat" |
||||||
|
) |
||||||
|
|
||||||
|
func codeFromErr(err error) string { |
||||||
|
code := "unknown_error" |
||||||
|
switch err { |
||||||
|
case gohbase.ErrClientClosed: |
||||||
|
code = "client_closed" |
||||||
|
case gohbase.ErrConnotFindRegion: |
||||||
|
code = "connot_find_region" |
||||||
|
case gohbase.TableNotFound: |
||||||
|
code = "table_not_found" |
||||||
|
case gohbase.ErrRegionUnavailable: |
||||||
|
code = "region_unavailable" |
||||||
|
} |
||||||
|
return code |
||||||
|
} |
||||||
|
|
||||||
|
// MetricsHook if stats is nil use stat.DB as default.
|
||||||
|
func MetricsHook(stats stat.Stat) HookFunc { |
||||||
|
if stats == nil { |
||||||
|
stats = stat.DB |
||||||
|
} |
||||||
|
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) { |
||||||
|
now := time.Now() |
||||||
|
if customName == "" { |
||||||
|
customName = call.Name() |
||||||
|
} |
||||||
|
method := "hbase:" + customName |
||||||
|
return func(err error) { |
||||||
|
durationMs := int64(time.Since(now) / time.Millisecond) |
||||||
|
stats.Timing(method, durationMs) |
||||||
|
if err != nil && err != io.EOF { |
||||||
|
stats.Incr(method, codeFromErr(err)) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,24 @@ |
|||||||
|
package hbase |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/tsuna/gohbase/hrpc" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
) |
||||||
|
|
||||||
|
// NewSlowLogHook log slow operation.
|
||||||
|
func NewSlowLogHook(threshold time.Duration) HookFunc { |
||||||
|
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) { |
||||||
|
start := time.Now() |
||||||
|
return func(error) { |
||||||
|
duration := time.Since(start) |
||||||
|
if duration < threshold { |
||||||
|
return |
||||||
|
} |
||||||
|
log.Warn("hbase slow log: %s %s %s time: %s", customName, call.Table(), call.Key(), duration) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,40 @@ |
|||||||
|
package hbase |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"io" |
||||||
|
|
||||||
|
"github.com/tsuna/gohbase/hrpc" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/net/trace" |
||||||
|
) |
||||||
|
|
||||||
|
// TraceHook create new hbase trace hook.
|
||||||
|
func TraceHook(component, instance string) HookFunc { |
||||||
|
var internalTags []trace.Tag |
||||||
|
internalTags = append(internalTags, trace.TagString(trace.TagComponent, component)) |
||||||
|
internalTags = append(internalTags, trace.TagString(trace.TagDBInstance, instance)) |
||||||
|
internalTags = append(internalTags, trace.TagString(trace.TagPeerService, "hbase")) |
||||||
|
internalTags = append(internalTags, trace.TagString(trace.TagSpanKind, "client")) |
||||||
|
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) { |
||||||
|
noop := func(error) {} |
||||||
|
root, ok := trace.FromContext(ctx) |
||||||
|
if !ok { |
||||||
|
return noop |
||||||
|
} |
||||||
|
if customName == "" { |
||||||
|
customName = call.Name() |
||||||
|
} |
||||||
|
span := root.Fork("", "Hbase:"+customName) |
||||||
|
span.SetTag(internalTags...) |
||||||
|
statement := string(call.Table()) + " " + string(call.Key()) |
||||||
|
span.SetTag(trace.TagString(trace.TagDBStatement, statement)) |
||||||
|
return func(err error) { |
||||||
|
if err == io.EOF { |
||||||
|
// reset error for trace.
|
||||||
|
err = nil |
||||||
|
} |
||||||
|
span.Finish(&err) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,9 @@ |
|||||||
|
#### database/sql |
||||||
|
|
||||||
|
##### 项目简介 |
||||||
|
MySQL数据库驱动,进行封装加入了链路追踪和统计。 |
||||||
|
|
||||||
|
如果需要SQL级别的超时管理 可以在业务代码里面使用context.WithDeadline实现 推荐超时配置放到application.toml里面 方便热加载 |
||||||
|
|
||||||
|
##### 依赖包 |
||||||
|
1. [Go-MySQL-Driver](https://github.com/go-sql-driver/mysql) |
@ -0,0 +1,40 @@ |
|||||||
|
package sql |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||||
|
"github.com/bilibili/kratos/pkg/stat" |
||||||
|
"github.com/bilibili/kratos/pkg/time" |
||||||
|
|
||||||
|
// database driver
|
||||||
|
_ "github.com/go-sql-driver/mysql" |
||||||
|
) |
||||||
|
|
||||||
|
var stats = stat.DB |
||||||
|
|
||||||
|
// Config mysql config.
|
||||||
|
type Config struct { |
||||||
|
Addr string // for trace
|
||||||
|
DSN string // write data source name.
|
||||||
|
ReadDSN []string // read data source name.
|
||||||
|
Active int // pool
|
||||||
|
Idle int // pool
|
||||||
|
IdleTimeout time.Duration // connect max life time.
|
||||||
|
QueryTimeout time.Duration // query sql timeout
|
||||||
|
ExecTimeout time.Duration // execute sql timeout
|
||||||
|
TranTimeout time.Duration // transaction sql timeout
|
||||||
|
Breaker *breaker.Config // breaker
|
||||||
|
} |
||||||
|
|
||||||
|
// NewMySQL new db and retry connection when has error.
|
||||||
|
func NewMySQL(c *Config) (db *DB) { |
||||||
|
if c.QueryTimeout == 0 || c.ExecTimeout == 0 || c.TranTimeout == 0 { |
||||||
|
panic("mysql must be set query/execute/transction timeout") |
||||||
|
} |
||||||
|
db, err := Open(c) |
||||||
|
if err != nil { |
||||||
|
log.Error("open mysql error(%v)", err) |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
@ -0,0 +1,678 @@ |
|||||||
|
package sql |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"database/sql" |
||||||
|
"fmt" |
||||||
|
"strings" |
||||||
|
"sync/atomic" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/ecode" |
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||||
|
"github.com/bilibili/kratos/pkg/net/trace" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
_family = "sql_client" |
||||||
|
_slowLogDuration = time.Millisecond * 250 |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
// ErrStmtNil prepared stmt error
|
||||||
|
ErrStmtNil = errors.New("sql: prepare failed and stmt nil") |
||||||
|
// ErrNoMaster is returned by Master when call master multiple times.
|
||||||
|
ErrNoMaster = errors.New("sql: no master instance") |
||||||
|
// ErrNoRows is returned by Scan when QueryRow doesn't return a row.
|
||||||
|
// In such a case, QueryRow returns a placeholder *Row value that defers
|
||||||
|
// this error until a Scan.
|
||||||
|
ErrNoRows = sql.ErrNoRows |
||||||
|
// ErrTxDone transaction done.
|
||||||
|
ErrTxDone = sql.ErrTxDone |
||||||
|
) |
||||||
|
|
||||||
|
// DB database.
|
||||||
|
type DB struct { |
||||||
|
write *conn |
||||||
|
read []*conn |
||||||
|
idx int64 |
||||||
|
master *DB |
||||||
|
} |
||||||
|
|
||||||
|
// conn database connection
|
||||||
|
type conn struct { |
||||||
|
*sql.DB |
||||||
|
breaker breaker.Breaker |
||||||
|
conf *Config |
||||||
|
} |
||||||
|
|
||||||
|
// Tx transaction.
|
||||||
|
type Tx struct { |
||||||
|
db *conn |
||||||
|
tx *sql.Tx |
||||||
|
t trace.Trace |
||||||
|
c context.Context |
||||||
|
cancel func() |
||||||
|
} |
||||||
|
|
||||||
|
// Row row.
|
||||||
|
type Row struct { |
||||||
|
err error |
||||||
|
*sql.Row |
||||||
|
db *conn |
||||||
|
query string |
||||||
|
args []interface{} |
||||||
|
t trace.Trace |
||||||
|
cancel func() |
||||||
|
} |
||||||
|
|
||||||
|
// Scan copies the columns from the matched row into the values pointed at by dest.
|
||||||
|
func (r *Row) Scan(dest ...interface{}) (err error) { |
||||||
|
defer slowLog(fmt.Sprintf("Scan query(%s) args(%+v)", r.query, r.args), time.Now()) |
||||||
|
if r.t != nil { |
||||||
|
defer r.t.Finish(&err) |
||||||
|
} |
||||||
|
if r.err != nil { |
||||||
|
err = r.err |
||||||
|
} else if r.Row == nil { |
||||||
|
err = ErrStmtNil |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
err = r.Row.Scan(dest...) |
||||||
|
if r.cancel != nil { |
||||||
|
r.cancel() |
||||||
|
} |
||||||
|
r.db.onBreaker(&err) |
||||||
|
if err != ErrNoRows { |
||||||
|
err = errors.Wrapf(err, "query %s args %+v", r.query, r.args) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Rows rows.
|
||||||
|
type Rows struct { |
||||||
|
*sql.Rows |
||||||
|
cancel func() |
||||||
|
} |
||||||
|
|
||||||
|
// Close closes the Rows, preventing further enumeration. If Next is called
|
||||||
|
// and returns false and there are no further result sets,
|
||||||
|
// the Rows are closed automatically and it will suffice to check the
|
||||||
|
// result of Err. Close is idempotent and does not affect the result of Err.
|
||||||
|
func (rs *Rows) Close() (err error) { |
||||||
|
err = errors.WithStack(rs.Rows.Close()) |
||||||
|
if rs.cancel != nil { |
||||||
|
rs.cancel() |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Stmt prepared stmt.
|
||||||
|
type Stmt struct { |
||||||
|
db *conn |
||||||
|
tx bool |
||||||
|
query string |
||||||
|
stmt atomic.Value |
||||||
|
t trace.Trace |
||||||
|
} |
||||||
|
|
||||||
|
// Open opens a database specified by its database driver name and a
|
||||||
|
// driver-specific data source name, usually consisting of at least a database
|
||||||
|
// name and connection information.
|
||||||
|
func Open(c *Config) (*DB, error) { |
||||||
|
db := new(DB) |
||||||
|
d, err := connect(c, c.DSN) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
brkGroup := breaker.NewGroup(c.Breaker) |
||||||
|
brk := brkGroup.Get(c.Addr) |
||||||
|
w := &conn{DB: d, breaker: brk, conf: c} |
||||||
|
rs := make([]*conn, 0, len(c.ReadDSN)) |
||||||
|
for _, rd := range c.ReadDSN { |
||||||
|
d, err := connect(c, rd) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
brk := brkGroup.Get(parseDSNAddr(rd)) |
||||||
|
r := &conn{DB: d, breaker: brk, conf: c} |
||||||
|
rs = append(rs, r) |
||||||
|
} |
||||||
|
db.write = w |
||||||
|
db.read = rs |
||||||
|
db.master = &DB{write: db.write} |
||||||
|
return db, nil |
||||||
|
} |
||||||
|
|
||||||
|
func connect(c *Config, dataSourceName string) (*sql.DB, error) { |
||||||
|
d, err := sql.Open("mysql", dataSourceName) |
||||||
|
if err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
d.SetMaxOpenConns(c.Active) |
||||||
|
d.SetMaxIdleConns(c.Idle) |
||||||
|
d.SetConnMaxLifetime(time.Duration(c.IdleTimeout)) |
||||||
|
return d, nil |
||||||
|
} |
||||||
|
|
||||||
|
// Begin starts a transaction. The isolation level is dependent on the driver.
|
||||||
|
func (db *DB) Begin(c context.Context) (tx *Tx, err error) { |
||||||
|
return db.write.begin(c) |
||||||
|
} |
||||||
|
|
||||||
|
// Exec executes a query without returning any rows.
|
||||||
|
// The args are for any placeholder parameters in the query.
|
||||||
|
func (db *DB) Exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) { |
||||||
|
return db.write.exec(c, query, args...) |
||||||
|
} |
||||||
|
|
||||||
|
// Prepare creates a prepared statement for later queries or executions.
|
||||||
|
// Multiple queries or executions may be run concurrently from the returned
|
||||||
|
// statement. The caller must call the statement's Close method when the
|
||||||
|
// statement is no longer needed.
|
||||||
|
func (db *DB) Prepare(query string) (*Stmt, error) { |
||||||
|
return db.write.prepare(query) |
||||||
|
} |
||||||
|
|
||||||
|
// Prepared creates a prepared statement for later queries or executions.
|
||||||
|
// Multiple queries or executions may be run concurrently from the returned
|
||||||
|
// statement. The caller must call the statement's Close method when the
|
||||||
|
// statement is no longer needed.
|
||||||
|
func (db *DB) Prepared(query string) (stmt *Stmt) { |
||||||
|
return db.write.prepared(query) |
||||||
|
} |
||||||
|
|
||||||
|
// Query executes a query that returns rows, typically a SELECT. The args are
|
||||||
|
// for any placeholder parameters in the query.
|
||||||
|
func (db *DB) Query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) { |
||||||
|
idx := db.readIndex() |
||||||
|
for i := range db.read { |
||||||
|
if rows, err = db.read[(idx+i)%len(db.read)].query(c, query, args...); !ecode.ServiceUnavailable.Equal(err) { |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
return db.write.query(c, query, args...) |
||||||
|
} |
||||||
|
|
||||||
|
// QueryRow executes a query that is expected to return at most one row.
|
||||||
|
// QueryRow always returns a non-nil value. Errors are deferred until Row's
|
||||||
|
// Scan method is called.
|
||||||
|
func (db *DB) QueryRow(c context.Context, query string, args ...interface{}) *Row { |
||||||
|
idx := db.readIndex() |
||||||
|
for i := range db.read { |
||||||
|
if row := db.read[(idx+i)%len(db.read)].queryRow(c, query, args...); !ecode.ServiceUnavailable.Equal(row.err) { |
||||||
|
return row |
||||||
|
} |
||||||
|
} |
||||||
|
return db.write.queryRow(c, query, args...) |
||||||
|
} |
||||||
|
|
||||||
|
func (db *DB) readIndex() int { |
||||||
|
if len(db.read) == 0 { |
||||||
|
return 0 |
||||||
|
} |
||||||
|
v := atomic.AddInt64(&db.idx, 1) |
||||||
|
return int(v) % len(db.read) |
||||||
|
} |
||||||
|
|
||||||
|
// Close closes the write and read database, releasing any open resources.
|
||||||
|
func (db *DB) Close() (err error) { |
||||||
|
if e := db.write.Close(); e != nil { |
||||||
|
err = errors.WithStack(e) |
||||||
|
} |
||||||
|
for _, rd := range db.read { |
||||||
|
if e := rd.Close(); e != nil { |
||||||
|
err = errors.WithStack(e) |
||||||
|
} |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Ping verifies a connection to the database is still alive, establishing a
|
||||||
|
// connection if necessary.
|
||||||
|
func (db *DB) Ping(c context.Context) (err error) { |
||||||
|
if err = db.write.ping(c); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
for _, rd := range db.read { |
||||||
|
if err = rd.ping(c); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Master return *DB instance direct use master conn
|
||||||
|
// use this *DB instance only when you have some reason need to get result without any delay.
|
||||||
|
func (db *DB) Master() *DB { |
||||||
|
if db.master == nil { |
||||||
|
panic(ErrNoMaster) |
||||||
|
} |
||||||
|
return db.master |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) onBreaker(err *error) { |
||||||
|
if err != nil && *err != nil && *err != sql.ErrNoRows && *err != sql.ErrTxDone { |
||||||
|
db.breaker.MarkFailed() |
||||||
|
} else { |
||||||
|
db.breaker.MarkSuccess() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) begin(c context.Context) (tx *Tx, err error) { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog("Begin", now) |
||||||
|
t, ok := trace.FromContext(c) |
||||||
|
if ok { |
||||||
|
t = t.Fork(_family, "begin") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, "")) |
||||||
|
defer func() { |
||||||
|
if err != nil { |
||||||
|
t.Finish(&err) |
||||||
|
} |
||||||
|
}() |
||||||
|
} |
||||||
|
if err = db.breaker.Allow(); err != nil { |
||||||
|
stats.Incr("mysql:begin", "breaker") |
||||||
|
return |
||||||
|
} |
||||||
|
_, c, cancel := db.conf.TranTimeout.Shrink(c) |
||||||
|
rtx, err := db.BeginTx(c, nil) |
||||||
|
stats.Timing("mysql:begin", int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
cancel() |
||||||
|
return |
||||||
|
} |
||||||
|
tx = &Tx{tx: rtx, t: t, db: db, c: c, cancel: cancel} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", query, args), now) |
||||||
|
if t, ok := trace.FromContext(c); ok { |
||||||
|
t = t.Fork(_family, "exec") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query)) |
||||||
|
defer t.Finish(&err) |
||||||
|
} |
||||||
|
if err = db.breaker.Allow(); err != nil { |
||||||
|
stats.Incr("mysql:exec", "breaker") |
||||||
|
return |
||||||
|
} |
||||||
|
_, c, cancel := db.conf.ExecTimeout.Shrink(c) |
||||||
|
res, err = db.ExecContext(c, query, args...) |
||||||
|
cancel() |
||||||
|
db.onBreaker(&err) |
||||||
|
stats.Timing("mysql:exec", int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "exec:%s, args:%+v", query, args) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) ping(c context.Context) (err error) { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog("Ping", now) |
||||||
|
if t, ok := trace.FromContext(c); ok { |
||||||
|
t = t.Fork(_family, "ping") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, "")) |
||||||
|
defer t.Finish(&err) |
||||||
|
} |
||||||
|
if err = db.breaker.Allow(); err != nil { |
||||||
|
stats.Incr("mysql:ping", "breaker") |
||||||
|
return |
||||||
|
} |
||||||
|
_, c, cancel := db.conf.ExecTimeout.Shrink(c) |
||||||
|
err = db.PingContext(c) |
||||||
|
cancel() |
||||||
|
db.onBreaker(&err) |
||||||
|
stats.Timing("mysql:ping", int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) prepare(query string) (*Stmt, error) { |
||||||
|
defer slowLog(fmt.Sprintf("Prepare query(%s)", query), time.Now()) |
||||||
|
stmt, err := db.Prepare(query) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "prepare %s", query) |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
st := &Stmt{query: query, db: db} |
||||||
|
st.stmt.Store(stmt) |
||||||
|
return st, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) prepared(query string) (stmt *Stmt) { |
||||||
|
defer slowLog(fmt.Sprintf("Prepared query(%s)", query), time.Now()) |
||||||
|
stmt = &Stmt{query: query, db: db} |
||||||
|
s, err := db.Prepare(query) |
||||||
|
if err == nil { |
||||||
|
stmt.stmt.Store(s) |
||||||
|
return |
||||||
|
} |
||||||
|
go func() { |
||||||
|
for { |
||||||
|
s, err := db.Prepare(query) |
||||||
|
if err != nil { |
||||||
|
time.Sleep(time.Second) |
||||||
|
continue |
||||||
|
} |
||||||
|
stmt.stmt.Store(s) |
||||||
|
return |
||||||
|
} |
||||||
|
}() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", query, args), now) |
||||||
|
if t, ok := trace.FromContext(c); ok { |
||||||
|
t = t.Fork(_family, "query") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query)) |
||||||
|
defer t.Finish(&err) |
||||||
|
} |
||||||
|
if err = db.breaker.Allow(); err != nil { |
||||||
|
stats.Incr("mysql:query", "breaker") |
||||||
|
return |
||||||
|
} |
||||||
|
_, c, cancel := db.conf.QueryTimeout.Shrink(c) |
||||||
|
rs, err := db.DB.QueryContext(c, query, args...) |
||||||
|
db.onBreaker(&err) |
||||||
|
stats.Timing("mysql:query", int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "query:%s, args:%+v", query, args) |
||||||
|
cancel() |
||||||
|
return |
||||||
|
} |
||||||
|
rows = &Rows{Rows: rs, cancel: cancel} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) queryRow(c context.Context, query string, args ...interface{}) *Row { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", query, args), now) |
||||||
|
t, ok := trace.FromContext(c) |
||||||
|
if ok { |
||||||
|
t = t.Fork(_family, "queryrow") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query)) |
||||||
|
} |
||||||
|
if err := db.breaker.Allow(); err != nil { |
||||||
|
stats.Incr("mysql:queryrow", "breaker") |
||||||
|
return &Row{db: db, t: t, err: err} |
||||||
|
} |
||||||
|
_, c, cancel := db.conf.QueryTimeout.Shrink(c) |
||||||
|
r := db.DB.QueryRowContext(c, query, args...) |
||||||
|
stats.Timing("mysql:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||||
|
return &Row{db: db, Row: r, query: query, args: args, t: t, cancel: cancel} |
||||||
|
} |
||||||
|
|
||||||
|
// Close closes the statement.
|
||||||
|
func (s *Stmt) Close() (err error) { |
||||||
|
if s == nil { |
||||||
|
err = ErrStmtNil |
||||||
|
return |
||||||
|
} |
||||||
|
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||||
|
if ok { |
||||||
|
err = errors.WithStack(stmt.Close()) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Exec executes a prepared statement with the given arguments and returns a
|
||||||
|
// Result summarizing the effect of the statement.
|
||||||
|
func (s *Stmt) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) { |
||||||
|
if s == nil { |
||||||
|
err = ErrStmtNil |
||||||
|
return |
||||||
|
} |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", s.query, args), now) |
||||||
|
if s.tx { |
||||||
|
if s.t != nil { |
||||||
|
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||||
|
} |
||||||
|
} else if t, ok := trace.FromContext(c); ok { |
||||||
|
t = t.Fork(_family, "exec") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query)) |
||||||
|
defer t.Finish(&err) |
||||||
|
} |
||||||
|
if err = s.db.breaker.Allow(); err != nil { |
||||||
|
stats.Incr("mysql:stmt:exec", "breaker") |
||||||
|
return |
||||||
|
} |
||||||
|
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||||
|
if !ok { |
||||||
|
err = ErrStmtNil |
||||||
|
return |
||||||
|
} |
||||||
|
_, c, cancel := s.db.conf.ExecTimeout.Shrink(c) |
||||||
|
res, err = stmt.ExecContext(c, args...) |
||||||
|
cancel() |
||||||
|
s.db.onBreaker(&err) |
||||||
|
stats.Timing("mysql:stmt:exec", int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "exec:%s, args:%+v", s.query, args) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Query executes a prepared query statement with the given arguments and
|
||||||
|
// returns the query results as a *Rows.
|
||||||
|
func (s *Stmt) Query(c context.Context, args ...interface{}) (rows *Rows, err error) { |
||||||
|
if s == nil { |
||||||
|
err = ErrStmtNil |
||||||
|
return |
||||||
|
} |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", s.query, args), now) |
||||||
|
if s.tx { |
||||||
|
if s.t != nil { |
||||||
|
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||||
|
} |
||||||
|
} else if t, ok := trace.FromContext(c); ok { |
||||||
|
t = t.Fork(_family, "query") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query)) |
||||||
|
defer t.Finish(&err) |
||||||
|
} |
||||||
|
if err = s.db.breaker.Allow(); err != nil { |
||||||
|
stats.Incr("mysql:stmt:query", "breaker") |
||||||
|
return |
||||||
|
} |
||||||
|
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||||
|
if !ok { |
||||||
|
err = ErrStmtNil |
||||||
|
return |
||||||
|
} |
||||||
|
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c) |
||||||
|
rs, err := stmt.QueryContext(c, args...) |
||||||
|
s.db.onBreaker(&err) |
||||||
|
stats.Timing("mysql:stmt:query", int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "query:%s, args:%+v", s.query, args) |
||||||
|
cancel() |
||||||
|
return |
||||||
|
} |
||||||
|
rows = &Rows{Rows: rs, cancel: cancel} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// QueryRow executes a prepared query statement with the given arguments.
|
||||||
|
// If an error occurs during the execution of the statement, that error will
|
||||||
|
// be returned by a call to Scan on the returned *Row, which is always non-nil.
|
||||||
|
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
|
||||||
|
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
|
||||||
|
func (s *Stmt) QueryRow(c context.Context, args ...interface{}) (row *Row) { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", s.query, args), now) |
||||||
|
row = &Row{db: s.db, query: s.query, args: args} |
||||||
|
if s == nil { |
||||||
|
row.err = ErrStmtNil |
||||||
|
return |
||||||
|
} |
||||||
|
if s.tx { |
||||||
|
if s.t != nil { |
||||||
|
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||||
|
} |
||||||
|
} else if t, ok := trace.FromContext(c); ok { |
||||||
|
t = t.Fork(_family, "queryrow") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query)) |
||||||
|
row.t = t |
||||||
|
} |
||||||
|
if row.err = s.db.breaker.Allow(); row.err != nil { |
||||||
|
stats.Incr("mysql:stmt:queryrow", "breaker") |
||||||
|
return |
||||||
|
} |
||||||
|
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||||
|
if !ok { |
||||||
|
return |
||||||
|
} |
||||||
|
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c) |
||||||
|
row.Row = stmt.QueryRowContext(c, args...) |
||||||
|
row.cancel = cancel |
||||||
|
stats.Timing("mysql:stmt:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Commit commits the transaction.
|
||||||
|
func (tx *Tx) Commit() (err error) { |
||||||
|
err = tx.tx.Commit() |
||||||
|
tx.cancel() |
||||||
|
tx.db.onBreaker(&err) |
||||||
|
if tx.t != nil { |
||||||
|
tx.t.Finish(&err) |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Rollback aborts the transaction.
|
||||||
|
func (tx *Tx) Rollback() (err error) { |
||||||
|
err = tx.tx.Rollback() |
||||||
|
tx.cancel() |
||||||
|
tx.db.onBreaker(&err) |
||||||
|
if tx.t != nil { |
||||||
|
tx.t.Finish(&err) |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Exec executes a query that doesn't return rows. For example: an INSERT and
|
||||||
|
// UPDATE.
|
||||||
|
func (tx *Tx) Exec(query string, args ...interface{}) (res sql.Result, err error) { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", query, args), now) |
||||||
|
if tx.t != nil { |
||||||
|
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("exec %s", query))) |
||||||
|
} |
||||||
|
res, err = tx.tx.ExecContext(tx.c, query, args...) |
||||||
|
stats.Timing("mysql:tx:exec", int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "exec:%s, args:%+v", query, args) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Query executes a query that returns rows, typically a SELECT.
|
||||||
|
func (tx *Tx) Query(query string, args ...interface{}) (rows *Rows, err error) { |
||||||
|
if tx.t != nil { |
||||||
|
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("query %s", query))) |
||||||
|
} |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", query, args), now) |
||||||
|
defer func() { |
||||||
|
stats.Timing("mysql:tx:query", int64(time.Since(now)/time.Millisecond)) |
||||||
|
}() |
||||||
|
rs, err := tx.tx.QueryContext(tx.c, query, args...) |
||||||
|
if err == nil { |
||||||
|
rows = &Rows{Rows: rs} |
||||||
|
} else { |
||||||
|
err = errors.Wrapf(err, "query:%s, args:%+v", query, args) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// QueryRow executes a query that is expected to return at most one row.
|
||||||
|
// QueryRow always returns a non-nil value. Errors are deferred until Row's
|
||||||
|
// Scan method is called.
|
||||||
|
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { |
||||||
|
if tx.t != nil { |
||||||
|
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("queryrow %s", query))) |
||||||
|
} |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", query, args), now) |
||||||
|
defer func() { |
||||||
|
stats.Timing("mysql:tx:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||||
|
}() |
||||||
|
r := tx.tx.QueryRowContext(tx.c, query, args...) |
||||||
|
return &Row{Row: r, db: tx.db, query: query, args: args} |
||||||
|
} |
||||||
|
|
||||||
|
// Stmt returns a transaction-specific prepared statement from an existing statement.
|
||||||
|
func (tx *Tx) Stmt(stmt *Stmt) *Stmt { |
||||||
|
as, ok := stmt.stmt.Load().(*sql.Stmt) |
||||||
|
if !ok { |
||||||
|
return nil |
||||||
|
} |
||||||
|
ts := tx.tx.StmtContext(tx.c, as) |
||||||
|
st := &Stmt{query: stmt.query, tx: true, t: tx.t, db: tx.db} |
||||||
|
st.stmt.Store(ts) |
||||||
|
return st |
||||||
|
} |
||||||
|
|
||||||
|
// Prepare creates a prepared statement for use within a transaction.
|
||||||
|
// The returned statement operates within the transaction and can no longer be
|
||||||
|
// used once the transaction has been committed or rolled back.
|
||||||
|
// To use an existing prepared statement on this transaction, see Tx.Stmt.
|
||||||
|
func (tx *Tx) Prepare(query string) (*Stmt, error) { |
||||||
|
if tx.t != nil { |
||||||
|
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("prepare %s", query))) |
||||||
|
} |
||||||
|
defer slowLog(fmt.Sprintf("Prepare query(%s)", query), time.Now()) |
||||||
|
stmt, err := tx.tx.Prepare(query) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "prepare %s", query) |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
st := &Stmt{query: query, tx: true, t: tx.t, db: tx.db} |
||||||
|
st.stmt.Store(stmt) |
||||||
|
return st, nil |
||||||
|
} |
||||||
|
|
||||||
|
// parseDSNAddr parse dsn name and return addr.
|
||||||
|
func parseDSNAddr(dsn string) (addr string) { |
||||||
|
if dsn == "" { |
||||||
|
return |
||||||
|
} |
||||||
|
part0 := strings.Split(dsn, "@") |
||||||
|
if len(part0) > 1 { |
||||||
|
part1 := strings.Split(part0[1], "?") |
||||||
|
if len(part1) > 0 { |
||||||
|
addr = part1[0] |
||||||
|
} |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func slowLog(statement string, now time.Time) { |
||||||
|
du := time.Since(now) |
||||||
|
if du > _slowLogDuration { |
||||||
|
log.Warn("%s slow log statement: %s time: %v", _family, statement, du) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,14 @@ |
|||||||
|
#### database/tidb |
||||||
|
|
||||||
|
##### 项目简介 |
||||||
|
TiDB数据库驱动 对mysql驱动进行封装 |
||||||
|
|
||||||
|
##### 功能 |
||||||
|
1. 支持discovery服务发现 多节点直连 |
||||||
|
2. 支持通过lvs单一地址连接 |
||||||
|
3. 支持prepare绑定多个节点 |
||||||
|
4. 支持动态增减节点负载均衡 |
||||||
|
5. 日志区分运行节点 |
||||||
|
|
||||||
|
##### 依赖包 |
||||||
|
1.[Go-MySQL-Driver](https://github.com/go-sql-driver/mysql) |
@ -0,0 +1,58 @@ |
|||||||
|
package tidb |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"fmt" |
||||||
|
"strings" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/conf/env" |
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
"github.com/bilibili/kratos/pkg/naming" |
||||||
|
"github.com/bilibili/kratos/pkg/naming/discovery" |
||||||
|
) |
||||||
|
|
||||||
|
var _schema = "tidb://" |
||||||
|
|
||||||
|
func (db *DB) nodeList() (nodes []string) { |
||||||
|
var ( |
||||||
|
insInfo *naming.InstancesInfo |
||||||
|
insMap map[string][]*naming.Instance |
||||||
|
ins []*naming.Instance |
||||||
|
ok bool |
||||||
|
) |
||||||
|
if insInfo, ok = db.dis.Fetch(context.Background()); !ok { |
||||||
|
return |
||||||
|
} |
||||||
|
insMap = insInfo.Instances |
||||||
|
if ins, ok = insMap[env.Zone]; !ok || len(ins) == 0 { |
||||||
|
return |
||||||
|
} |
||||||
|
for _, in := range ins { |
||||||
|
for _, addr := range in.Addrs { |
||||||
|
if strings.HasPrefix(addr, _schema) { |
||||||
|
addr = strings.Replace(addr, _schema, "", -1) |
||||||
|
nodes = append(nodes, addr) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
log.Info("tidb get %s instances(%v)", db.appid, nodes) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (db *DB) disc() (nodes []string) { |
||||||
|
db.dis = discovery.Build(db.appid) |
||||||
|
e := db.dis.Watch() |
||||||
|
select { |
||||||
|
case <-e: |
||||||
|
nodes = db.nodeList() |
||||||
|
case <-time.After(10 * time.Second): |
||||||
|
panic("tidb init discovery err") |
||||||
|
} |
||||||
|
if len(nodes) == 0 { |
||||||
|
panic(fmt.Sprintf("tidb %s no instance", db.appid)) |
||||||
|
} |
||||||
|
go db.nodeproc(e) |
||||||
|
log.Info("init tidb discvoery info successfully") |
||||||
|
return |
||||||
|
} |
@ -0,0 +1,82 @@ |
|||||||
|
package tidb |
||||||
|
|
||||||
|
import ( |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
) |
||||||
|
|
||||||
|
func (db *DB) nodeproc(e <-chan struct{}) { |
||||||
|
if db.dis == nil { |
||||||
|
return |
||||||
|
} |
||||||
|
for { |
||||||
|
<-e |
||||||
|
nodes := db.nodeList() |
||||||
|
if len(nodes) == 0 { |
||||||
|
continue |
||||||
|
} |
||||||
|
cm := make(map[string]*conn) |
||||||
|
var conns []*conn |
||||||
|
for _, conn := range db.conns { |
||||||
|
cm[conn.addr] = conn |
||||||
|
} |
||||||
|
for _, node := range nodes { |
||||||
|
if cm[node] != nil { |
||||||
|
conns = append(conns, cm[node]) |
||||||
|
continue |
||||||
|
} |
||||||
|
c, err := db.connectDSN(genDSN(db.conf.DSN, node)) |
||||||
|
if err == nil { |
||||||
|
conns = append(conns, c) |
||||||
|
} else { |
||||||
|
log.Error("tidb: connect addr: %s err: %+v", node, err) |
||||||
|
} |
||||||
|
} |
||||||
|
if len(conns) == 0 { |
||||||
|
log.Error("tidb: no nodes ignore event") |
||||||
|
continue |
||||||
|
} |
||||||
|
oldConns := db.conns |
||||||
|
db.mutex.Lock() |
||||||
|
db.conns = conns |
||||||
|
db.mutex.Unlock() |
||||||
|
log.Info("tidb: new nodes: %v", nodes) |
||||||
|
var removedConn []*conn |
||||||
|
for _, conn := range oldConns { |
||||||
|
var exist bool |
||||||
|
for _, c := range conns { |
||||||
|
if c.addr == conn.addr { |
||||||
|
exist = true |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
if !exist { |
||||||
|
removedConn = append(removedConn, conn) |
||||||
|
} |
||||||
|
} |
||||||
|
go db.closeConns(removedConn) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (db *DB) closeConns(conns []*conn) { |
||||||
|
if len(conns) == 0 { |
||||||
|
return |
||||||
|
} |
||||||
|
du := db.conf.QueryTimeout |
||||||
|
if db.conf.ExecTimeout > du { |
||||||
|
du = db.conf.ExecTimeout |
||||||
|
} |
||||||
|
if db.conf.TranTimeout > du { |
||||||
|
du = db.conf.TranTimeout |
||||||
|
} |
||||||
|
time.Sleep(time.Duration(du)) |
||||||
|
for _, conn := range conns { |
||||||
|
err := conn.Close() |
||||||
|
if err != nil { |
||||||
|
log.Error("tidb: close removed conn: %s err: %v", conn.addr, err) |
||||||
|
} else { |
||||||
|
log.Info("tidb: close removed conn: %s", conn.addr) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,739 @@ |
|||||||
|
package tidb |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"database/sql" |
||||||
|
"fmt" |
||||||
|
"sync" |
||||||
|
"sync/atomic" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
"github.com/bilibili/kratos/pkg/naming" |
||||||
|
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||||
|
"github.com/bilibili/kratos/pkg/net/trace" |
||||||
|
|
||||||
|
"github.com/go-sql-driver/mysql" |
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
_family = "tidb_client" |
||||||
|
_slowLogDuration = time.Millisecond * 250 |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
// ErrStmtNil prepared stmt error
|
||||||
|
ErrStmtNil = errors.New("sql: prepare failed and stmt nil") |
||||||
|
// ErrNoRows is returned by Scan when QueryRow doesn't return a row.
|
||||||
|
// In such a case, QueryRow returns a placeholder *Row value that defers
|
||||||
|
// this error until a Scan.
|
||||||
|
ErrNoRows = sql.ErrNoRows |
||||||
|
// ErrTxDone transaction done.
|
||||||
|
ErrTxDone = sql.ErrTxDone |
||||||
|
) |
||||||
|
|
||||||
|
// DB database.
|
||||||
|
type DB struct { |
||||||
|
conf *Config |
||||||
|
conns []*conn |
||||||
|
idx int64 |
||||||
|
dis naming.Resolver |
||||||
|
appid string |
||||||
|
mutex sync.RWMutex |
||||||
|
breakerGroup *breaker.Group |
||||||
|
} |
||||||
|
|
||||||
|
// conn database connection
|
||||||
|
type conn struct { |
||||||
|
*sql.DB |
||||||
|
breaker breaker.Breaker |
||||||
|
conf *Config |
||||||
|
addr string |
||||||
|
} |
||||||
|
|
||||||
|
// Tx transaction.
|
||||||
|
type Tx struct { |
||||||
|
db *conn |
||||||
|
tx *sql.Tx |
||||||
|
t trace.Trace |
||||||
|
c context.Context |
||||||
|
cancel func() |
||||||
|
} |
||||||
|
|
||||||
|
// Row row.
|
||||||
|
type Row struct { |
||||||
|
err error |
||||||
|
*sql.Row |
||||||
|
db *conn |
||||||
|
query string |
||||||
|
args []interface{} |
||||||
|
t trace.Trace |
||||||
|
cancel func() |
||||||
|
} |
||||||
|
|
||||||
|
// Scan copies the columns from the matched row into the values pointed at by dest.
|
||||||
|
func (r *Row) Scan(dest ...interface{}) (err error) { |
||||||
|
defer slowLog(fmt.Sprintf("Scan addr: %s query(%s) args(%+v)", r.db.addr, r.query, r.args), time.Now()) |
||||||
|
if r.t != nil { |
||||||
|
defer r.t.Finish(&err) |
||||||
|
} |
||||||
|
if r.err != nil { |
||||||
|
err = r.err |
||||||
|
} else if r.Row == nil { |
||||||
|
err = ErrStmtNil |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
err = r.Row.Scan(dest...) |
||||||
|
if r.cancel != nil { |
||||||
|
r.cancel() |
||||||
|
} |
||||||
|
r.db.onBreaker(&err) |
||||||
|
if err != ErrNoRows { |
||||||
|
err = errors.Wrapf(err, "addr: %s, query %s args %+v", r.db.addr, r.query, r.args) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Rows rows.
|
||||||
|
type Rows struct { |
||||||
|
*sql.Rows |
||||||
|
cancel func() |
||||||
|
} |
||||||
|
|
||||||
|
// Close closes the Rows, preventing further enumeration. If Next is called
|
||||||
|
// and returns false and there are no further result sets,
|
||||||
|
// the Rows are closed automatically and it will suffice to check the
|
||||||
|
// result of Err. Close is idempotent and does not affect the result of Err.
|
||||||
|
func (rs *Rows) Close() (err error) { |
||||||
|
err = errors.WithStack(rs.Rows.Close()) |
||||||
|
if rs.cancel != nil { |
||||||
|
rs.cancel() |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Stmt prepared stmt.
|
||||||
|
type Stmt struct { |
||||||
|
db *conn |
||||||
|
tx bool |
||||||
|
query string |
||||||
|
stmt atomic.Value |
||||||
|
t trace.Trace |
||||||
|
} |
||||||
|
|
||||||
|
// Stmts random prepared stmt.
|
||||||
|
type Stmts struct { |
||||||
|
query string |
||||||
|
sts map[string]*Stmt |
||||||
|
mu sync.RWMutex |
||||||
|
db *DB |
||||||
|
} |
||||||
|
|
||||||
|
// Open opens a database specified by its database driver name and a
|
||||||
|
// driver-specific data source name, usually consisting of at least a database
|
||||||
|
// name and connection information.
|
||||||
|
func Open(c *Config) (db *DB, err error) { |
||||||
|
db = &DB{conf: c, breakerGroup: breaker.NewGroup(c.Breaker)} |
||||||
|
cfg, err := mysql.ParseDSN(c.DSN) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
var dsns []string |
||||||
|
if cfg.Net == "discovery" { |
||||||
|
db.appid = cfg.Addr |
||||||
|
for _, addr := range db.disc() { |
||||||
|
dsns = append(dsns, genDSN(c.DSN, addr)) |
||||||
|
} |
||||||
|
} else { |
||||||
|
dsns = append(dsns, c.DSN) |
||||||
|
} |
||||||
|
|
||||||
|
cs := make([]*conn, 0, len(dsns)) |
||||||
|
for _, dsn := range dsns { |
||||||
|
r, err := db.connectDSN(dsn) |
||||||
|
if err != nil { |
||||||
|
return db, err |
||||||
|
} |
||||||
|
cs = append(cs, r) |
||||||
|
} |
||||||
|
db.conns = cs |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (db *DB) connectDSN(dsn string) (c *conn, err error) { |
||||||
|
d, err := connect(db.conf, dsn) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
addr := parseDSNAddr(dsn) |
||||||
|
brk := db.breakerGroup.Get(addr) |
||||||
|
c = &conn{DB: d, breaker: brk, conf: db.conf, addr: addr} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func connect(c *Config, dataSourceName string) (*sql.DB, error) { |
||||||
|
d, err := sql.Open("mysql", dataSourceName) |
||||||
|
if err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
d.SetMaxOpenConns(c.Active) |
||||||
|
d.SetMaxIdleConns(c.Idle) |
||||||
|
d.SetConnMaxLifetime(time.Duration(c.IdleTimeout)) |
||||||
|
return d, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (db *DB) conn() (c *conn) { |
||||||
|
db.mutex.RLock() |
||||||
|
c = db.conns[db.index()] |
||||||
|
db.mutex.RUnlock() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Begin starts a transaction. The isolation level is dependent on the driver.
|
||||||
|
func (db *DB) Begin(c context.Context) (tx *Tx, err error) { |
||||||
|
return db.conn().begin(c) |
||||||
|
} |
||||||
|
|
||||||
|
// Exec executes a query without returning any rows.
|
||||||
|
// The args are for any placeholder parameters in the query.
|
||||||
|
func (db *DB) Exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) { |
||||||
|
return db.conn().exec(c, query, args...) |
||||||
|
} |
||||||
|
|
||||||
|
// Prepare creates a prepared statement for later queries or executions.
|
||||||
|
// Multiple queries or executions may be run concurrently from the returned
|
||||||
|
// statement. The caller must call the statement's Close method when the
|
||||||
|
// statement is no longer needed.
|
||||||
|
func (db *DB) Prepare(query string) (*Stmt, error) { |
||||||
|
return db.conn().prepare(query) |
||||||
|
} |
||||||
|
|
||||||
|
// Prepared creates a prepared statement for later queries or executions.
|
||||||
|
// Multiple queries or executions may be run concurrently from the returned
|
||||||
|
// statement. The caller must call the statement's Close method when the
|
||||||
|
// statement is no longer needed.
|
||||||
|
func (db *DB) Prepared(query string) (s *Stmts) { |
||||||
|
s = &Stmts{query: query, sts: make(map[string]*Stmt), db: db} |
||||||
|
for _, c := range db.conns { |
||||||
|
st := c.prepared(query) |
||||||
|
s.mu.Lock() |
||||||
|
s.sts[c.addr] = st |
||||||
|
s.mu.Unlock() |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Query executes a query that returns rows, typically a SELECT. The args are
|
||||||
|
// for any placeholder parameters in the query.
|
||||||
|
func (db *DB) Query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) { |
||||||
|
return db.conn().query(c, query, args...) |
||||||
|
} |
||||||
|
|
||||||
|
// QueryRow executes a query that is expected to return at most one row.
|
||||||
|
// QueryRow always returns a non-nil value. Errors are deferred until Row's
|
||||||
|
// Scan method is called.
|
||||||
|
func (db *DB) QueryRow(c context.Context, query string, args ...interface{}) *Row { |
||||||
|
return db.conn().queryRow(c, query, args...) |
||||||
|
} |
||||||
|
|
||||||
|
func (db *DB) index() int { |
||||||
|
if len(db.conns) == 1 { |
||||||
|
return 0 |
||||||
|
} |
||||||
|
v := atomic.AddInt64(&db.idx, 1) |
||||||
|
return int(v) % len(db.conns) |
||||||
|
} |
||||||
|
|
||||||
|
// Close closes the databases, releasing any open resources.
|
||||||
|
func (db *DB) Close() (err error) { |
||||||
|
db.mutex.RLock() |
||||||
|
defer db.mutex.RUnlock() |
||||||
|
for _, d := range db.conns { |
||||||
|
if e := d.Close(); e != nil { |
||||||
|
err = errors.WithStack(e) |
||||||
|
} |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Ping verifies a connection to the database is still alive, establishing a
|
||||||
|
// connection if necessary.
|
||||||
|
func (db *DB) Ping(c context.Context) (err error) { |
||||||
|
if err = db.conn().ping(c); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) onBreaker(err *error) { |
||||||
|
if err != nil && *err != nil && *err != sql.ErrNoRows && *err != sql.ErrTxDone { |
||||||
|
db.breaker.MarkFailed() |
||||||
|
} else { |
||||||
|
db.breaker.MarkSuccess() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) begin(c context.Context) (tx *Tx, err error) { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("Begin addr: %s", db.addr), now) |
||||||
|
t, ok := trace.FromContext(c) |
||||||
|
if ok { |
||||||
|
t = t.Fork(_family, "begin") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, "")) |
||||||
|
defer func() { |
||||||
|
if err != nil { |
||||||
|
t.Finish(&err) |
||||||
|
} |
||||||
|
}() |
||||||
|
} |
||||||
|
if err = db.breaker.Allow(); err != nil { |
||||||
|
stats.Incr("tidb:begin", "breaker") |
||||||
|
return |
||||||
|
} |
||||||
|
_, c, cancel := db.conf.TranTimeout.Shrink(c) |
||||||
|
rtx, err := db.BeginTx(c, nil) |
||||||
|
stats.Timing("tidb:begin", int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
cancel() |
||||||
|
return |
||||||
|
} |
||||||
|
tx = &Tx{tx: rtx, t: t, db: db, c: c, cancel: cancel} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", db.addr, query, args), now) |
||||||
|
if t, ok := trace.FromContext(c); ok { |
||||||
|
t = t.Fork(_family, "exec") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query)) |
||||||
|
defer t.Finish(&err) |
||||||
|
} |
||||||
|
if err = db.breaker.Allow(); err != nil { |
||||||
|
stats.Incr("tidb:exec", "breaker") |
||||||
|
return |
||||||
|
} |
||||||
|
_, c, cancel := db.conf.ExecTimeout.Shrink(c) |
||||||
|
res, err = db.ExecContext(c, query, args...) |
||||||
|
cancel() |
||||||
|
db.onBreaker(&err) |
||||||
|
stats.Timing("tidb:exec", int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", db.addr, query, args) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) ping(c context.Context) (err error) { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("Ping addr: %s", db.addr), now) |
||||||
|
if t, ok := trace.FromContext(c); ok { |
||||||
|
t = t.Fork(_family, "ping") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, "")) |
||||||
|
defer t.Finish(&err) |
||||||
|
} |
||||||
|
if err = db.breaker.Allow(); err != nil { |
||||||
|
stats.Incr("tidb:ping", "breaker") |
||||||
|
return |
||||||
|
} |
||||||
|
_, c, cancel := db.conf.ExecTimeout.Shrink(c) |
||||||
|
err = db.PingContext(c) |
||||||
|
cancel() |
||||||
|
db.onBreaker(&err) |
||||||
|
stats.Timing("tidb:ping", int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) prepare(query string) (*Stmt, error) { |
||||||
|
defer slowLog(fmt.Sprintf("Prepare addr: %s query(%s)", db.addr, query), time.Now()) |
||||||
|
stmt, err := db.Prepare(query) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "addr: %s prepare %s", db.addr, query) |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
st := &Stmt{query: query, db: db} |
||||||
|
st.stmt.Store(stmt) |
||||||
|
return st, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) prepared(query string) (stmt *Stmt) { |
||||||
|
defer slowLog(fmt.Sprintf("Prepared addr: %s query(%s)", db.addr, query), time.Now()) |
||||||
|
stmt = &Stmt{query: query, db: db} |
||||||
|
s, err := db.Prepare(query) |
||||||
|
if err == nil { |
||||||
|
stmt.stmt.Store(s) |
||||||
|
return |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", db.addr, query, args), now) |
||||||
|
if t, ok := trace.FromContext(c); ok { |
||||||
|
t = t.Fork(_family, "query") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query)) |
||||||
|
defer t.Finish(&err) |
||||||
|
} |
||||||
|
if err = db.breaker.Allow(); err != nil { |
||||||
|
stats.Incr("tidb:query", "breaker") |
||||||
|
return |
||||||
|
} |
||||||
|
_, c, cancel := db.conf.QueryTimeout.Shrink(c) |
||||||
|
rs, err := db.DB.QueryContext(c, query, args...) |
||||||
|
db.onBreaker(&err) |
||||||
|
stats.Timing("tidb:query", int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", db.addr, query, args) |
||||||
|
cancel() |
||||||
|
return |
||||||
|
} |
||||||
|
rows = &Rows{Rows: rs, cancel: cancel} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (db *conn) queryRow(c context.Context, query string, args ...interface{}) *Row { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", db.addr, query, args), now) |
||||||
|
t, ok := trace.FromContext(c) |
||||||
|
if ok { |
||||||
|
t = t.Fork(_family, "queryrow") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query)) |
||||||
|
} |
||||||
|
if err := db.breaker.Allow(); err != nil { |
||||||
|
stats.Incr("tidb:queryrow", "breaker") |
||||||
|
return &Row{db: db, t: t, err: err} |
||||||
|
} |
||||||
|
_, c, cancel := db.conf.QueryTimeout.Shrink(c) |
||||||
|
r := db.DB.QueryRowContext(c, query, args...) |
||||||
|
stats.Timing("tidb:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||||
|
return &Row{db: db, Row: r, query: query, args: args, t: t, cancel: cancel} |
||||||
|
} |
||||||
|
|
||||||
|
// Close closes the statement.
|
||||||
|
func (s *Stmt) Close() (err error) { |
||||||
|
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||||
|
if ok { |
||||||
|
err = errors.WithStack(stmt.Close()) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Stmt) prepare() (st *sql.Stmt) { |
||||||
|
var ok bool |
||||||
|
if st, ok = s.stmt.Load().(*sql.Stmt); ok { |
||||||
|
return |
||||||
|
} |
||||||
|
var err error |
||||||
|
if st, err = s.db.Prepare(s.query); err == nil { |
||||||
|
s.stmt.Store(st) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Exec executes a prepared statement with the given arguments and returns a
|
||||||
|
// Result summarizing the effect of the statement.
|
||||||
|
func (s *Stmt) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now) |
||||||
|
if s.tx { |
||||||
|
if s.t != nil { |
||||||
|
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||||
|
} |
||||||
|
} else if t, ok := trace.FromContext(c); ok { |
||||||
|
t = t.Fork(_family, "exec") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query)) |
||||||
|
defer t.Finish(&err) |
||||||
|
} |
||||||
|
if err = s.db.breaker.Allow(); err != nil { |
||||||
|
stats.Incr("tidb:stmt:exec", "breaker") |
||||||
|
return |
||||||
|
} |
||||||
|
stmt := s.prepare() |
||||||
|
if stmt == nil { |
||||||
|
err = ErrStmtNil |
||||||
|
return |
||||||
|
} |
||||||
|
_, c, cancel := s.db.conf.ExecTimeout.Shrink(c) |
||||||
|
res, err = stmt.ExecContext(c, args...) |
||||||
|
cancel() |
||||||
|
s.db.onBreaker(&err) |
||||||
|
stats.Timing("tidb:stmt:exec", int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", s.db.addr, s.query, args) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Query executes a prepared query statement with the given arguments and
|
||||||
|
// returns the query results as a *Rows.
|
||||||
|
func (s *Stmt) Query(c context.Context, args ...interface{}) (rows *Rows, err error) { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now) |
||||||
|
if s.tx { |
||||||
|
if s.t != nil { |
||||||
|
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||||
|
} |
||||||
|
} else if t, ok := trace.FromContext(c); ok { |
||||||
|
t = t.Fork(_family, "query") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query)) |
||||||
|
defer t.Finish(&err) |
||||||
|
} |
||||||
|
if err = s.db.breaker.Allow(); err != nil { |
||||||
|
stats.Incr("tidb:stmt:query", "breaker") |
||||||
|
return |
||||||
|
} |
||||||
|
stmt := s.prepare() |
||||||
|
if stmt == nil { |
||||||
|
err = ErrStmtNil |
||||||
|
return |
||||||
|
} |
||||||
|
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c) |
||||||
|
rs, err := stmt.QueryContext(c, args...) |
||||||
|
s.db.onBreaker(&err) |
||||||
|
stats.Timing("tidb:stmt:query", int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", s.db.addr, s.query, args) |
||||||
|
cancel() |
||||||
|
return |
||||||
|
} |
||||||
|
rows = &Rows{Rows: rs, cancel: cancel} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// QueryRow executes a prepared query statement with the given arguments.
|
||||||
|
// If an error occurs during the execution of the statement, that error will
|
||||||
|
// be returned by a call to Scan on the returned *Row, which is always non-nil.
|
||||||
|
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
|
||||||
|
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
|
||||||
|
func (s *Stmt) QueryRow(c context.Context, args ...interface{}) (row *Row) { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now) |
||||||
|
row = &Row{db: s.db, query: s.query, args: args} |
||||||
|
if s.tx { |
||||||
|
if s.t != nil { |
||||||
|
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||||
|
} |
||||||
|
} else if t, ok := trace.FromContext(c); ok { |
||||||
|
t = t.Fork(_family, "queryrow") |
||||||
|
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query)) |
||||||
|
row.t = t |
||||||
|
} |
||||||
|
if row.err = s.db.breaker.Allow(); row.err != nil { |
||||||
|
stats.Incr("tidb:stmt:queryrow", "breaker") |
||||||
|
return |
||||||
|
} |
||||||
|
stmt := s.prepare() |
||||||
|
if stmt == nil { |
||||||
|
return |
||||||
|
} |
||||||
|
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c) |
||||||
|
row.Row = stmt.QueryRowContext(c, args...) |
||||||
|
row.cancel = cancel |
||||||
|
stats.Timing("tidb:stmt:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Stmts) prepare(conn *conn) (st *Stmt) { |
||||||
|
if conn == nil { |
||||||
|
conn = s.db.conn() |
||||||
|
} |
||||||
|
s.mu.RLock() |
||||||
|
st = s.sts[conn.addr] |
||||||
|
s.mu.RUnlock() |
||||||
|
if st == nil { |
||||||
|
st = conn.prepared(s.query) |
||||||
|
s.mu.Lock() |
||||||
|
s.sts[conn.addr] = st |
||||||
|
s.mu.Unlock() |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Exec executes a prepared statement with the given arguments and returns a
|
||||||
|
// Result summarizing the effect of the statement.
|
||||||
|
func (s *Stmts) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) { |
||||||
|
return s.prepare(nil).Exec(c, args...) |
||||||
|
} |
||||||
|
|
||||||
|
// Query executes a prepared query statement with the given arguments and
|
||||||
|
// returns the query results as a *Rows.
|
||||||
|
func (s *Stmts) Query(c context.Context, args ...interface{}) (rows *Rows, err error) { |
||||||
|
return s.prepare(nil).Query(c, args...) |
||||||
|
} |
||||||
|
|
||||||
|
// QueryRow executes a prepared query statement with the given arguments.
|
||||||
|
// If an error occurs during the execution of the statement, that error will
|
||||||
|
// be returned by a call to Scan on the returned *Row, which is always non-nil.
|
||||||
|
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
|
||||||
|
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
|
||||||
|
func (s *Stmts) QueryRow(c context.Context, args ...interface{}) (row *Row) { |
||||||
|
return s.prepare(nil).QueryRow(c, args...) |
||||||
|
} |
||||||
|
|
||||||
|
// Close closes the statement.
|
||||||
|
func (s *Stmts) Close() (err error) { |
||||||
|
for _, st := range s.sts { |
||||||
|
if err = errors.WithStack(st.Close()); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Commit commits the transaction.
|
||||||
|
func (tx *Tx) Commit() (err error) { |
||||||
|
err = tx.tx.Commit() |
||||||
|
tx.cancel() |
||||||
|
tx.db.onBreaker(&err) |
||||||
|
if tx.t != nil { |
||||||
|
tx.t.Finish(&err) |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Rollback aborts the transaction.
|
||||||
|
func (tx *Tx) Rollback() (err error) { |
||||||
|
err = tx.tx.Rollback() |
||||||
|
tx.cancel() |
||||||
|
tx.db.onBreaker(&err) |
||||||
|
if tx.t != nil { |
||||||
|
tx.t.Finish(&err) |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Exec executes a query that doesn't return rows. For example: an INSERT and
|
||||||
|
// UPDATE.
|
||||||
|
func (tx *Tx) Exec(query string, args ...interface{}) (res sql.Result, err error) { |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now) |
||||||
|
if tx.t != nil { |
||||||
|
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("exec %s", query))) |
||||||
|
} |
||||||
|
res, err = tx.tx.ExecContext(tx.c, query, args...) |
||||||
|
stats.Timing("tidb:tx:exec", int64(time.Since(now)/time.Millisecond)) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", tx.db.addr, query, args) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Query executes a query that returns rows, typically a SELECT.
|
||||||
|
func (tx *Tx) Query(query string, args ...interface{}) (rows *Rows, err error) { |
||||||
|
if tx.t != nil { |
||||||
|
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("query %s", query))) |
||||||
|
} |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now) |
||||||
|
defer func() { |
||||||
|
stats.Timing("tidb:tx:query", int64(time.Since(now)/time.Millisecond)) |
||||||
|
}() |
||||||
|
rs, err := tx.tx.QueryContext(tx.c, query, args...) |
||||||
|
if err == nil { |
||||||
|
rows = &Rows{Rows: rs} |
||||||
|
} else { |
||||||
|
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", tx.db.addr, query, args) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// QueryRow executes a query that is expected to return at most one row.
|
||||||
|
// QueryRow always returns a non-nil value. Errors are deferred until Row's
|
||||||
|
// Scan method is called.
|
||||||
|
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { |
||||||
|
if tx.t != nil { |
||||||
|
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("queryrow %s", query))) |
||||||
|
} |
||||||
|
now := time.Now() |
||||||
|
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now) |
||||||
|
defer func() { |
||||||
|
stats.Timing("tidb:tx:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||||
|
}() |
||||||
|
r := tx.tx.QueryRowContext(tx.c, query, args...) |
||||||
|
return &Row{Row: r, db: tx.db, query: query, args: args} |
||||||
|
} |
||||||
|
|
||||||
|
// Stmt returns a transaction-specific prepared statement from an existing statement.
|
||||||
|
func (tx *Tx) Stmt(stmt *Stmt) *Stmt { |
||||||
|
if stmt == nil { |
||||||
|
return nil |
||||||
|
} |
||||||
|
as, ok := stmt.stmt.Load().(*sql.Stmt) |
||||||
|
if !ok { |
||||||
|
return nil |
||||||
|
} |
||||||
|
ts := tx.tx.StmtContext(tx.c, as) |
||||||
|
st := &Stmt{query: stmt.query, tx: true, t: tx.t, db: tx.db} |
||||||
|
st.stmt.Store(ts) |
||||||
|
return st |
||||||
|
} |
||||||
|
|
||||||
|
// Stmts returns a transaction-specific prepared statement from an existing statement.
|
||||||
|
func (tx *Tx) Stmts(stmt *Stmts) *Stmt { |
||||||
|
return tx.Stmt(stmt.prepare(tx.db)) |
||||||
|
} |
||||||
|
|
||||||
|
// Prepare creates a prepared statement for use within a transaction.
|
||||||
|
// The returned statement operates within the transaction and can no longer be
|
||||||
|
// used once the transaction has been committed or rolled back.
|
||||||
|
// To use an existing prepared statement on this transaction, see Tx.Stmt.
|
||||||
|
func (tx *Tx) Prepare(query string) (*Stmt, error) { |
||||||
|
if tx.t != nil { |
||||||
|
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("prepare %s", query))) |
||||||
|
} |
||||||
|
defer slowLog(fmt.Sprintf("Prepare addr: %s query(%s)", tx.db.addr, query), time.Now()) |
||||||
|
stmt, err := tx.tx.Prepare(query) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "addr: %s prepare %s", tx.db.addr, query) |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
st := &Stmt{query: query, tx: true, t: tx.t, db: tx.db} |
||||||
|
st.stmt.Store(stmt) |
||||||
|
return st, nil |
||||||
|
} |
||||||
|
|
||||||
|
// parseDSNAddr parse dsn name and return addr.
|
||||||
|
func parseDSNAddr(dsn string) (addr string) { |
||||||
|
if dsn == "" { |
||||||
|
return |
||||||
|
} |
||||||
|
cfg, err := mysql.ParseDSN(dsn) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
addr = cfg.Addr |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func genDSN(dsn, addr string) (res string) { |
||||||
|
cfg, err := mysql.ParseDSN(dsn) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
cfg.Addr = addr |
||||||
|
cfg.Net = "tcp" |
||||||
|
res = cfg.FormatDSN() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func slowLog(statement string, now time.Time) { |
||||||
|
du := time.Since(now) |
||||||
|
if du > _slowLogDuration { |
||||||
|
log.Warn("%s slow log statement: %s time: %v", _family, statement, du) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,38 @@ |
|||||||
|
package tidb |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||||
|
"github.com/bilibili/kratos/pkg/stat" |
||||||
|
"github.com/bilibili/kratos/pkg/time" |
||||||
|
|
||||||
|
// database driver
|
||||||
|
_ "github.com/go-sql-driver/mysql" |
||||||
|
) |
||||||
|
|
||||||
|
var stats = stat.DB |
||||||
|
|
||||||
|
// Config mysql config.
|
||||||
|
type Config struct { |
||||||
|
DSN string // dsn
|
||||||
|
Active int // pool
|
||||||
|
Idle int // pool
|
||||||
|
IdleTimeout time.Duration // connect max life time.
|
||||||
|
QueryTimeout time.Duration // query sql timeout
|
||||||
|
ExecTimeout time.Duration // execute sql timeout
|
||||||
|
TranTimeout time.Duration // transaction sql timeout
|
||||||
|
Breaker *breaker.Config // breaker
|
||||||
|
} |
||||||
|
|
||||||
|
// NewTiDB new db and retry connection when has error.
|
||||||
|
func NewTiDB(c *Config) (db *DB) { |
||||||
|
if c.QueryTimeout == 0 || c.ExecTimeout == 0 || c.TranTimeout == 0 { |
||||||
|
panic("tidb must be set query/execute/transction timeout") |
||||||
|
} |
||||||
|
db, err := Open(c) |
||||||
|
if err != nil { |
||||||
|
log.Error("open tidb error(%v)", err) |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
@ -0,0 +1,85 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
"strings" |
||||||
|
|
||||||
|
"gopkg.in/go-playground/validator.v9" |
||||||
|
) |
||||||
|
|
||||||
|
// MIME
|
||||||
|
const ( |
||||||
|
MIMEJSON = "application/json" |
||||||
|
MIMEHTML = "text/html" |
||||||
|
MIMEXML = "application/xml" |
||||||
|
MIMEXML2 = "text/xml" |
||||||
|
MIMEPlain = "text/plain" |
||||||
|
MIMEPOSTForm = "application/x-www-form-urlencoded" |
||||||
|
MIMEMultipartPOSTForm = "multipart/form-data" |
||||||
|
) |
||||||
|
|
||||||
|
// Binding http binding request interface.
|
||||||
|
type Binding interface { |
||||||
|
Name() string |
||||||
|
Bind(*http.Request, interface{}) error |
||||||
|
} |
||||||
|
|
||||||
|
// StructValidator http validator interface.
|
||||||
|
type StructValidator interface { |
||||||
|
// ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right.
|
||||||
|
// If the received type is not a struct, any validation should be skipped and nil must be returned.
|
||||||
|
// If the received type is a struct or pointer to a struct, the validation should be performed.
|
||||||
|
// If the struct is not valid or the validation itself fails, a descriptive error should be returned.
|
||||||
|
// Otherwise nil must be returned.
|
||||||
|
ValidateStruct(interface{}) error |
||||||
|
|
||||||
|
// RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key
|
||||||
|
// NOTE: if the key already exists, the previous validation function will be replaced.
|
||||||
|
// NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
|
||||||
|
RegisterValidation(string, validator.Func) error |
||||||
|
} |
||||||
|
|
||||||
|
// Validator default validator.
|
||||||
|
var Validator StructValidator = &defaultValidator{} |
||||||
|
|
||||||
|
// Binding
|
||||||
|
var ( |
||||||
|
JSON = jsonBinding{} |
||||||
|
XML = xmlBinding{} |
||||||
|
Form = formBinding{} |
||||||
|
Query = queryBinding{} |
||||||
|
FormPost = formPostBinding{} |
||||||
|
FormMultipart = formMultipartBinding{} |
||||||
|
) |
||||||
|
|
||||||
|
// Default get by binding type by method and contexttype.
|
||||||
|
func Default(method, contentType string) Binding { |
||||||
|
if method == "GET" { |
||||||
|
return Form |
||||||
|
} |
||||||
|
|
||||||
|
contentType = stripContentTypeParam(contentType) |
||||||
|
switch contentType { |
||||||
|
case MIMEJSON: |
||||||
|
return JSON |
||||||
|
case MIMEXML, MIMEXML2: |
||||||
|
return XML |
||||||
|
default: //case MIMEPOSTForm, MIMEMultipartPOSTForm:
|
||||||
|
return Form |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func validate(obj interface{}) error { |
||||||
|
if Validator == nil { |
||||||
|
return nil |
||||||
|
} |
||||||
|
return Validator.ValidateStruct(obj) |
||||||
|
} |
||||||
|
|
||||||
|
func stripContentTypeParam(contentType string) string { |
||||||
|
i := strings.Index(contentType, ";") |
||||||
|
if i != -1 { |
||||||
|
contentType = contentType[:i] |
||||||
|
} |
||||||
|
return contentType |
||||||
|
} |
@ -0,0 +1,342 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"mime/multipart" |
||||||
|
"net/http" |
||||||
|
"testing" |
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert" |
||||||
|
) |
||||||
|
|
||||||
|
type FooStruct struct { |
||||||
|
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" validate:"required"` |
||||||
|
} |
||||||
|
|
||||||
|
type FooBarStruct struct { |
||||||
|
FooStruct |
||||||
|
Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" validate:"required"` |
||||||
|
Slice []string `form:"slice" validate:"max=10"` |
||||||
|
} |
||||||
|
|
||||||
|
type ComplexDefaultStruct struct { |
||||||
|
Int int `form:"int" default:"999"` |
||||||
|
String string `form:"string" default:"default-string"` |
||||||
|
Bool bool `form:"bool" default:"false"` |
||||||
|
Int64Slice []int64 `form:"int64_slice,split" default:"1,2,3,4"` |
||||||
|
Int8Slice []int8 `form:"int8_slice,split" default:"1,2,3,4"` |
||||||
|
} |
||||||
|
|
||||||
|
type Int8SliceStruct struct { |
||||||
|
State []int8 `form:"state,split"` |
||||||
|
} |
||||||
|
|
||||||
|
type Int64SliceStruct struct { |
||||||
|
State []int64 `form:"state,split"` |
||||||
|
} |
||||||
|
|
||||||
|
type StringSliceStruct struct { |
||||||
|
State []string `form:"state,split"` |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingDefault(t *testing.T) { |
||||||
|
assert.Equal(t, Default("GET", ""), Form) |
||||||
|
assert.Equal(t, Default("GET", MIMEJSON), Form) |
||||||
|
assert.Equal(t, Default("GET", MIMEJSON+"; charset=utf-8"), Form) |
||||||
|
|
||||||
|
assert.Equal(t, Default("POST", MIMEJSON), JSON) |
||||||
|
assert.Equal(t, Default("PUT", MIMEJSON), JSON) |
||||||
|
|
||||||
|
assert.Equal(t, Default("POST", MIMEJSON+"; charset=utf-8"), JSON) |
||||||
|
assert.Equal(t, Default("PUT", MIMEJSON+"; charset=utf-8"), JSON) |
||||||
|
|
||||||
|
assert.Equal(t, Default("POST", MIMEXML), XML) |
||||||
|
assert.Equal(t, Default("PUT", MIMEXML2), XML) |
||||||
|
|
||||||
|
assert.Equal(t, Default("POST", MIMEPOSTForm), Form) |
||||||
|
assert.Equal(t, Default("PUT", MIMEPOSTForm), Form) |
||||||
|
|
||||||
|
assert.Equal(t, Default("POST", MIMEPOSTForm+"; charset=utf-8"), Form) |
||||||
|
assert.Equal(t, Default("PUT", MIMEPOSTForm+"; charset=utf-8"), Form) |
||||||
|
|
||||||
|
assert.Equal(t, Default("POST", MIMEMultipartPOSTForm), Form) |
||||||
|
assert.Equal(t, Default("PUT", MIMEMultipartPOSTForm), Form) |
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
func TestStripContentType(t *testing.T) { |
||||||
|
c1 := "application/vnd.mozilla.xul+xml" |
||||||
|
c2 := "application/vnd.mozilla.xul+xml; charset=utf-8" |
||||||
|
assert.Equal(t, stripContentTypeParam(c1), c1) |
||||||
|
assert.Equal(t, stripContentTypeParam(c2), "application/vnd.mozilla.xul+xml") |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindInt8Form(t *testing.T) { |
||||||
|
params := "state=1,2,3" |
||||||
|
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q := new(Int8SliceStruct) |
||||||
|
Form.Bind(req, q) |
||||||
|
assert.EqualValues(t, []int8{1, 2, 3}, q.State) |
||||||
|
|
||||||
|
params = "state=1,2,3,256" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(Int8SliceStruct) |
||||||
|
assert.Error(t, Form.Bind(req, q)) |
||||||
|
|
||||||
|
params = "state=" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(Int8SliceStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.Len(t, q.State, 0) |
||||||
|
|
||||||
|
params = "state=1,,2" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(Int8SliceStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.EqualValues(t, []int8{1, 2}, q.State) |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindInt64Form(t *testing.T) { |
||||||
|
params := "state=1,2,3" |
||||||
|
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q := new(Int64SliceStruct) |
||||||
|
Form.Bind(req, q) |
||||||
|
assert.EqualValues(t, []int64{1, 2, 3}, q.State) |
||||||
|
|
||||||
|
params = "state=" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(Int64SliceStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.Len(t, q.State, 0) |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindStringForm(t *testing.T) { |
||||||
|
params := "state=1,2,3" |
||||||
|
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q := new(StringSliceStruct) |
||||||
|
Form.Bind(req, q) |
||||||
|
assert.EqualValues(t, []string{"1", "2", "3"}, q.State) |
||||||
|
|
||||||
|
params = "state=" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(StringSliceStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.Len(t, q.State, 0) |
||||||
|
|
||||||
|
params = "state=p,,p" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(StringSliceStruct) |
||||||
|
Form.Bind(req, q) |
||||||
|
assert.EqualValues(t, []string{"p", "p"}, q.State) |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingJSON(t *testing.T) { |
||||||
|
testBodyBinding(t, |
||||||
|
JSON, "json", |
||||||
|
"/", "/", |
||||||
|
`{"foo": "bar"}`, `{"bar": "foo"}`) |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingForm(t *testing.T) { |
||||||
|
testFormBinding(t, "POST", |
||||||
|
"/", "/", |
||||||
|
"foo=bar&bar=foo&slice=a&slice=b", "bar2=foo") |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingForm2(t *testing.T) { |
||||||
|
testFormBinding(t, "GET", |
||||||
|
"/?foo=bar&bar=foo", "/?bar2=foo", |
||||||
|
"", "") |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingQuery(t *testing.T) { |
||||||
|
testQueryBinding(t, "POST", |
||||||
|
"/?foo=bar&bar=foo", "/", |
||||||
|
"foo=unused", "bar2=foo") |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingQuery2(t *testing.T) { |
||||||
|
testQueryBinding(t, "GET", |
||||||
|
"/?foo=bar&bar=foo", "/?bar2=foo", |
||||||
|
"foo=unused", "") |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingXML(t *testing.T) { |
||||||
|
testBodyBinding(t, |
||||||
|
XML, "xml", |
||||||
|
"/", "/", |
||||||
|
"<map><foo>bar</foo></map>", "<map><bar>foo</bar></map>") |
||||||
|
} |
||||||
|
|
||||||
|
func createFormPostRequest() *http.Request { |
||||||
|
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo")) |
||||||
|
req.Header.Set("Content-Type", MIMEPOSTForm) |
||||||
|
return req |
||||||
|
} |
||||||
|
|
||||||
|
func createFormMultipartRequest() *http.Request { |
||||||
|
boundary := "--testboundary" |
||||||
|
body := new(bytes.Buffer) |
||||||
|
mw := multipart.NewWriter(body) |
||||||
|
defer mw.Close() |
||||||
|
|
||||||
|
mw.SetBoundary(boundary) |
||||||
|
mw.WriteField("foo", "bar") |
||||||
|
mw.WriteField("bar", "foo") |
||||||
|
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) |
||||||
|
req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) |
||||||
|
return req |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingFormPost(t *testing.T) { |
||||||
|
req := createFormPostRequest() |
||||||
|
var obj FooBarStruct |
||||||
|
FormPost.Bind(req, &obj) |
||||||
|
|
||||||
|
assert.Equal(t, obj.Foo, "bar") |
||||||
|
assert.Equal(t, obj.Bar, "foo") |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingFormMultipart(t *testing.T) { |
||||||
|
req := createFormMultipartRequest() |
||||||
|
var obj FooBarStruct |
||||||
|
FormMultipart.Bind(req, &obj) |
||||||
|
|
||||||
|
assert.Equal(t, obj.Foo, "bar") |
||||||
|
assert.Equal(t, obj.Bar, "foo") |
||||||
|
} |
||||||
|
|
||||||
|
func TestValidationFails(t *testing.T) { |
||||||
|
var obj FooStruct |
||||||
|
req := requestWithBody("POST", "/", `{"bar": "foo"}`) |
||||||
|
err := JSON.Bind(req, &obj) |
||||||
|
assert.Error(t, err) |
||||||
|
} |
||||||
|
|
||||||
|
func TestValidationDisabled(t *testing.T) { |
||||||
|
backup := Validator |
||||||
|
Validator = nil |
||||||
|
defer func() { Validator = backup }() |
||||||
|
|
||||||
|
var obj FooStruct |
||||||
|
req := requestWithBody("POST", "/", `{"bar": "foo"}`) |
||||||
|
err := JSON.Bind(req, &obj) |
||||||
|
assert.NoError(t, err) |
||||||
|
} |
||||||
|
|
||||||
|
func TestExistsSucceeds(t *testing.T) { |
||||||
|
type HogeStruct struct { |
||||||
|
Hoge *int `json:"hoge" binding:"exists"` |
||||||
|
} |
||||||
|
|
||||||
|
var obj HogeStruct |
||||||
|
req := requestWithBody("POST", "/", `{"hoge": 0}`) |
||||||
|
err := JSON.Bind(req, &obj) |
||||||
|
assert.NoError(t, err) |
||||||
|
} |
||||||
|
|
||||||
|
func TestFormDefaultValue(t *testing.T) { |
||||||
|
params := "int=333&string=hello&bool=true&int64_slice=5,6,7,8&int8_slice=5,6,7,8" |
||||||
|
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q := new(ComplexDefaultStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.Equal(t, 333, q.Int) |
||||||
|
assert.Equal(t, "hello", q.String) |
||||||
|
assert.Equal(t, true, q.Bool) |
||||||
|
assert.EqualValues(t, []int64{5, 6, 7, 8}, q.Int64Slice) |
||||||
|
assert.EqualValues(t, []int8{5, 6, 7, 8}, q.Int8Slice) |
||||||
|
|
||||||
|
params = "string=hello&bool=false" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(ComplexDefaultStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.Equal(t, 999, q.Int) |
||||||
|
assert.Equal(t, "hello", q.String) |
||||||
|
assert.Equal(t, false, q.Bool) |
||||||
|
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) |
||||||
|
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) |
||||||
|
|
||||||
|
params = "strings=hello" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(ComplexDefaultStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.Equal(t, 999, q.Int) |
||||||
|
assert.Equal(t, "default-string", q.String) |
||||||
|
assert.Equal(t, false, q.Bool) |
||||||
|
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) |
||||||
|
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) |
||||||
|
|
||||||
|
params = "int=&string=&bool=true&int64_slice=&int8_slice=" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(ComplexDefaultStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.Equal(t, 999, q.Int) |
||||||
|
assert.Equal(t, "default-string", q.String) |
||||||
|
assert.Equal(t, true, q.Bool) |
||||||
|
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) |
||||||
|
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) |
||||||
|
} |
||||||
|
|
||||||
|
func testFormBinding(t *testing.T, method, path, badPath, body, badBody string) { |
||||||
|
b := Form |
||||||
|
assert.Equal(t, b.Name(), "form") |
||||||
|
|
||||||
|
obj := FooBarStruct{} |
||||||
|
req := requestWithBody(method, path, body) |
||||||
|
if method == "POST" { |
||||||
|
req.Header.Add("Content-Type", MIMEPOSTForm) |
||||||
|
} |
||||||
|
err := b.Bind(req, &obj) |
||||||
|
assert.NoError(t, err) |
||||||
|
assert.Equal(t, obj.Foo, "bar") |
||||||
|
assert.Equal(t, obj.Bar, "foo") |
||||||
|
|
||||||
|
obj = FooBarStruct{} |
||||||
|
req = requestWithBody(method, badPath, badBody) |
||||||
|
err = JSON.Bind(req, &obj) |
||||||
|
assert.Error(t, err) |
||||||
|
} |
||||||
|
|
||||||
|
func testQueryBinding(t *testing.T, method, path, badPath, body, badBody string) { |
||||||
|
b := Query |
||||||
|
assert.Equal(t, b.Name(), "query") |
||||||
|
|
||||||
|
obj := FooBarStruct{} |
||||||
|
req := requestWithBody(method, path, body) |
||||||
|
if method == "POST" { |
||||||
|
req.Header.Add("Content-Type", MIMEPOSTForm) |
||||||
|
} |
||||||
|
err := b.Bind(req, &obj) |
||||||
|
assert.NoError(t, err) |
||||||
|
assert.Equal(t, obj.Foo, "bar") |
||||||
|
assert.Equal(t, obj.Bar, "foo") |
||||||
|
} |
||||||
|
|
||||||
|
func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) { |
||||||
|
assert.Equal(t, b.Name(), name) |
||||||
|
|
||||||
|
obj := FooStruct{} |
||||||
|
req := requestWithBody("POST", path, body) |
||||||
|
err := b.Bind(req, &obj) |
||||||
|
assert.NoError(t, err) |
||||||
|
assert.Equal(t, obj.Foo, "bar") |
||||||
|
|
||||||
|
obj = FooStruct{} |
||||||
|
req = requestWithBody("POST", badPath, badBody) |
||||||
|
err = JSON.Bind(req, &obj) |
||||||
|
assert.Error(t, err) |
||||||
|
} |
||||||
|
|
||||||
|
func requestWithBody(method, path, body string) (req *http.Request) { |
||||||
|
req, _ = http.NewRequest(method, path, bytes.NewBufferString(body)) |
||||||
|
return |
||||||
|
} |
||||||
|
func BenchmarkBindingForm(b *testing.B) { |
||||||
|
req := requestWithBody("POST", "/", "foo=bar&bar=foo&slice=a&slice=b&slice=c&slice=w") |
||||||
|
req.Header.Add("Content-Type", MIMEPOSTForm) |
||||||
|
f := Form |
||||||
|
for i := 0; i < b.N; i++ { |
||||||
|
obj := FooBarStruct{} |
||||||
|
f.Bind(req, &obj) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,45 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"reflect" |
||||||
|
"sync" |
||||||
|
|
||||||
|
"gopkg.in/go-playground/validator.v9" |
||||||
|
) |
||||||
|
|
||||||
|
type defaultValidator struct { |
||||||
|
once sync.Once |
||||||
|
validate *validator.Validate |
||||||
|
} |
||||||
|
|
||||||
|
var _ StructValidator = &defaultValidator{} |
||||||
|
|
||||||
|
func (v *defaultValidator) ValidateStruct(obj interface{}) error { |
||||||
|
if kindOfData(obj) == reflect.Struct { |
||||||
|
v.lazyinit() |
||||||
|
if err := v.validate.Struct(obj); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (v *defaultValidator) RegisterValidation(key string, fn validator.Func) error { |
||||||
|
v.lazyinit() |
||||||
|
return v.validate.RegisterValidation(key, fn) |
||||||
|
} |
||||||
|
|
||||||
|
func (v *defaultValidator) lazyinit() { |
||||||
|
v.once.Do(func() { |
||||||
|
v.validate = validator.New() |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func kindOfData(data interface{}) reflect.Kind { |
||||||
|
value := reflect.ValueOf(data) |
||||||
|
valueType := value.Kind() |
||||||
|
if valueType == reflect.Ptr { |
||||||
|
valueType = value.Elem().Kind() |
||||||
|
} |
||||||
|
return valueType |
||||||
|
} |
@ -0,0 +1,113 @@ |
|||||||
|
// Code generated by protoc-gen-go.
|
||||||
|
// source: test.proto
|
||||||
|
// DO NOT EDIT!
|
||||||
|
|
||||||
|
/* |
||||||
|
Package example is a generated protocol buffer package. |
||||||
|
|
||||||
|
It is generated from these files: |
||||||
|
test.proto |
||||||
|
|
||||||
|
It has these top-level messages: |
||||||
|
Test |
||||||
|
*/ |
||||||
|
package example |
||||||
|
|
||||||
|
import proto "github.com/golang/protobuf/proto" |
||||||
|
import math "math" |
||||||
|
|
||||||
|
// Reference imports to suppress errors if they are not otherwise used.
|
||||||
|
var _ = proto.Marshal |
||||||
|
var _ = math.Inf |
||||||
|
|
||||||
|
type FOO int32 |
||||||
|
|
||||||
|
const ( |
||||||
|
FOO_X FOO = 17 |
||||||
|
) |
||||||
|
|
||||||
|
var FOO_name = map[int32]string{ |
||||||
|
17: "X", |
||||||
|
} |
||||||
|
var FOO_value = map[string]int32{ |
||||||
|
"X": 17, |
||||||
|
} |
||||||
|
|
||||||
|
func (x FOO) Enum() *FOO { |
||||||
|
p := new(FOO) |
||||||
|
*p = x |
||||||
|
return p |
||||||
|
} |
||||||
|
func (x FOO) String() string { |
||||||
|
return proto.EnumName(FOO_name, int32(x)) |
||||||
|
} |
||||||
|
func (x *FOO) UnmarshalJSON(data []byte) error { |
||||||
|
value, err := proto.UnmarshalJSONEnum(FOO_value, data, "FOO") |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
*x = FOO(value) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
type Test struct { |
||||||
|
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"` |
||||||
|
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"` |
||||||
|
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"` |
||||||
|
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"` |
||||||
|
XXX_unrecognized []byte `json:"-"` |
||||||
|
} |
||||||
|
|
||||||
|
func (m *Test) Reset() { *m = Test{} } |
||||||
|
func (m *Test) String() string { return proto.CompactTextString(m) } |
||||||
|
func (*Test) ProtoMessage() {} |
||||||
|
|
||||||
|
const Default_Test_Type int32 = 77 |
||||||
|
|
||||||
|
func (m *Test) GetLabel() string { |
||||||
|
if m != nil && m.Label != nil { |
||||||
|
return *m.Label |
||||||
|
} |
||||||
|
return "" |
||||||
|
} |
||||||
|
|
||||||
|
func (m *Test) GetType() int32 { |
||||||
|
if m != nil && m.Type != nil { |
||||||
|
return *m.Type |
||||||
|
} |
||||||
|
return Default_Test_Type |
||||||
|
} |
||||||
|
|
||||||
|
func (m *Test) GetReps() []int64 { |
||||||
|
if m != nil { |
||||||
|
return m.Reps |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (m *Test) GetOptionalgroup() *Test_OptionalGroup { |
||||||
|
if m != nil { |
||||||
|
return m.Optionalgroup |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
type Test_OptionalGroup struct { |
||||||
|
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"` |
||||||
|
XXX_unrecognized []byte `json:"-"` |
||||||
|
} |
||||||
|
|
||||||
|
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} } |
||||||
|
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) } |
||||||
|
func (*Test_OptionalGroup) ProtoMessage() {} |
||||||
|
|
||||||
|
func (m *Test_OptionalGroup) GetRequiredField() string { |
||||||
|
if m != nil && m.RequiredField != nil { |
||||||
|
return *m.RequiredField |
||||||
|
} |
||||||
|
return "" |
||||||
|
} |
||||||
|
|
||||||
|
func init() { |
||||||
|
proto.RegisterEnum("example.FOO", FOO_name, FOO_value) |
||||||
|
} |
@ -0,0 +1,12 @@ |
|||||||
|
package example; |
||||||
|
|
||||||
|
enum FOO {X=17;}; |
||||||
|
|
||||||
|
message Test { |
||||||
|
required string label = 1; |
||||||
|
optional int32 type = 2[default=77]; |
||||||
|
repeated int64 reps = 3; |
||||||
|
optional group OptionalGroup = 4{ |
||||||
|
required string RequiredField = 5; |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,36 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"log" |
||||||
|
"net/http" |
||||||
|
) |
||||||
|
|
||||||
|
type Arg struct { |
||||||
|
Max int64 `form:"max" validate:"max=10"` |
||||||
|
Min int64 `form:"min" validate:"min=2"` |
||||||
|
Range int64 `form:"range" validate:"min=1,max=10"` |
||||||
|
// use split option to split arg 1,2,3 into slice [1 2 3]
|
||||||
|
// otherwise slice type with parse url.Values (eg:a=b&a=c) default.
|
||||||
|
Slice []int64 `form:"slice,split" validate:"min=1"` |
||||||
|
} |
||||||
|
|
||||||
|
func ExampleBinding() { |
||||||
|
req := initHTTP("max=9&min=3&range=3&slice=1,2,3") |
||||||
|
arg := new(Arg) |
||||||
|
if err := Form.Bind(req, arg); err != nil { |
||||||
|
log.Fatal(err) |
||||||
|
} |
||||||
|
fmt.Printf("arg.Max %d\narg.Min %d\narg.Range %d\narg.Slice %v", arg.Max, arg.Min, arg.Range, arg.Slice) |
||||||
|
// Output:
|
||||||
|
// arg.Max 9
|
||||||
|
// arg.Min 3
|
||||||
|
// arg.Range 3
|
||||||
|
// arg.Slice [1 2 3]
|
||||||
|
} |
||||||
|
|
||||||
|
func initHTTP(params string) (req *http.Request) { |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
req.ParseForm() |
||||||
|
return |
||||||
|
} |
@ -0,0 +1,55 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
const defaultMemory = 32 * 1024 * 1024 |
||||||
|
|
||||||
|
type formBinding struct{} |
||||||
|
type formPostBinding struct{} |
||||||
|
type formMultipartBinding struct{} |
||||||
|
|
||||||
|
func (f formBinding) Name() string { |
||||||
|
return "form" |
||||||
|
} |
||||||
|
|
||||||
|
func (f formBinding) Bind(req *http.Request, obj interface{}) error { |
||||||
|
if err := req.ParseForm(); err != nil { |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
if err := mapForm(obj, req.Form); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return validate(obj) |
||||||
|
} |
||||||
|
|
||||||
|
func (f formPostBinding) Name() string { |
||||||
|
return "form-urlencoded" |
||||||
|
} |
||||||
|
|
||||||
|
func (f formPostBinding) Bind(req *http.Request, obj interface{}) error { |
||||||
|
if err := req.ParseForm(); err != nil { |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
if err := mapForm(obj, req.PostForm); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return validate(obj) |
||||||
|
} |
||||||
|
|
||||||
|
func (f formMultipartBinding) Name() string { |
||||||
|
return "multipart/form-data" |
||||||
|
} |
||||||
|
|
||||||
|
func (f formMultipartBinding) Bind(req *http.Request, obj interface{}) error { |
||||||
|
if err := req.ParseMultipartForm(defaultMemory); err != nil { |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
if err := mapForm(obj, req.MultipartForm.Value); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return validate(obj) |
||||||
|
} |
@ -0,0 +1,276 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"reflect" |
||||||
|
"strconv" |
||||||
|
"strings" |
||||||
|
"sync" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
// scache struct reflect type cache.
|
||||||
|
var scache = &cache{ |
||||||
|
data: make(map[reflect.Type]*sinfo), |
||||||
|
} |
||||||
|
|
||||||
|
type cache struct { |
||||||
|
data map[reflect.Type]*sinfo |
||||||
|
mutex sync.RWMutex |
||||||
|
} |
||||||
|
|
||||||
|
func (c *cache) get(obj reflect.Type) (s *sinfo) { |
||||||
|
var ok bool |
||||||
|
c.mutex.RLock() |
||||||
|
if s, ok = c.data[obj]; !ok { |
||||||
|
c.mutex.RUnlock() |
||||||
|
s = c.set(obj) |
||||||
|
return |
||||||
|
} |
||||||
|
c.mutex.RUnlock() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (c *cache) set(obj reflect.Type) (s *sinfo) { |
||||||
|
s = new(sinfo) |
||||||
|
tp := obj.Elem() |
||||||
|
for i := 0; i < tp.NumField(); i++ { |
||||||
|
fd := new(field) |
||||||
|
fd.tp = tp.Field(i) |
||||||
|
tag := fd.tp.Tag.Get("form") |
||||||
|
fd.name, fd.option = parseTag(tag) |
||||||
|
if defV := fd.tp.Tag.Get("default"); defV != "" { |
||||||
|
dv := reflect.New(fd.tp.Type).Elem() |
||||||
|
setWithProperType(fd.tp.Type.Kind(), []string{defV}, dv, fd.option) |
||||||
|
fd.hasDefault = true |
||||||
|
fd.defaultValue = dv |
||||||
|
} |
||||||
|
s.field = append(s.field, fd) |
||||||
|
} |
||||||
|
c.mutex.Lock() |
||||||
|
c.data[obj] = s |
||||||
|
c.mutex.Unlock() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
type sinfo struct { |
||||||
|
field []*field |
||||||
|
} |
||||||
|
|
||||||
|
type field struct { |
||||||
|
tp reflect.StructField |
||||||
|
name string |
||||||
|
option tagOptions |
||||||
|
|
||||||
|
hasDefault bool // if field had default value
|
||||||
|
defaultValue reflect.Value // field default value
|
||||||
|
} |
||||||
|
|
||||||
|
func mapForm(ptr interface{}, form map[string][]string) error { |
||||||
|
sinfo := scache.get(reflect.TypeOf(ptr)) |
||||||
|
val := reflect.ValueOf(ptr).Elem() |
||||||
|
for i, fd := range sinfo.field { |
||||||
|
typeField := fd.tp |
||||||
|
structField := val.Field(i) |
||||||
|
if !structField.CanSet() { |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
structFieldKind := structField.Kind() |
||||||
|
inputFieldName := fd.name |
||||||
|
if inputFieldName == "" { |
||||||
|
inputFieldName = typeField.Name |
||||||
|
|
||||||
|
// if "form" tag is nil, we inspect if the field is a struct.
|
||||||
|
// this would not make sense for JSON parsing but it does for a form
|
||||||
|
// since data is flatten
|
||||||
|
if structFieldKind == reflect.Struct { |
||||||
|
err := mapForm(structField.Addr().Interface(), form) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
continue |
||||||
|
} |
||||||
|
} |
||||||
|
inputValue, exists := form[inputFieldName] |
||||||
|
if !exists { |
||||||
|
// Set the field as default value when the input value is not exist
|
||||||
|
if fd.hasDefault { |
||||||
|
structField.Set(fd.defaultValue) |
||||||
|
} |
||||||
|
continue |
||||||
|
} |
||||||
|
// Set the field as default value when the input value is empty
|
||||||
|
if fd.hasDefault && inputValue[0] == "" { |
||||||
|
structField.Set(fd.defaultValue) |
||||||
|
continue |
||||||
|
} |
||||||
|
if _, isTime := structField.Interface().(time.Time); isTime { |
||||||
|
if err := setTimeField(inputValue[0], typeField, structField); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
continue |
||||||
|
} |
||||||
|
if err := setWithProperType(typeField.Type.Kind(), inputValue, structField, fd.option); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func setWithProperType(valueKind reflect.Kind, val []string, structField reflect.Value, option tagOptions) error { |
||||||
|
switch valueKind { |
||||||
|
case reflect.Int: |
||||||
|
return setIntField(val[0], 0, structField) |
||||||
|
case reflect.Int8: |
||||||
|
return setIntField(val[0], 8, structField) |
||||||
|
case reflect.Int16: |
||||||
|
return setIntField(val[0], 16, structField) |
||||||
|
case reflect.Int32: |
||||||
|
return setIntField(val[0], 32, structField) |
||||||
|
case reflect.Int64: |
||||||
|
return setIntField(val[0], 64, structField) |
||||||
|
case reflect.Uint: |
||||||
|
return setUintField(val[0], 0, structField) |
||||||
|
case reflect.Uint8: |
||||||
|
return setUintField(val[0], 8, structField) |
||||||
|
case reflect.Uint16: |
||||||
|
return setUintField(val[0], 16, structField) |
||||||
|
case reflect.Uint32: |
||||||
|
return setUintField(val[0], 32, structField) |
||||||
|
case reflect.Uint64: |
||||||
|
return setUintField(val[0], 64, structField) |
||||||
|
case reflect.Bool: |
||||||
|
return setBoolField(val[0], structField) |
||||||
|
case reflect.Float32: |
||||||
|
return setFloatField(val[0], 32, structField) |
||||||
|
case reflect.Float64: |
||||||
|
return setFloatField(val[0], 64, structField) |
||||||
|
case reflect.String: |
||||||
|
structField.SetString(val[0]) |
||||||
|
case reflect.Slice: |
||||||
|
if option.Contains("split") { |
||||||
|
val = strings.Split(val[0], ",") |
||||||
|
} |
||||||
|
filtered := filterEmpty(val) |
||||||
|
switch structField.Type().Elem().Kind() { |
||||||
|
case reflect.Int64: |
||||||
|
valSli := make([]int64, 0, len(filtered)) |
||||||
|
for i := 0; i < len(filtered); i++ { |
||||||
|
d, err := strconv.ParseInt(filtered[i], 10, 64) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
valSli = append(valSli, d) |
||||||
|
} |
||||||
|
structField.Set(reflect.ValueOf(valSli)) |
||||||
|
case reflect.String: |
||||||
|
valSli := make([]string, 0, len(filtered)) |
||||||
|
for i := 0; i < len(filtered); i++ { |
||||||
|
valSli = append(valSli, filtered[i]) |
||||||
|
} |
||||||
|
structField.Set(reflect.ValueOf(valSli)) |
||||||
|
default: |
||||||
|
sliceOf := structField.Type().Elem().Kind() |
||||||
|
numElems := len(filtered) |
||||||
|
slice := reflect.MakeSlice(structField.Type(), len(filtered), len(filtered)) |
||||||
|
for i := 0; i < numElems; i++ { |
||||||
|
if err := setWithProperType(sliceOf, filtered[i:], slice.Index(i), ""); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
structField.Set(slice) |
||||||
|
} |
||||||
|
default: |
||||||
|
return errors.New("Unknown type") |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func setIntField(val string, bitSize int, field reflect.Value) error { |
||||||
|
if val == "" { |
||||||
|
val = "0" |
||||||
|
} |
||||||
|
intVal, err := strconv.ParseInt(val, 10, bitSize) |
||||||
|
if err == nil { |
||||||
|
field.SetInt(intVal) |
||||||
|
} |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
|
||||||
|
func setUintField(val string, bitSize int, field reflect.Value) error { |
||||||
|
if val == "" { |
||||||
|
val = "0" |
||||||
|
} |
||||||
|
uintVal, err := strconv.ParseUint(val, 10, bitSize) |
||||||
|
if err == nil { |
||||||
|
field.SetUint(uintVal) |
||||||
|
} |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
|
||||||
|
func setBoolField(val string, field reflect.Value) error { |
||||||
|
if val == "" { |
||||||
|
val = "false" |
||||||
|
} |
||||||
|
boolVal, err := strconv.ParseBool(val) |
||||||
|
if err == nil { |
||||||
|
field.SetBool(boolVal) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func setFloatField(val string, bitSize int, field reflect.Value) error { |
||||||
|
if val == "" { |
||||||
|
val = "0.0" |
||||||
|
} |
||||||
|
floatVal, err := strconv.ParseFloat(val, bitSize) |
||||||
|
if err == nil { |
||||||
|
field.SetFloat(floatVal) |
||||||
|
} |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
|
||||||
|
func setTimeField(val string, structField reflect.StructField, value reflect.Value) error { |
||||||
|
timeFormat := structField.Tag.Get("time_format") |
||||||
|
if timeFormat == "" { |
||||||
|
return errors.New("Blank time format") |
||||||
|
} |
||||||
|
|
||||||
|
if val == "" { |
||||||
|
value.Set(reflect.ValueOf(time.Time{})) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
l := time.Local |
||||||
|
if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC { |
||||||
|
l = time.UTC |
||||||
|
} |
||||||
|
|
||||||
|
if locTag := structField.Tag.Get("time_location"); locTag != "" { |
||||||
|
loc, err := time.LoadLocation(locTag) |
||||||
|
if err != nil { |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
l = loc |
||||||
|
} |
||||||
|
|
||||||
|
t, err := time.ParseInLocation(timeFormat, val, l) |
||||||
|
if err != nil { |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
|
||||||
|
value.Set(reflect.ValueOf(t)) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func filterEmpty(val []string) []string { |
||||||
|
filtered := make([]string, 0, len(val)) |
||||||
|
for _, v := range val { |
||||||
|
if v != "" { |
||||||
|
filtered = append(filtered, v) |
||||||
|
} |
||||||
|
} |
||||||
|
return filtered |
||||||
|
} |
@ -0,0 +1,22 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/json" |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
type jsonBinding struct{} |
||||||
|
|
||||||
|
func (jsonBinding) Name() string { |
||||||
|
return "json" |
||||||
|
} |
||||||
|
|
||||||
|
func (jsonBinding) Bind(req *http.Request, obj interface{}) error { |
||||||
|
decoder := json.NewDecoder(req.Body) |
||||||
|
if err := decoder.Decode(obj); err != nil { |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
return validate(obj) |
||||||
|
} |
@ -0,0 +1,19 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
) |
||||||
|
|
||||||
|
type queryBinding struct{} |
||||||
|
|
||||||
|
func (queryBinding) Name() string { |
||||||
|
return "query" |
||||||
|
} |
||||||
|
|
||||||
|
func (queryBinding) Bind(req *http.Request, obj interface{}) error { |
||||||
|
values := req.URL.Query() |
||||||
|
if err := mapForm(obj, values); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return validate(obj) |
||||||
|
} |
@ -0,0 +1,44 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"strings" |
||||||
|
) |
||||||
|
|
||||||
|
// tagOptions is the string following a comma in a struct field's "json"
|
||||||
|
// tag, or the empty string. It does not include the leading comma.
|
||||||
|
type tagOptions string |
||||||
|
|
||||||
|
// parseTag splits a struct field's json tag into its name and
|
||||||
|
// comma-separated options.
|
||||||
|
func parseTag(tag string) (string, tagOptions) { |
||||||
|
if idx := strings.Index(tag, ","); idx != -1 { |
||||||
|
return tag[:idx], tagOptions(tag[idx+1:]) |
||||||
|
} |
||||||
|
return tag, tagOptions("") |
||||||
|
} |
||||||
|
|
||||||
|
// Contains reports whether a comma-separated list of options
|
||||||
|
// contains a particular substr flag. substr must be surrounded by a
|
||||||
|
// string boundary or commas.
|
||||||
|
func (o tagOptions) Contains(optionName string) bool { |
||||||
|
if len(o) == 0 { |
||||||
|
return false |
||||||
|
} |
||||||
|
s := string(o) |
||||||
|
for s != "" { |
||||||
|
var next string |
||||||
|
i := strings.Index(s, ",") |
||||||
|
if i >= 0 { |
||||||
|
s, next = s[:i], s[i+1:] |
||||||
|
} |
||||||
|
if s == optionName { |
||||||
|
return true |
||||||
|
} |
||||||
|
s = next |
||||||
|
} |
||||||
|
return false |
||||||
|
} |
@ -0,0 +1,209 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"testing" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert" |
||||||
|
) |
||||||
|
|
||||||
|
type testInterface interface { |
||||||
|
String() string |
||||||
|
} |
||||||
|
|
||||||
|
type substructNoValidation struct { |
||||||
|
IString string |
||||||
|
IInt int |
||||||
|
} |
||||||
|
|
||||||
|
type mapNoValidationSub map[string]substructNoValidation |
||||||
|
|
||||||
|
type structNoValidationValues struct { |
||||||
|
substructNoValidation |
||||||
|
|
||||||
|
Boolean bool |
||||||
|
|
||||||
|
Uinteger uint |
||||||
|
Integer int |
||||||
|
Integer8 int8 |
||||||
|
Integer16 int16 |
||||||
|
Integer32 int32 |
||||||
|
Integer64 int64 |
||||||
|
Uinteger8 uint8 |
||||||
|
Uinteger16 uint16 |
||||||
|
Uinteger32 uint32 |
||||||
|
Uinteger64 uint64 |
||||||
|
|
||||||
|
Float32 float32 |
||||||
|
Float64 float64 |
||||||
|
|
||||||
|
String string |
||||||
|
|
||||||
|
Date time.Time |
||||||
|
|
||||||
|
Struct substructNoValidation |
||||||
|
InlinedStruct struct { |
||||||
|
String []string |
||||||
|
Integer int |
||||||
|
} |
||||||
|
|
||||||
|
IntSlice []int |
||||||
|
IntPointerSlice []*int |
||||||
|
StructPointerSlice []*substructNoValidation |
||||||
|
StructSlice []substructNoValidation |
||||||
|
InterfaceSlice []testInterface |
||||||
|
|
||||||
|
UniversalInterface interface{} |
||||||
|
CustomInterface testInterface |
||||||
|
|
||||||
|
FloatMap map[string]float32 |
||||||
|
StructMap mapNoValidationSub |
||||||
|
} |
||||||
|
|
||||||
|
func createNoValidationValues() structNoValidationValues { |
||||||
|
integer := 1 |
||||||
|
s := structNoValidationValues{ |
||||||
|
Boolean: true, |
||||||
|
Uinteger: 1 << 29, |
||||||
|
Integer: -10000, |
||||||
|
Integer8: 120, |
||||||
|
Integer16: -20000, |
||||||
|
Integer32: 1 << 29, |
||||||
|
Integer64: 1 << 61, |
||||||
|
Uinteger8: 250, |
||||||
|
Uinteger16: 50000, |
||||||
|
Uinteger32: 1 << 31, |
||||||
|
Uinteger64: 1 << 62, |
||||||
|
Float32: 123.456, |
||||||
|
Float64: 123.456789, |
||||||
|
String: "text", |
||||||
|
Date: time.Time{}, |
||||||
|
CustomInterface: &bytes.Buffer{}, |
||||||
|
Struct: substructNoValidation{}, |
||||||
|
IntSlice: []int{-3, -2, 1, 0, 1, 2, 3}, |
||||||
|
IntPointerSlice: []*int{&integer}, |
||||||
|
StructSlice: []substructNoValidation{}, |
||||||
|
UniversalInterface: 1.2, |
||||||
|
FloatMap: map[string]float32{ |
||||||
|
"foo": 1.23, |
||||||
|
"bar": 232.323, |
||||||
|
}, |
||||||
|
StructMap: mapNoValidationSub{ |
||||||
|
"foo": substructNoValidation{}, |
||||||
|
"bar": substructNoValidation{}, |
||||||
|
}, |
||||||
|
// StructPointerSlice []noValidationSub
|
||||||
|
// InterfaceSlice []testInterface
|
||||||
|
} |
||||||
|
s.InlinedStruct.Integer = 1000 |
||||||
|
s.InlinedStruct.String = []string{"first", "second"} |
||||||
|
s.IString = "substring" |
||||||
|
s.IInt = 987654 |
||||||
|
return s |
||||||
|
} |
||||||
|
|
||||||
|
func TestValidateNoValidationValues(t *testing.T) { |
||||||
|
origin := createNoValidationValues() |
||||||
|
test := createNoValidationValues() |
||||||
|
empty := structNoValidationValues{} |
||||||
|
|
||||||
|
assert.Nil(t, validate(test)) |
||||||
|
assert.Nil(t, validate(&test)) |
||||||
|
assert.Nil(t, validate(empty)) |
||||||
|
assert.Nil(t, validate(&empty)) |
||||||
|
|
||||||
|
assert.Equal(t, origin, test) |
||||||
|
} |
||||||
|
|
||||||
|
type structNoValidationPointer struct { |
||||||
|
// substructNoValidation
|
||||||
|
|
||||||
|
Boolean bool |
||||||
|
|
||||||
|
Uinteger *uint |
||||||
|
Integer *int |
||||||
|
Integer8 *int8 |
||||||
|
Integer16 *int16 |
||||||
|
Integer32 *int32 |
||||||
|
Integer64 *int64 |
||||||
|
Uinteger8 *uint8 |
||||||
|
Uinteger16 *uint16 |
||||||
|
Uinteger32 *uint32 |
||||||
|
Uinteger64 *uint64 |
||||||
|
|
||||||
|
Float32 *float32 |
||||||
|
Float64 *float64 |
||||||
|
|
||||||
|
String *string |
||||||
|
|
||||||
|
Date *time.Time |
||||||
|
|
||||||
|
Struct *substructNoValidation |
||||||
|
|
||||||
|
IntSlice *[]int |
||||||
|
IntPointerSlice *[]*int |
||||||
|
StructPointerSlice *[]*substructNoValidation |
||||||
|
StructSlice *[]substructNoValidation |
||||||
|
InterfaceSlice *[]testInterface |
||||||
|
|
||||||
|
FloatMap *map[string]float32 |
||||||
|
StructMap *mapNoValidationSub |
||||||
|
} |
||||||
|
|
||||||
|
func TestValidateNoValidationPointers(t *testing.T) { |
||||||
|
//origin := createNoValidation_values()
|
||||||
|
//test := createNoValidation_values()
|
||||||
|
empty := structNoValidationPointer{} |
||||||
|
|
||||||
|
//assert.Nil(t, validate(test))
|
||||||
|
//assert.Nil(t, validate(&test))
|
||||||
|
assert.Nil(t, validate(empty)) |
||||||
|
assert.Nil(t, validate(&empty)) |
||||||
|
|
||||||
|
//assert.Equal(t, origin, test)
|
||||||
|
} |
||||||
|
|
||||||
|
type Object map[string]interface{} |
||||||
|
|
||||||
|
func TestValidatePrimitives(t *testing.T) { |
||||||
|
obj := Object{"foo": "bar", "bar": 1} |
||||||
|
assert.NoError(t, validate(obj)) |
||||||
|
assert.NoError(t, validate(&obj)) |
||||||
|
assert.Equal(t, obj, Object{"foo": "bar", "bar": 1}) |
||||||
|
|
||||||
|
obj2 := []Object{{"foo": "bar", "bar": 1}, {"foo": "bar", "bar": 1}} |
||||||
|
assert.NoError(t, validate(obj2)) |
||||||
|
assert.NoError(t, validate(&obj2)) |
||||||
|
|
||||||
|
nu := 10 |
||||||
|
assert.NoError(t, validate(nu)) |
||||||
|
assert.NoError(t, validate(&nu)) |
||||||
|
assert.Equal(t, nu, 10) |
||||||
|
|
||||||
|
str := "value" |
||||||
|
assert.NoError(t, validate(str)) |
||||||
|
assert.NoError(t, validate(&str)) |
||||||
|
assert.Equal(t, str, "value") |
||||||
|
} |
||||||
|
|
||||||
|
// structCustomValidation is a helper struct we use to check that
|
||||||
|
// custom validation can be registered on it.
|
||||||
|
// The `notone` binding directive is for custom validation and registered later.
|
||||||
|
// type structCustomValidation struct {
|
||||||
|
// Integer int `binding:"notone"`
|
||||||
|
// }
|
||||||
|
|
||||||
|
// notOne is a custom validator meant to be used with `validator.v8` library.
|
||||||
|
// The method signature for `v9` is significantly different and this function
|
||||||
|
// would need to be changed for tests to pass after upgrade.
|
||||||
|
// See https://github.com/gin-gonic/gin/pull/1015.
|
||||||
|
// func notOne(
|
||||||
|
// v *validator.Validate, topStruct reflect.Value, currentStructOrField reflect.Value,
|
||||||
|
// field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string,
|
||||||
|
// ) bool {
|
||||||
|
// if val, ok := field.Interface().(int); ok {
|
||||||
|
// return val != 1
|
||||||
|
// }
|
||||||
|
// return false
|
||||||
|
// }
|
@ -0,0 +1,22 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/xml" |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
type xmlBinding struct{} |
||||||
|
|
||||||
|
func (xmlBinding) Name() string { |
||||||
|
return "xml" |
||||||
|
} |
||||||
|
|
||||||
|
func (xmlBinding) Bind(req *http.Request, obj interface{}) error { |
||||||
|
decoder := xml.NewDecoder(req.Body) |
||||||
|
if err := decoder.Decode(obj); err != nil { |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
return validate(obj) |
||||||
|
} |
@ -0,0 +1,306 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"math" |
||||||
|
"net/http" |
||||||
|
"strconv" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/ecode" |
||||||
|
"github.com/bilibili/kratos/pkg/net/http/blademaster/binding" |
||||||
|
"github.com/bilibili/kratos/pkg/net/http/blademaster/render" |
||||||
|
|
||||||
|
"github.com/gogo/protobuf/proto" |
||||||
|
"github.com/gogo/protobuf/types" |
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
_abortIndex int8 = math.MaxInt8 / 2 |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
_openParen = []byte("(") |
||||||
|
_closeParen = []byte(")") |
||||||
|
) |
||||||
|
|
||||||
|
// Context is the most important part. It allows us to pass variables between
|
||||||
|
// middleware, manage the flow, validate the JSON of a request and render a
|
||||||
|
// JSON response for example.
|
||||||
|
type Context struct { |
||||||
|
context.Context |
||||||
|
|
||||||
|
Request *http.Request |
||||||
|
Writer http.ResponseWriter |
||||||
|
|
||||||
|
// flow control
|
||||||
|
index int8 |
||||||
|
handlers []HandlerFunc |
||||||
|
|
||||||
|
// Keys is a key/value pair exclusively for the context of each request.
|
||||||
|
Keys map[string]interface{} |
||||||
|
|
||||||
|
Error error |
||||||
|
|
||||||
|
method string |
||||||
|
engine *Engine |
||||||
|
} |
||||||
|
|
||||||
|
/************************************/ |
||||||
|
/*********** FLOW CONTROL ***********/ |
||||||
|
/************************************/ |
||||||
|
|
||||||
|
// Next should be used only inside middleware.
|
||||||
|
// It executes the pending handlers in the chain inside the calling handler.
|
||||||
|
// See example in godoc.
|
||||||
|
func (c *Context) Next() { |
||||||
|
c.index++ |
||||||
|
s := int8(len(c.handlers)) |
||||||
|
for ; c.index < s; c.index++ { |
||||||
|
// only check method on last handler, otherwise middlewares
|
||||||
|
// will never be effected if request method is not matched
|
||||||
|
if c.index == s-1 && c.method != c.Request.Method { |
||||||
|
code := http.StatusMethodNotAllowed |
||||||
|
c.Error = ecode.MethodNotAllowed |
||||||
|
http.Error(c.Writer, http.StatusText(code), code) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
c.handlers[c.index](c) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Abort prevents pending handlers from being called. Note that this will not stop the current handler.
|
||||||
|
// Let's say you have an authorization middleware that validates that the current request is authorized.
|
||||||
|
// If the authorization fails (ex: the password does not match), call Abort to ensure the remaining handlers
|
||||||
|
// for this request are not called.
|
||||||
|
func (c *Context) Abort() { |
||||||
|
c.index = _abortIndex |
||||||
|
} |
||||||
|
|
||||||
|
// AbortWithStatus calls `Abort()` and writes the headers with the specified status code.
|
||||||
|
// For example, a failed attempt to authenticate a request could use: context.AbortWithStatus(401).
|
||||||
|
func (c *Context) AbortWithStatus(code int) { |
||||||
|
c.Status(code) |
||||||
|
c.Abort() |
||||||
|
} |
||||||
|
|
||||||
|
// IsAborted returns true if the current context was aborted.
|
||||||
|
func (c *Context) IsAborted() bool { |
||||||
|
return c.index >= _abortIndex |
||||||
|
} |
||||||
|
|
||||||
|
/************************************/ |
||||||
|
/******** METADATA MANAGEMENT********/ |
||||||
|
/************************************/ |
||||||
|
|
||||||
|
// Set is used to store a new key/value pair exclusively for this context.
|
||||||
|
// It also lazy initializes c.Keys if it was not used previously.
|
||||||
|
func (c *Context) Set(key string, value interface{}) { |
||||||
|
if c.Keys == nil { |
||||||
|
c.Keys = make(map[string]interface{}) |
||||||
|
} |
||||||
|
c.Keys[key] = value |
||||||
|
} |
||||||
|
|
||||||
|
// Get returns the value for the given key, ie: (value, true).
|
||||||
|
// If the value does not exists it returns (nil, false)
|
||||||
|
func (c *Context) Get(key string) (value interface{}, exists bool) { |
||||||
|
value, exists = c.Keys[key] |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
/************************************/ |
||||||
|
/******** RESPONSE RENDERING ********/ |
||||||
|
/************************************/ |
||||||
|
|
||||||
|
// bodyAllowedForStatus is a copy of http.bodyAllowedForStatus non-exported function.
|
||||||
|
func bodyAllowedForStatus(status int) bool { |
||||||
|
switch { |
||||||
|
case status >= 100 && status <= 199: |
||||||
|
return false |
||||||
|
case status == 204: |
||||||
|
return false |
||||||
|
case status == 304: |
||||||
|
return false |
||||||
|
} |
||||||
|
return true |
||||||
|
} |
||||||
|
|
||||||
|
// Status sets the HTTP response code.
|
||||||
|
func (c *Context) Status(code int) { |
||||||
|
c.Writer.WriteHeader(code) |
||||||
|
} |
||||||
|
|
||||||
|
// Render http response with http code by a render instance.
|
||||||
|
func (c *Context) Render(code int, r render.Render) { |
||||||
|
r.WriteContentType(c.Writer) |
||||||
|
if code > 0 { |
||||||
|
c.Status(code) |
||||||
|
} |
||||||
|
|
||||||
|
if !bodyAllowedForStatus(code) { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
params := c.Request.Form |
||||||
|
|
||||||
|
cb := params.Get("callback") |
||||||
|
jsonp := cb != "" && params.Get("jsonp") == "jsonp" |
||||||
|
if jsonp { |
||||||
|
c.Writer.Write([]byte(cb)) |
||||||
|
c.Writer.Write(_openParen) |
||||||
|
} |
||||||
|
|
||||||
|
if err := r.Render(c.Writer); err != nil { |
||||||
|
c.Error = err |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
if jsonp { |
||||||
|
if _, err := c.Writer.Write(_closeParen); err != nil { |
||||||
|
c.Error = errors.WithStack(err) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// JSON serializes the given struct as JSON into the response body.
|
||||||
|
// It also sets the Content-Type as "application/json".
|
||||||
|
func (c *Context) JSON(data interface{}, err error) { |
||||||
|
code := http.StatusOK |
||||||
|
c.Error = err |
||||||
|
bcode := ecode.Cause(err) |
||||||
|
// TODO app allow 5xx?
|
||||||
|
/* |
||||||
|
if bcode.Code() == -500 { |
||||||
|
code = http.StatusServiceUnavailable |
||||||
|
} |
||||||
|
*/ |
||||||
|
writeStatusCode(c.Writer, bcode.Code()) |
||||||
|
c.Render(code, render.JSON{ |
||||||
|
Code: bcode.Code(), |
||||||
|
Message: bcode.Message(), |
||||||
|
Data: data, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
// JSONMap serializes the given map as map JSON into the response body.
|
||||||
|
// It also sets the Content-Type as "application/json".
|
||||||
|
func (c *Context) JSONMap(data map[string]interface{}, err error) { |
||||||
|
code := http.StatusOK |
||||||
|
c.Error = err |
||||||
|
bcode := ecode.Cause(err) |
||||||
|
// TODO app allow 5xx?
|
||||||
|
/* |
||||||
|
if bcode.Code() == -500 { |
||||||
|
code = http.StatusServiceUnavailable |
||||||
|
} |
||||||
|
*/ |
||||||
|
writeStatusCode(c.Writer, bcode.Code()) |
||||||
|
data["code"] = bcode.Code() |
||||||
|
if _, ok := data["message"]; !ok { |
||||||
|
data["message"] = bcode.Message() |
||||||
|
} |
||||||
|
c.Render(code, render.MapJSON(data)) |
||||||
|
} |
||||||
|
|
||||||
|
// XML serializes the given struct as XML into the response body.
|
||||||
|
// It also sets the Content-Type as "application/xml".
|
||||||
|
func (c *Context) XML(data interface{}, err error) { |
||||||
|
code := http.StatusOK |
||||||
|
c.Error = err |
||||||
|
bcode := ecode.Cause(err) |
||||||
|
// TODO app allow 5xx?
|
||||||
|
/* |
||||||
|
if bcode.Code() == -500 { |
||||||
|
code = http.StatusServiceUnavailable |
||||||
|
} |
||||||
|
*/ |
||||||
|
writeStatusCode(c.Writer, bcode.Code()) |
||||||
|
c.Render(code, render.XML{ |
||||||
|
Code: bcode.Code(), |
||||||
|
Message: bcode.Message(), |
||||||
|
Data: data, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
// Protobuf serializes the given struct as PB into the response body.
|
||||||
|
// It also sets the ContentType as "application/x-protobuf".
|
||||||
|
func (c *Context) Protobuf(data proto.Message, err error) { |
||||||
|
var ( |
||||||
|
bytes []byte |
||||||
|
) |
||||||
|
|
||||||
|
code := http.StatusOK |
||||||
|
c.Error = err |
||||||
|
bcode := ecode.Cause(err) |
||||||
|
|
||||||
|
any := new(types.Any) |
||||||
|
if data != nil { |
||||||
|
if bytes, err = proto.Marshal(data); err != nil { |
||||||
|
c.Error = errors.WithStack(err) |
||||||
|
return |
||||||
|
} |
||||||
|
any.TypeUrl = "type.googleapis.com/" + proto.MessageName(data) |
||||||
|
any.Value = bytes |
||||||
|
} |
||||||
|
writeStatusCode(c.Writer, bcode.Code()) |
||||||
|
c.Render(code, render.PB{ |
||||||
|
Code: int64(bcode.Code()), |
||||||
|
Message: bcode.Message(), |
||||||
|
Data: any, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
// Bytes writes some data into the body stream and updates the HTTP code.
|
||||||
|
func (c *Context) Bytes(code int, contentType string, data ...[]byte) { |
||||||
|
c.Render(code, render.Data{ |
||||||
|
ContentType: contentType, |
||||||
|
Data: data, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
// String writes the given string into the response body.
|
||||||
|
func (c *Context) String(code int, format string, values ...interface{}) { |
||||||
|
c.Render(code, render.String{Format: format, Data: values}) |
||||||
|
} |
||||||
|
|
||||||
|
// Redirect returns a HTTP redirect to the specific location.
|
||||||
|
func (c *Context) Redirect(code int, location string) { |
||||||
|
c.Render(-1, render.Redirect{ |
||||||
|
Code: code, |
||||||
|
Location: location, |
||||||
|
Request: c.Request, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
// BindWith bind req arg with parser.
|
||||||
|
func (c *Context) BindWith(obj interface{}, b binding.Binding) error { |
||||||
|
return c.mustBindWith(obj, b) |
||||||
|
} |
||||||
|
|
||||||
|
// Bind bind req arg with defult form binding.
|
||||||
|
func (c *Context) Bind(obj interface{}) error { |
||||||
|
return c.mustBindWith(obj, binding.Form) |
||||||
|
} |
||||||
|
|
||||||
|
// mustBindWith binds the passed struct pointer using the specified binding engine.
|
||||||
|
// It will abort the request with HTTP 400 if any error ocurrs.
|
||||||
|
// See the binding package.
|
||||||
|
func (c *Context) mustBindWith(obj interface{}, b binding.Binding) (err error) { |
||||||
|
if err = b.Bind(c.Request, obj); err != nil { |
||||||
|
c.Error = ecode.RequestErr |
||||||
|
c.Render(http.StatusOK, render.JSON{ |
||||||
|
Code: ecode.RequestErr.Code(), |
||||||
|
Message: err.Error(), |
||||||
|
Data: nil, |
||||||
|
}) |
||||||
|
c.Abort() |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func writeStatusCode(w http.ResponseWriter, ecode int) { |
||||||
|
header := w.Header() |
||||||
|
header.Set("kratos-status-code", strconv.FormatInt(int64(ecode), 10)) |
||||||
|
} |
@ -0,0 +1,249 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
"strconv" |
||||||
|
"strings" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
// CORSConfig represents all available options for the middleware.
|
||||||
|
type CORSConfig struct { |
||||||
|
AllowAllOrigins bool |
||||||
|
|
||||||
|
// AllowedOrigins is a list of origins a cross-domain request can be executed from.
|
||||||
|
// If the special "*" value is present in the list, all origins will be allowed.
|
||||||
|
// Default value is []
|
||||||
|
AllowOrigins []string |
||||||
|
|
||||||
|
// AllowOriginFunc is a custom function to validate the origin. It take the origin
|
||||||
|
// as argument and returns true if allowed or false otherwise. If this option is
|
||||||
|
// set, the content of AllowedOrigins is ignored.
|
||||||
|
AllowOriginFunc func(origin string) bool |
||||||
|
|
||||||
|
// AllowedMethods is a list of methods the client is allowed to use with
|
||||||
|
// cross-domain requests. Default value is simple methods (GET and POST)
|
||||||
|
AllowMethods []string |
||||||
|
|
||||||
|
// AllowedHeaders is list of non simple headers the client is allowed to use with
|
||||||
|
// cross-domain requests.
|
||||||
|
AllowHeaders []string |
||||||
|
|
||||||
|
// AllowCredentials indicates whether the request can include user credentials like
|
||||||
|
// cookies, HTTP authentication or client side SSL certificates.
|
||||||
|
AllowCredentials bool |
||||||
|
|
||||||
|
// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
|
||||||
|
// API specification
|
||||||
|
ExposeHeaders []string |
||||||
|
|
||||||
|
// MaxAge indicates how long (in seconds) the results of a preflight request
|
||||||
|
// can be cached
|
||||||
|
MaxAge time.Duration |
||||||
|
} |
||||||
|
|
||||||
|
type cors struct { |
||||||
|
allowAllOrigins bool |
||||||
|
allowCredentials bool |
||||||
|
allowOriginFunc func(string) bool |
||||||
|
allowOrigins []string |
||||||
|
normalHeaders http.Header |
||||||
|
preflightHeaders http.Header |
||||||
|
} |
||||||
|
|
||||||
|
type converter func(string) string |
||||||
|
|
||||||
|
// Validate is check configuration of user defined.
|
||||||
|
func (c *CORSConfig) Validate() error { |
||||||
|
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) { |
||||||
|
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed") |
||||||
|
} |
||||||
|
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 { |
||||||
|
return errors.New("conflict settings: all origins disabled") |
||||||
|
} |
||||||
|
for _, origin := range c.AllowOrigins { |
||||||
|
if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") { |
||||||
|
return errors.New("bad origin: origins must either be '*' or include http:// or https://") |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// CORS returns the location middleware with default configuration.
|
||||||
|
func CORS(allowOriginHosts []string) HandlerFunc { |
||||||
|
config := &CORSConfig{ |
||||||
|
AllowMethods: []string{"GET", "POST"}, |
||||||
|
AllowHeaders: []string{"Origin", "Content-Length", "Content-Type"}, |
||||||
|
AllowCredentials: true, |
||||||
|
MaxAge: time.Duration(0), |
||||||
|
AllowOriginFunc: func(origin string) bool { |
||||||
|
for _, host := range allowOriginHosts { |
||||||
|
if strings.HasSuffix(strings.ToLower(origin), host) { |
||||||
|
return true |
||||||
|
} |
||||||
|
} |
||||||
|
return false |
||||||
|
}, |
||||||
|
} |
||||||
|
return newCORS(config) |
||||||
|
} |
||||||
|
|
||||||
|
// newCORS returns the location middleware with user-defined custom configuration.
|
||||||
|
func newCORS(config *CORSConfig) HandlerFunc { |
||||||
|
if err := config.Validate(); err != nil { |
||||||
|
panic(err.Error()) |
||||||
|
} |
||||||
|
cors := &cors{ |
||||||
|
allowOriginFunc: config.AllowOriginFunc, |
||||||
|
allowAllOrigins: config.AllowAllOrigins, |
||||||
|
allowCredentials: config.AllowCredentials, |
||||||
|
allowOrigins: normalize(config.AllowOrigins), |
||||||
|
normalHeaders: generateNormalHeaders(config), |
||||||
|
preflightHeaders: generatePreflightHeaders(config), |
||||||
|
} |
||||||
|
|
||||||
|
return func(c *Context) { |
||||||
|
cors.applyCORS(c) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (cors *cors) applyCORS(c *Context) { |
||||||
|
origin := c.Request.Header.Get("Origin") |
||||||
|
if len(origin) == 0 { |
||||||
|
// request is not a CORS request
|
||||||
|
return |
||||||
|
} |
||||||
|
if !cors.validateOrigin(origin) { |
||||||
|
log.V(5).Info("The request's Origin header `%s` does not match any of allowed origins.", origin) |
||||||
|
c.AbortWithStatus(http.StatusForbidden) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
if c.Request.Method == "OPTIONS" { |
||||||
|
cors.handlePreflight(c) |
||||||
|
defer c.AbortWithStatus(200) |
||||||
|
} else { |
||||||
|
cors.handleNormal(c) |
||||||
|
} |
||||||
|
|
||||||
|
if !cors.allowAllOrigins { |
||||||
|
header := c.Writer.Header() |
||||||
|
header.Set("Access-Control-Allow-Origin", origin) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (cors *cors) validateOrigin(origin string) bool { |
||||||
|
if cors.allowAllOrigins { |
||||||
|
return true |
||||||
|
} |
||||||
|
for _, value := range cors.allowOrigins { |
||||||
|
if value == origin { |
||||||
|
return true |
||||||
|
} |
||||||
|
} |
||||||
|
if cors.allowOriginFunc != nil { |
||||||
|
return cors.allowOriginFunc(origin) |
||||||
|
} |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
func (cors *cors) handlePreflight(c *Context) { |
||||||
|
header := c.Writer.Header() |
||||||
|
for key, value := range cors.preflightHeaders { |
||||||
|
header[key] = value |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (cors *cors) handleNormal(c *Context) { |
||||||
|
header := c.Writer.Header() |
||||||
|
for key, value := range cors.normalHeaders { |
||||||
|
header[key] = value |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func generateNormalHeaders(c *CORSConfig) http.Header { |
||||||
|
headers := make(http.Header) |
||||||
|
if c.AllowCredentials { |
||||||
|
headers.Set("Access-Control-Allow-Credentials", "true") |
||||||
|
} |
||||||
|
|
||||||
|
// backport support for early browsers
|
||||||
|
if len(c.AllowMethods) > 0 { |
||||||
|
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper) |
||||||
|
value := strings.Join(allowMethods, ",") |
||||||
|
headers.Set("Access-Control-Allow-Methods", value) |
||||||
|
} |
||||||
|
|
||||||
|
if len(c.ExposeHeaders) > 0 { |
||||||
|
exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey) |
||||||
|
headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ",")) |
||||||
|
} |
||||||
|
if c.AllowAllOrigins { |
||||||
|
headers.Set("Access-Control-Allow-Origin", "*") |
||||||
|
} else { |
||||||
|
headers.Set("Vary", "Origin") |
||||||
|
} |
||||||
|
return headers |
||||||
|
} |
||||||
|
|
||||||
|
func generatePreflightHeaders(c *CORSConfig) http.Header { |
||||||
|
headers := make(http.Header) |
||||||
|
if c.AllowCredentials { |
||||||
|
headers.Set("Access-Control-Allow-Credentials", "true") |
||||||
|
} |
||||||
|
if len(c.AllowMethods) > 0 { |
||||||
|
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper) |
||||||
|
value := strings.Join(allowMethods, ",") |
||||||
|
headers.Set("Access-Control-Allow-Methods", value) |
||||||
|
} |
||||||
|
if len(c.AllowHeaders) > 0 { |
||||||
|
allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey) |
||||||
|
value := strings.Join(allowHeaders, ",") |
||||||
|
headers.Set("Access-Control-Allow-Headers", value) |
||||||
|
} |
||||||
|
if c.MaxAge > time.Duration(0) { |
||||||
|
value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10) |
||||||
|
headers.Set("Access-Control-Max-Age", value) |
||||||
|
} |
||||||
|
if c.AllowAllOrigins { |
||||||
|
headers.Set("Access-Control-Allow-Origin", "*") |
||||||
|
} else { |
||||||
|
// Always set Vary headers
|
||||||
|
// see https://github.com/rs/cors/issues/10,
|
||||||
|
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
|
||||||
|
|
||||||
|
headers.Add("Vary", "Origin") |
||||||
|
headers.Add("Vary", "Access-Control-Request-Method") |
||||||
|
headers.Add("Vary", "Access-Control-Request-Headers") |
||||||
|
} |
||||||
|
return headers |
||||||
|
} |
||||||
|
|
||||||
|
func normalize(values []string) []string { |
||||||
|
if values == nil { |
||||||
|
return nil |
||||||
|
} |
||||||
|
distinctMap := make(map[string]bool, len(values)) |
||||||
|
normalized := make([]string, 0, len(values)) |
||||||
|
for _, value := range values { |
||||||
|
value = strings.TrimSpace(value) |
||||||
|
value = strings.ToLower(value) |
||||||
|
if _, seen := distinctMap[value]; !seen { |
||||||
|
normalized = append(normalized, value) |
||||||
|
distinctMap[value] = true |
||||||
|
} |
||||||
|
} |
||||||
|
return normalized |
||||||
|
} |
||||||
|
|
||||||
|
func convert(s []string, c converter) []string { |
||||||
|
var out []string |
||||||
|
for _, i := range s { |
||||||
|
out = append(out, c(i)) |
||||||
|
} |
||||||
|
return out |
||||||
|
} |
@ -0,0 +1,64 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/url" |
||||||
|
"regexp" |
||||||
|
"strings" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
) |
||||||
|
|
||||||
|
func matchHostSuffix(suffix string) func(*url.URL) bool { |
||||||
|
return func(uri *url.URL) bool { |
||||||
|
return strings.HasSuffix(strings.ToLower(uri.Host), suffix) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func matchPattern(pattern *regexp.Regexp) func(*url.URL) bool { |
||||||
|
return func(uri *url.URL) bool { |
||||||
|
return pattern.MatchString(strings.ToLower(uri.String())) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// CSRF returns the csrf middleware to prevent invalid cross site request.
|
||||||
|
// Only referer is checked currently.
|
||||||
|
func CSRF(allowHosts []string, allowPattern []string) HandlerFunc { |
||||||
|
validations := []func(*url.URL) bool{} |
||||||
|
|
||||||
|
addHostSuffix := func(suffix string) { |
||||||
|
validations = append(validations, matchHostSuffix(suffix)) |
||||||
|
} |
||||||
|
addPattern := func(pattern string) { |
||||||
|
validations = append(validations, matchPattern(regexp.MustCompile(pattern))) |
||||||
|
} |
||||||
|
|
||||||
|
for _, r := range allowHosts { |
||||||
|
addHostSuffix(r) |
||||||
|
} |
||||||
|
for _, p := range allowPattern { |
||||||
|
addPattern(p) |
||||||
|
} |
||||||
|
|
||||||
|
return func(c *Context) { |
||||||
|
referer := c.Request.Header.Get("Referer") |
||||||
|
if referer == "" { |
||||||
|
log.V(5).Info("The request's Referer or Origin header is empty.") |
||||||
|
c.AbortWithStatus(403) |
||||||
|
return |
||||||
|
} |
||||||
|
illegal := true |
||||||
|
if uri, err := url.Parse(referer); err == nil && uri.Host != "" { |
||||||
|
for _, validate := range validations { |
||||||
|
if validate(uri) { |
||||||
|
illegal = false |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
if illegal { |
||||||
|
log.V(5).Info("The request's Referer header `%s` does not match any of allowed referers.", referer) |
||||||
|
c.AbortWithStatus(403) |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,69 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"strconv" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/ecode" |
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
"github.com/bilibili/kratos/pkg/net/metadata" |
||||||
|
) |
||||||
|
|
||||||
|
// Logger is logger middleware
|
||||||
|
func Logger() HandlerFunc { |
||||||
|
const noUser = "no_user" |
||||||
|
return func(c *Context) { |
||||||
|
now := time.Now() |
||||||
|
ip := metadata.String(c, metadata.RemoteIP) |
||||||
|
req := c.Request |
||||||
|
path := req.URL.Path |
||||||
|
params := req.Form |
||||||
|
var quota float64 |
||||||
|
if deadline, ok := c.Context.Deadline(); ok { |
||||||
|
quota = time.Until(deadline).Seconds() |
||||||
|
} |
||||||
|
|
||||||
|
c.Next() |
||||||
|
|
||||||
|
err := c.Error |
||||||
|
cerr := ecode.Cause(err) |
||||||
|
dt := time.Since(now) |
||||||
|
caller := metadata.String(c, metadata.Caller) |
||||||
|
if caller == "" { |
||||||
|
caller = noUser |
||||||
|
} |
||||||
|
|
||||||
|
stats.Incr(caller, path[1:], strconv.FormatInt(int64(cerr.Code()), 10)) |
||||||
|
stats.Timing(caller, int64(dt/time.Millisecond), path[1:]) |
||||||
|
|
||||||
|
lf := log.Infov |
||||||
|
errmsg := "" |
||||||
|
isSlow := dt >= (time.Millisecond * 500) |
||||||
|
if err != nil { |
||||||
|
errmsg = err.Error() |
||||||
|
lf = log.Errorv |
||||||
|
if cerr.Code() > 0 { |
||||||
|
lf = log.Warnv |
||||||
|
} |
||||||
|
} else { |
||||||
|
if isSlow { |
||||||
|
lf = log.Warnv |
||||||
|
} |
||||||
|
} |
||||||
|
lf(c, |
||||||
|
log.KVString("method", req.Method), |
||||||
|
log.KVString("ip", ip), |
||||||
|
log.KVString("user", caller), |
||||||
|
log.KVString("path", path), |
||||||
|
log.KVString("params", params.Encode()), |
||||||
|
log.KVInt("ret", cerr.Code()), |
||||||
|
log.KVString("msg", cerr.Message()), |
||||||
|
log.KVString("stack", fmt.Sprintf("%+v", err)), |
||||||
|
log.KVString("err", errmsg), |
||||||
|
log.KVFloat64("timeout_quota", quota), |
||||||
|
log.KVFloat64("ts", dt.Seconds()), |
||||||
|
log.KVString("source", "http-access-log"), |
||||||
|
) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,46 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
"net/http" |
||||||
|
"net/http/pprof" |
||||||
|
"os" |
||||||
|
"sync" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/conf/dsn" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
_perfOnce sync.Once |
||||||
|
_perfDSN string |
||||||
|
) |
||||||
|
|
||||||
|
func init() { |
||||||
|
v := os.Getenv("HTTP_PERF") |
||||||
|
if v == "" { |
||||||
|
v = "tcp://0.0.0.0:2333" |
||||||
|
} |
||||||
|
flag.StringVar(&_perfDSN, "http.perf", v, "listen http perf dsn, or use HTTP_PERF env variable.") |
||||||
|
} |
||||||
|
|
||||||
|
func startPerf() { |
||||||
|
_perfOnce.Do(func() { |
||||||
|
mux := http.NewServeMux() |
||||||
|
mux.HandleFunc("/debug/pprof/", pprof.Index) |
||||||
|
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) |
||||||
|
mux.HandleFunc("/debug/pprof/profile", pprof.Profile) |
||||||
|
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) |
||||||
|
|
||||||
|
go func() { |
||||||
|
d, err := dsn.Parse(_perfDSN) |
||||||
|
if err != nil { |
||||||
|
panic(errors.Errorf("blademaster: http perf dsn must be tcp://$host:port, %s:error(%v)", _perfDSN, err)) |
||||||
|
} |
||||||
|
if err := http.ListenAndServe(d.Host, mux); err != nil { |
||||||
|
panic(errors.Errorf("blademaster: listen %s: error(%v)", d.Host, err)) |
||||||
|
} |
||||||
|
}() |
||||||
|
}) |
||||||
|
} |
@ -0,0 +1,12 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp" |
||||||
|
) |
||||||
|
|
||||||
|
func monitor() HandlerFunc { |
||||||
|
return func(c *Context) { |
||||||
|
h := promhttp.Handler() |
||||||
|
h.ServeHTTP(c.Writer, c.Request) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,32 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"net/http/httputil" |
||||||
|
"os" |
||||||
|
"runtime" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
) |
||||||
|
|
||||||
|
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
|
||||||
|
func Recovery() HandlerFunc { |
||||||
|
return func(c *Context) { |
||||||
|
defer func() { |
||||||
|
var rawReq []byte |
||||||
|
if err := recover(); err != nil { |
||||||
|
const size = 64 << 10 |
||||||
|
buf := make([]byte, size) |
||||||
|
buf = buf[:runtime.Stack(buf, false)] |
||||||
|
if c.Request != nil { |
||||||
|
rawReq, _ = httputil.DumpRequest(c.Request, false) |
||||||
|
} |
||||||
|
pl := fmt.Sprintf("http call panic: %s\n%v\n%s\n", string(rawReq), err, buf) |
||||||
|
fmt.Fprintf(os.Stderr, pl) |
||||||
|
log.Error(pl) |
||||||
|
c.AbortWithStatus(500) |
||||||
|
} |
||||||
|
}() |
||||||
|
c.Next() |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,30 @@ |
|||||||
|
package render |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
// Data common bytes struct.
|
||||||
|
type Data struct { |
||||||
|
ContentType string |
||||||
|
Data [][]byte |
||||||
|
} |
||||||
|
|
||||||
|
// Render (Data) writes data with custom ContentType.
|
||||||
|
func (r Data) Render(w http.ResponseWriter) (err error) { |
||||||
|
r.WriteContentType(w) |
||||||
|
for _, d := range r.Data { |
||||||
|
if _, err = w.Write(d); err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// WriteContentType writes data with custom ContentType.
|
||||||
|
func (r Data) WriteContentType(w http.ResponseWriter) { |
||||||
|
writeContentType(w, []string{r.ContentType}) |
||||||
|
} |
@ -0,0 +1,58 @@ |
|||||||
|
package render |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/json" |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var jsonContentType = []string{"application/json; charset=utf-8"} |
||||||
|
|
||||||
|
// JSON common json struct.
|
||||||
|
type JSON struct { |
||||||
|
Code int `json:"code"` |
||||||
|
Message string `json:"message"` |
||||||
|
TTL int `json:"ttl"` |
||||||
|
Data interface{} `json:"data,omitempty"` |
||||||
|
} |
||||||
|
|
||||||
|
func writeJSON(w http.ResponseWriter, obj interface{}) (err error) { |
||||||
|
var jsonBytes []byte |
||||||
|
writeContentType(w, jsonContentType) |
||||||
|
if jsonBytes, err = json.Marshal(obj); err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
return |
||||||
|
} |
||||||
|
if _, err = w.Write(jsonBytes); err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Render (JSON) writes data with json ContentType.
|
||||||
|
func (r JSON) Render(w http.ResponseWriter) error { |
||||||
|
// FIXME(zhoujiahui): the TTL field will be configurable in the future
|
||||||
|
if r.TTL <= 0 { |
||||||
|
r.TTL = 1 |
||||||
|
} |
||||||
|
return writeJSON(w, r) |
||||||
|
} |
||||||
|
|
||||||
|
// WriteContentType write json ContentType.
|
||||||
|
func (r JSON) WriteContentType(w http.ResponseWriter) { |
||||||
|
writeContentType(w, jsonContentType) |
||||||
|
} |
||||||
|
|
||||||
|
// MapJSON common map json struct.
|
||||||
|
type MapJSON map[string]interface{} |
||||||
|
|
||||||
|
// Render (MapJSON) writes data with json ContentType.
|
||||||
|
func (m MapJSON) Render(w http.ResponseWriter) error { |
||||||
|
return writeJSON(w, m) |
||||||
|
} |
||||||
|
|
||||||
|
// WriteContentType write json ContentType.
|
||||||
|
func (m MapJSON) WriteContentType(w http.ResponseWriter) { |
||||||
|
writeContentType(w, jsonContentType) |
||||||
|
} |
@ -0,0 +1,38 @@ |
|||||||
|
package render |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/gogo/protobuf/proto" |
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var pbContentType = []string{"application/x-protobuf"} |
||||||
|
|
||||||
|
// Render (PB) writes data with protobuf ContentType.
|
||||||
|
func (r PB) Render(w http.ResponseWriter) error { |
||||||
|
if r.TTL <= 0 { |
||||||
|
r.TTL = 1 |
||||||
|
} |
||||||
|
return writePB(w, r) |
||||||
|
} |
||||||
|
|
||||||
|
// WriteContentType write protobuf ContentType.
|
||||||
|
func (r PB) WriteContentType(w http.ResponseWriter) { |
||||||
|
writeContentType(w, pbContentType) |
||||||
|
} |
||||||
|
|
||||||
|
func writePB(w http.ResponseWriter, obj PB) (err error) { |
||||||
|
var pbBytes []byte |
||||||
|
writeContentType(w, pbContentType) |
||||||
|
|
||||||
|
if pbBytes, err = proto.Marshal(&obj); err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
if _, err = w.Write(pbBytes); err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
@ -0,0 +1,26 @@ |
|||||||
|
package render |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
// Redirect render for redirect to specified location.
|
||||||
|
type Redirect struct { |
||||||
|
Code int |
||||||
|
Request *http.Request |
||||||
|
Location string |
||||||
|
} |
||||||
|
|
||||||
|
// Render (Redirect) redirect to specified location.
|
||||||
|
func (r Redirect) Render(w http.ResponseWriter) error { |
||||||
|
if (r.Code < 300 || r.Code > 308) && r.Code != 201 { |
||||||
|
return errors.Errorf("Cannot redirect with status code %d", r.Code) |
||||||
|
} |
||||||
|
http.Redirect(w, r.Request, r.Location, r.Code) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// WriteContentType noneContentType.
|
||||||
|
func (r Redirect) WriteContentType(http.ResponseWriter) {} |
@ -0,0 +1,30 @@ |
|||||||
|
package render |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
) |
||||||
|
|
||||||
|
// Render http reponse render.
|
||||||
|
type Render interface { |
||||||
|
// Render render it to http response writer.
|
||||||
|
Render(http.ResponseWriter) error |
||||||
|
// WriteContentType write content-type to http response writer.
|
||||||
|
WriteContentType(w http.ResponseWriter) |
||||||
|
} |
||||||
|
|
||||||
|
var ( |
||||||
|
_ Render = JSON{} |
||||||
|
_ Render = MapJSON{} |
||||||
|
_ Render = XML{} |
||||||
|
_ Render = String{} |
||||||
|
_ Render = Redirect{} |
||||||
|
_ Render = Data{} |
||||||
|
_ Render = PB{} |
||||||
|
) |
||||||
|
|
||||||
|
func writeContentType(w http.ResponseWriter, value []string) { |
||||||
|
header := w.Header() |
||||||
|
if val := header["Content-Type"]; len(val) == 0 { |
||||||
|
header["Content-Type"] = value |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,89 @@ |
|||||||
|
// Code generated by protoc-gen-gogo. DO NOT EDIT.
|
||||||
|
// source: pb.proto
|
||||||
|
|
||||||
|
/* |
||||||
|
Package render is a generated protocol buffer package. |
||||||
|
|
||||||
|
It is generated from these files: |
||||||
|
pb.proto |
||||||
|
|
||||||
|
It has these top-level messages: |
||||||
|
PB |
||||||
|
*/ |
||||||
|
package render |
||||||
|
|
||||||
|
import proto "github.com/gogo/protobuf/proto" |
||||||
|
import fmt "fmt" |
||||||
|
import math "math" |
||||||
|
import google_protobuf "github.com/gogo/protobuf/types" |
||||||
|
|
||||||
|
// Reference imports to suppress errors if they are not otherwise used.
|
||||||
|
var _ = proto.Marshal |
||||||
|
var _ = fmt.Errorf |
||||||
|
var _ = math.Inf |
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file
|
||||||
|
// is compatible with the proto package it is being compiled against.
|
||||||
|
// A compilation error at this line likely means your copy of the
|
||||||
|
// proto package needs to be updated.
|
||||||
|
const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package
|
||||||
|
|
||||||
|
type PB struct { |
||||||
|
Code int64 `protobuf:"varint,1,opt,name=Code,proto3" json:"Code,omitempty"` |
||||||
|
Message string `protobuf:"bytes,2,opt,name=Message,proto3" json:"Message,omitempty"` |
||||||
|
TTL uint64 `protobuf:"varint,3,opt,name=TTL,proto3" json:"TTL,omitempty"` |
||||||
|
Data *google_protobuf.Any `protobuf:"bytes,4,opt,name=Data" json:"Data,omitempty"` |
||||||
|
} |
||||||
|
|
||||||
|
func (m *PB) Reset() { *m = PB{} } |
||||||
|
func (m *PB) String() string { return proto.CompactTextString(m) } |
||||||
|
func (*PB) ProtoMessage() {} |
||||||
|
func (*PB) Descriptor() ([]byte, []int) { return fileDescriptorPb, []int{0} } |
||||||
|
|
||||||
|
func (m *PB) GetCode() int64 { |
||||||
|
if m != nil { |
||||||
|
return m.Code |
||||||
|
} |
||||||
|
return 0 |
||||||
|
} |
||||||
|
|
||||||
|
func (m *PB) GetMessage() string { |
||||||
|
if m != nil { |
||||||
|
return m.Message |
||||||
|
} |
||||||
|
return "" |
||||||
|
} |
||||||
|
|
||||||
|
func (m *PB) GetTTL() uint64 { |
||||||
|
if m != nil { |
||||||
|
return m.TTL |
||||||
|
} |
||||||
|
return 0 |
||||||
|
} |
||||||
|
|
||||||
|
func (m *PB) GetData() *google_protobuf.Any { |
||||||
|
if m != nil { |
||||||
|
return m.Data |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func init() { |
||||||
|
proto.RegisterType((*PB)(nil), "render.PB") |
||||||
|
} |
||||||
|
|
||||||
|
func init() { proto.RegisterFile("pb.proto", fileDescriptorPb) } |
||||||
|
|
||||||
|
var fileDescriptorPb = []byte{ |
||||||
|
// 154 bytes of a gzipped FileDescriptorProto
|
||||||
|
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x28, 0x48, 0xd2, 0x2b, |
||||||
|
0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2b, 0x4a, 0xcd, 0x4b, 0x49, 0x2d, 0x92, 0x92, 0x4c, 0xcf, |
||||||
|
0xcf, 0x4f, 0xcf, 0x49, 0xd5, 0x07, 0x8b, 0x26, 0x95, 0xa6, 0xe9, 0x27, 0xe6, 0x55, 0x42, 0x94, |
||||||
|
0x28, 0xe5, 0x71, 0x31, 0x05, 0x38, 0x09, 0x09, 0x71, 0xb1, 0x38, 0xe7, 0xa7, 0xa4, 0x4a, 0x30, |
||||||
|
0x2a, 0x30, 0x6a, 0x30, 0x07, 0x81, 0xd9, 0x42, 0x12, 0x5c, 0xec, 0xbe, 0xa9, 0xc5, 0xc5, 0x89, |
||||||
|
0xe9, 0xa9, 0x12, 0x4c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x30, 0xae, 0x90, 0x00, 0x17, 0x73, 0x48, |
||||||
|
0x88, 0x8f, 0x04, 0xb3, 0x02, 0xa3, 0x06, 0x4b, 0x10, 0x88, 0x29, 0xa4, 0xc1, 0xc5, 0xe2, 0x92, |
||||||
|
0x58, 0x92, 0x28, 0xc1, 0xa2, 0xc0, 0xa8, 0xc1, 0x6d, 0x24, 0xa2, 0x07, 0xb1, 0x4f, 0x0f, 0x66, |
||||||
|
0x9f, 0x9e, 0x63, 0x5e, 0x65, 0x10, 0x58, 0x45, 0x12, 0x1b, 0x58, 0xcc, 0x18, 0x10, 0x00, 0x00, |
||||||
|
0xff, 0xff, 0x7a, 0x92, 0x16, 0x71, 0xa5, 0x00, 0x00, 0x00, |
||||||
|
} |
@ -0,0 +1,14 @@ |
|||||||
|
// use under command to generate pb.pb.go |
||||||
|
// protoc --proto_path=.:$GOPATH/src/github.com/gogo/protobuf --gogo_out=Mgoogle/protobuf/any.proto=github.com/gogo/protobuf/types:. *.proto |
||||||
|
syntax = "proto3"; |
||||||
|
package render; |
||||||
|
|
||||||
|
import "google/protobuf/any.proto"; |
||||||
|
import "github.com/gogo/protobuf/gogoproto/gogo.proto"; |
||||||
|
|
||||||
|
message PB { |
||||||
|
int64 Code = 1; |
||||||
|
string Message = 2; |
||||||
|
uint64 TTL = 3; |
||||||
|
google.protobuf.Any Data = 4; |
||||||
|
} |
@ -0,0 +1,40 @@ |
|||||||
|
package render |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var plainContentType = []string{"text/plain; charset=utf-8"} |
||||||
|
|
||||||
|
// String common string struct.
|
||||||
|
type String struct { |
||||||
|
Format string |
||||||
|
Data []interface{} |
||||||
|
} |
||||||
|
|
||||||
|
// Render (String) writes data with custom ContentType.
|
||||||
|
func (r String) Render(w http.ResponseWriter) error { |
||||||
|
return writeString(w, r.Format, r.Data) |
||||||
|
} |
||||||
|
|
||||||
|
// WriteContentType writes string with text/plain ContentType.
|
||||||
|
func (r String) WriteContentType(w http.ResponseWriter) { |
||||||
|
writeContentType(w, plainContentType) |
||||||
|
} |
||||||
|
|
||||||
|
func writeString(w http.ResponseWriter, format string, data []interface{}) (err error) { |
||||||
|
writeContentType(w, plainContentType) |
||||||
|
if len(data) > 0 { |
||||||
|
_, err = fmt.Fprintf(w, format, data...) |
||||||
|
} else { |
||||||
|
_, err = io.WriteString(w, format) |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
@ -0,0 +1,31 @@ |
|||||||
|
package render |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/xml" |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
// XML common xml struct.
|
||||||
|
type XML struct { |
||||||
|
Code int |
||||||
|
Message string |
||||||
|
Data interface{} |
||||||
|
} |
||||||
|
|
||||||
|
var xmlContentType = []string{"application/xml; charset=utf-8"} |
||||||
|
|
||||||
|
// Render (XML) writes data with xml ContentType.
|
||||||
|
func (r XML) Render(w http.ResponseWriter) (err error) { |
||||||
|
r.WriteContentType(w) |
||||||
|
if err = xml.NewEncoder(w).Encode(r.Data); err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// WriteContentType write xml ContentType.
|
||||||
|
func (r XML) WriteContentType(w http.ResponseWriter) { |
||||||
|
writeContentType(w, xmlContentType) |
||||||
|
} |
@ -0,0 +1,166 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"regexp" |
||||||
|
) |
||||||
|
|
||||||
|
// IRouter http router framework interface.
|
||||||
|
type IRouter interface { |
||||||
|
IRoutes |
||||||
|
Group(string, ...HandlerFunc) *RouterGroup |
||||||
|
} |
||||||
|
|
||||||
|
// IRoutes http router interface.
|
||||||
|
type IRoutes interface { |
||||||
|
UseFunc(...HandlerFunc) IRoutes |
||||||
|
Use(...Handler) IRoutes |
||||||
|
|
||||||
|
Handle(string, string, ...HandlerFunc) IRoutes |
||||||
|
HEAD(string, ...HandlerFunc) IRoutes |
||||||
|
GET(string, ...HandlerFunc) IRoutes |
||||||
|
POST(string, ...HandlerFunc) IRoutes |
||||||
|
PUT(string, ...HandlerFunc) IRoutes |
||||||
|
DELETE(string, ...HandlerFunc) IRoutes |
||||||
|
} |
||||||
|
|
||||||
|
// RouterGroup is used internally to configure router, a RouterGroup is associated with a prefix
|
||||||
|
// and an array of handlers (middleware).
|
||||||
|
type RouterGroup struct { |
||||||
|
Handlers []HandlerFunc |
||||||
|
basePath string |
||||||
|
engine *Engine |
||||||
|
root bool |
||||||
|
baseConfig *MethodConfig |
||||||
|
} |
||||||
|
|
||||||
|
var _ IRouter = &RouterGroup{} |
||||||
|
|
||||||
|
// Use adds middleware to the group, see example code in doc.
|
||||||
|
func (group *RouterGroup) Use(middleware ...Handler) IRoutes { |
||||||
|
for _, m := range middleware { |
||||||
|
group.Handlers = append(group.Handlers, m.ServeHTTP) |
||||||
|
} |
||||||
|
return group.returnObj() |
||||||
|
} |
||||||
|
|
||||||
|
// UseFunc adds middleware to the group, see example code in doc.
|
||||||
|
func (group *RouterGroup) UseFunc(middleware ...HandlerFunc) IRoutes { |
||||||
|
group.Handlers = append(group.Handlers, middleware...) |
||||||
|
return group.returnObj() |
||||||
|
} |
||||||
|
|
||||||
|
// Group creates a new router group. You should add all the routes that have common middlwares or the same path prefix.
|
||||||
|
// For example, all the routes that use a common middlware for authorization could be grouped.
|
||||||
|
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup { |
||||||
|
return &RouterGroup{ |
||||||
|
Handlers: group.combineHandlers(handlers), |
||||||
|
basePath: group.calculateAbsolutePath(relativePath), |
||||||
|
engine: group.engine, |
||||||
|
root: false, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// SetMethodConfig is used to set config on specified method
|
||||||
|
func (group *RouterGroup) SetMethodConfig(config *MethodConfig) *RouterGroup { |
||||||
|
group.baseConfig = config |
||||||
|
return group |
||||||
|
} |
||||||
|
|
||||||
|
// BasePath router group base path.
|
||||||
|
func (group *RouterGroup) BasePath() string { |
||||||
|
return group.basePath |
||||||
|
} |
||||||
|
|
||||||
|
func (group *RouterGroup) handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes { |
||||||
|
absolutePath := group.calculateAbsolutePath(relativePath) |
||||||
|
injections := group.injections(relativePath) |
||||||
|
handlers = group.combineHandlers(injections, handlers) |
||||||
|
group.engine.addRoute(httpMethod, absolutePath, handlers...) |
||||||
|
if group.baseConfig != nil { |
||||||
|
group.engine.SetMethodConfig(absolutePath, group.baseConfig) |
||||||
|
} |
||||||
|
return group.returnObj() |
||||||
|
} |
||||||
|
|
||||||
|
// Handle registers a new request handle and middleware with the given path and method.
|
||||||
|
// The last handler should be the real handler, the other ones should be middleware that can and should be shared among different routes.
|
||||||
|
// See the example code in doc.
|
||||||
|
//
|
||||||
|
// For HEAD, GET, POST, PUT, and DELETE requests the respective shortcut
|
||||||
|
// functions can be used.
|
||||||
|
//
|
||||||
|
// This function is intended for bulk loading and to allow the usage of less
|
||||||
|
// frequently used, non-standardized or custom methods (e.g. for internal
|
||||||
|
// communication with a proxy).
|
||||||
|
func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes { |
||||||
|
if matches, err := regexp.MatchString("^[A-Z]+$", httpMethod); !matches || err != nil { |
||||||
|
panic("http method " + httpMethod + " is not valid") |
||||||
|
} |
||||||
|
return group.handle(httpMethod, relativePath, handlers...) |
||||||
|
} |
||||||
|
|
||||||
|
// HEAD is a shortcut for router.Handle("HEAD", path, handle).
|
||||||
|
func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||||
|
return group.handle("HEAD", relativePath, handlers...) |
||||||
|
} |
||||||
|
|
||||||
|
// GET is a shortcut for router.Handle("GET", path, handle).
|
||||||
|
func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||||
|
return group.handle("GET", relativePath, handlers...) |
||||||
|
} |
||||||
|
|
||||||
|
// POST is a shortcut for router.Handle("POST", path, handle).
|
||||||
|
func (group *RouterGroup) POST(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||||
|
return group.handle("POST", relativePath, handlers...) |
||||||
|
} |
||||||
|
|
||||||
|
// PUT is a shortcut for router.Handle("PUT", path, handle).
|
||||||
|
func (group *RouterGroup) PUT(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||||
|
return group.handle("PUT", relativePath, handlers...) |
||||||
|
} |
||||||
|
|
||||||
|
// DELETE is a shortcut for router.Handle("DELETE", path, handle).
|
||||||
|
func (group *RouterGroup) DELETE(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||||
|
return group.handle("DELETE", relativePath, handlers...) |
||||||
|
} |
||||||
|
|
||||||
|
func (group *RouterGroup) combineHandlers(handlerGroups ...[]HandlerFunc) []HandlerFunc { |
||||||
|
finalSize := len(group.Handlers) |
||||||
|
for _, handlers := range handlerGroups { |
||||||
|
finalSize += len(handlers) |
||||||
|
} |
||||||
|
if finalSize >= int(_abortIndex) { |
||||||
|
panic("too many handlers") |
||||||
|
} |
||||||
|
mergedHandlers := make([]HandlerFunc, finalSize) |
||||||
|
copy(mergedHandlers, group.Handlers) |
||||||
|
position := len(group.Handlers) |
||||||
|
for _, handlers := range handlerGroups { |
||||||
|
copy(mergedHandlers[position:], handlers) |
||||||
|
position += len(handlers) |
||||||
|
} |
||||||
|
return mergedHandlers |
||||||
|
} |
||||||
|
|
||||||
|
func (group *RouterGroup) calculateAbsolutePath(relativePath string) string { |
||||||
|
return joinPaths(group.basePath, relativePath) |
||||||
|
} |
||||||
|
|
||||||
|
func (group *RouterGroup) returnObj() IRoutes { |
||||||
|
if group.root { |
||||||
|
return group.engine |
||||||
|
} |
||||||
|
return group |
||||||
|
} |
||||||
|
|
||||||
|
// injections is
|
||||||
|
func (group *RouterGroup) injections(relativePath string) []HandlerFunc { |
||||||
|
absPath := group.calculateAbsolutePath(relativePath) |
||||||
|
for _, injection := range group.engine.injections { |
||||||
|
if !injection.pattern.MatchString(absPath) { |
||||||
|
continue |
||||||
|
} |
||||||
|
return injection.handlers |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,405 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"flag" |
||||||
|
"fmt" |
||||||
|
"net" |
||||||
|
"net/http" |
||||||
|
"os" |
||||||
|
"regexp" |
||||||
|
"strings" |
||||||
|
"sync" |
||||||
|
"sync/atomic" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/kratos/pkg/conf/dsn" |
||||||
|
"github.com/bilibili/kratos/pkg/log" |
||||||
|
"github.com/bilibili/kratos/pkg/net/ip" |
||||||
|
"github.com/bilibili/kratos/pkg/net/metadata" |
||||||
|
"github.com/bilibili/kratos/pkg/stat" |
||||||
|
xtime "github.com/bilibili/kratos/pkg/time" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
defaultMaxMemory = 32 << 20 // 32 MB
|
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
_ IRouter = &Engine{} |
||||||
|
stats = stat.HTTPServer |
||||||
|
|
||||||
|
_httpDSN string |
||||||
|
) |
||||||
|
|
||||||
|
func init() { |
||||||
|
addFlag(flag.CommandLine) |
||||||
|
} |
||||||
|
|
||||||
|
func addFlag(fs *flag.FlagSet) { |
||||||
|
v := os.Getenv("HTTP") |
||||||
|
if v == "" { |
||||||
|
v = "tcp://0.0.0.0:8000/?timeout=1s" |
||||||
|
} |
||||||
|
fs.StringVar(&_httpDSN, "http", v, "listen http dsn, or use HTTP env variable.") |
||||||
|
} |
||||||
|
|
||||||
|
func parseDSN(rawdsn string) *ServerConfig { |
||||||
|
conf := new(ServerConfig) |
||||||
|
d, err := dsn.Parse(rawdsn) |
||||||
|
if err != nil { |
||||||
|
panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn)) |
||||||
|
} |
||||||
|
if _, err = d.Bind(conf); err != nil { |
||||||
|
panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn)) |
||||||
|
} |
||||||
|
return conf |
||||||
|
} |
||||||
|
|
||||||
|
// Handler responds to an HTTP request.
|
||||||
|
type Handler interface { |
||||||
|
ServeHTTP(c *Context) |
||||||
|
} |
||||||
|
|
||||||
|
// HandlerFunc http request handler function.
|
||||||
|
type HandlerFunc func(*Context) |
||||||
|
|
||||||
|
// ServeHTTP calls f(ctx).
|
||||||
|
func (f HandlerFunc) ServeHTTP(c *Context) { |
||||||
|
f(c) |
||||||
|
} |
||||||
|
|
||||||
|
// ServerConfig is the bm server config model
|
||||||
|
type ServerConfig struct { |
||||||
|
Network string `dsn:"network"` |
||||||
|
Addr string `dsn:"address"` |
||||||
|
Timeout xtime.Duration `dsn:"query.timeout"` |
||||||
|
ReadTimeout xtime.Duration `dsn:"query.readTimeout"` |
||||||
|
WriteTimeout xtime.Duration `dsn:"query.writeTimeout"` |
||||||
|
} |
||||||
|
|
||||||
|
// MethodConfig is
|
||||||
|
type MethodConfig struct { |
||||||
|
Timeout xtime.Duration |
||||||
|
} |
||||||
|
|
||||||
|
// Start listen and serve bm engine by given DSN.
|
||||||
|
func (engine *Engine) Start() error { |
||||||
|
conf := engine.conf |
||||||
|
l, err := net.Listen(conf.Network, conf.Addr) |
||||||
|
if err != nil { |
||||||
|
errors.Wrapf(err, "blademaster: listen tcp: %s", conf.Addr) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
log.Info("blademaster: start http listen addr: %s", conf.Addr) |
||||||
|
server := &http.Server{ |
||||||
|
ReadTimeout: time.Duration(conf.ReadTimeout), |
||||||
|
WriteTimeout: time.Duration(conf.WriteTimeout), |
||||||
|
} |
||||||
|
go func() { |
||||||
|
if err := engine.RunServer(server, l); err != nil { |
||||||
|
if errors.Cause(err) == http.ErrServerClosed { |
||||||
|
log.Info("blademaster: server closed") |
||||||
|
return |
||||||
|
} |
||||||
|
panic(errors.Wrapf(err, "blademaster: engine.ListenServer(%+v, %+v)", server, l)) |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// Engine is the framework's instance, it contains the muxer, middleware and configuration settings.
|
||||||
|
// Create an instance of Engine, by using New() or Default()
|
||||||
|
type Engine struct { |
||||||
|
RouterGroup |
||||||
|
|
||||||
|
lock sync.RWMutex |
||||||
|
conf *ServerConfig |
||||||
|
|
||||||
|
address string |
||||||
|
|
||||||
|
mux *http.ServeMux // http mux router
|
||||||
|
server atomic.Value // store *http.Server
|
||||||
|
metastore map[string]map[string]interface{} // metastore is the path as key and the metadata of this path as value, it export via /metadata
|
||||||
|
|
||||||
|
pcLock sync.RWMutex |
||||||
|
methodConfigs map[string]*MethodConfig |
||||||
|
|
||||||
|
injections []injection |
||||||
|
} |
||||||
|
|
||||||
|
type injection struct { |
||||||
|
pattern *regexp.Regexp |
||||||
|
handlers []HandlerFunc |
||||||
|
} |
||||||
|
|
||||||
|
// NewServer returns a new blank Engine instance without any middleware attached.
|
||||||
|
func NewServer(conf *ServerConfig) *Engine { |
||||||
|
if conf == nil { |
||||||
|
if !flag.Parsed() { |
||||||
|
fmt.Fprint(os.Stderr, "[blademaster] please call flag.Parse() before Init warden server, some configure may not effect.\n") |
||||||
|
} |
||||||
|
conf = parseDSN(_httpDSN) |
||||||
|
} |
||||||
|
engine := &Engine{ |
||||||
|
RouterGroup: RouterGroup{ |
||||||
|
Handlers: nil, |
||||||
|
basePath: "/", |
||||||
|
root: true, |
||||||
|
}, |
||||||
|
address: ip.InternalIP(), |
||||||
|
mux: http.NewServeMux(), |
||||||
|
metastore: make(map[string]map[string]interface{}), |
||||||
|
methodConfigs: make(map[string]*MethodConfig), |
||||||
|
} |
||||||
|
if err := engine.SetConfig(conf); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
engine.RouterGroup.engine = engine |
||||||
|
// NOTE add prometheus monitor location
|
||||||
|
engine.addRoute("GET", "/metrics", monitor()) |
||||||
|
engine.addRoute("GET", "/metadata", engine.metadata()) |
||||||
|
startPerf() |
||||||
|
return engine |
||||||
|
} |
||||||
|
|
||||||
|
// SetMethodConfig is used to set config on specified path
|
||||||
|
func (engine *Engine) SetMethodConfig(path string, mc *MethodConfig) { |
||||||
|
engine.pcLock.Lock() |
||||||
|
engine.methodConfigs[path] = mc |
||||||
|
engine.pcLock.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
// DefaultServer returns an Engine instance with the Recovery, Logger and CSRF middleware already attached.
|
||||||
|
func DefaultServer(conf *ServerConfig) *Engine { |
||||||
|
engine := NewServer(conf) |
||||||
|
engine.Use(Recovery(), Trace(), Logger()) |
||||||
|
return engine |
||||||
|
} |
||||||
|
|
||||||
|
func (engine *Engine) addRoute(method, path string, handlers ...HandlerFunc) { |
||||||
|
if path[0] != '/' { |
||||||
|
panic("blademaster: path must begin with '/'") |
||||||
|
} |
||||||
|
if method == "" { |
||||||
|
panic("blademaster: HTTP method can not be empty") |
||||||
|
} |
||||||
|
if len(handlers) == 0 { |
||||||
|
panic("blademaster: there must be at least one handler") |
||||||
|
} |
||||||
|
if _, ok := engine.metastore[path]; !ok { |
||||||
|
engine.metastore[path] = make(map[string]interface{}) |
||||||
|
} |
||||||
|
engine.metastore[path]["method"] = method |
||||||
|
engine.mux.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) { |
||||||
|
c := &Context{ |
||||||
|
Context: nil, |
||||||
|
engine: engine, |
||||||
|
index: -1, |
||||||
|
handlers: nil, |
||||||
|
Keys: nil, |
||||||
|
method: "", |
||||||
|
Error: nil, |
||||||
|
} |
||||||
|
|
||||||
|
c.Request = req |
||||||
|
c.Writer = w |
||||||
|
c.handlers = handlers |
||||||
|
c.method = method |
||||||
|
|
||||||
|
engine.handleContext(c) |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
// SetConfig is used to set the engine configuration.
|
||||||
|
// Only the valid config will be loaded.
|
||||||
|
func (engine *Engine) SetConfig(conf *ServerConfig) (err error) { |
||||||
|
if conf.Timeout <= 0 { |
||||||
|
return errors.New("blademaster: config timeout must greater than 0") |
||||||
|
} |
||||||
|
if conf.Network == "" { |
||||||
|
conf.Network = "tcp" |
||||||
|
} |
||||||
|
engine.lock.Lock() |
||||||
|
engine.conf = conf |
||||||
|
engine.lock.Unlock() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (engine *Engine) methodConfig(path string) *MethodConfig { |
||||||
|
engine.pcLock.RLock() |
||||||
|
mc := engine.methodConfigs[path] |
||||||
|
engine.pcLock.RUnlock() |
||||||
|
return mc |
||||||
|
} |
||||||
|
|
||||||
|
func (engine *Engine) handleContext(c *Context) { |
||||||
|
var cancel func() |
||||||
|
req := c.Request |
||||||
|
ctype := req.Header.Get("Content-Type") |
||||||
|
switch { |
||||||
|
case strings.Contains(ctype, "multipart/form-data"): |
||||||
|
req.ParseMultipartForm(defaultMaxMemory) |
||||||
|
default: |
||||||
|
req.ParseForm() |
||||||
|
} |
||||||
|
// get derived timeout from http request header,
|
||||||
|
// compare with the engine configured,
|
||||||
|
// and use the minimum one
|
||||||
|
engine.lock.RLock() |
||||||
|
tm := time.Duration(engine.conf.Timeout) |
||||||
|
engine.lock.RUnlock() |
||||||
|
// the method config is preferred
|
||||||
|
if pc := engine.methodConfig(c.Request.URL.Path); pc != nil { |
||||||
|
tm = time.Duration(pc.Timeout) |
||||||
|
} |
||||||
|
if ctm := timeout(req); ctm > 0 && tm > ctm { |
||||||
|
tm = ctm |
||||||
|
} |
||||||
|
md := metadata.MD{ |
||||||
|
metadata.Color: color(req), |
||||||
|
metadata.RemoteIP: remoteIP(req), |
||||||
|
metadata.RemotePort: remotePort(req), |
||||||
|
metadata.Caller: caller(req), |
||||||
|
metadata.Mirror: mirror(req), |
||||||
|
} |
||||||
|
ctx := metadata.NewContext(context.Background(), md) |
||||||
|
if tm > 0 { |
||||||
|
c.Context, cancel = context.WithTimeout(ctx, tm) |
||||||
|
} else { |
||||||
|
c.Context, cancel = context.WithCancel(ctx) |
||||||
|
} |
||||||
|
defer cancel() |
||||||
|
c.Next() |
||||||
|
} |
||||||
|
|
||||||
|
// Router return a http.Handler for using http.ListenAndServe() directly.
|
||||||
|
func (engine *Engine) Router() http.Handler { |
||||||
|
return engine.mux |
||||||
|
} |
||||||
|
|
||||||
|
// Server is used to load stored http server.
|
||||||
|
func (engine *Engine) Server() *http.Server { |
||||||
|
s, ok := engine.server.Load().(*http.Server) |
||||||
|
if !ok { |
||||||
|
return nil |
||||||
|
} |
||||||
|
return s |
||||||
|
} |
||||||
|
|
||||||
|
// Shutdown the http server without interrupting active connections.
|
||||||
|
func (engine *Engine) Shutdown(ctx context.Context) error { |
||||||
|
server := engine.Server() |
||||||
|
if server == nil { |
||||||
|
return errors.New("blademaster: no server") |
||||||
|
} |
||||||
|
return errors.WithStack(server.Shutdown(ctx)) |
||||||
|
} |
||||||
|
|
||||||
|
// UseFunc attachs a global middleware to the router. ie. the middleware attached though UseFunc() will be
|
||||||
|
// included in the handlers chain for every single request. Even 404, 405, static files...
|
||||||
|
// For example, this is the right place for a logger or error management middleware.
|
||||||
|
func (engine *Engine) UseFunc(middleware ...HandlerFunc) IRoutes { |
||||||
|
engine.RouterGroup.UseFunc(middleware...) |
||||||
|
return engine |
||||||
|
} |
||||||
|
|
||||||
|
// Use attachs a global middleware to the router. ie. the middleware attached though Use() will be
|
||||||
|
// included in the handlers chain for every single request. Even 404, 405, static files...
|
||||||
|
// For example, this is the right place for a logger or error management middleware.
|
||||||
|
func (engine *Engine) Use(middleware ...Handler) IRoutes { |
||||||
|
engine.RouterGroup.Use(middleware...) |
||||||
|
return engine |
||||||
|
} |
||||||
|
|
||||||
|
// Ping is used to set the general HTTP ping handler.
|
||||||
|
func (engine *Engine) Ping(handler HandlerFunc) { |
||||||
|
engine.GET("/ping", handler) |
||||||
|
} |
||||||
|
|
||||||
|
// Register is used to export metadata to discovery.
|
||||||
|
func (engine *Engine) Register(handler HandlerFunc) { |
||||||
|
engine.GET("/register", handler) |
||||||
|
} |
||||||
|
|
||||||
|
// Run attaches the router to a http.Server and starts listening and serving HTTP requests.
|
||||||
|
// It is a shortcut for http.ListenAndServe(addr, router)
|
||||||
|
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||||
|
func (engine *Engine) Run(addr ...string) (err error) { |
||||||
|
address := resolveAddress(addr) |
||||||
|
server := &http.Server{ |
||||||
|
Addr: address, |
||||||
|
Handler: engine.mux, |
||||||
|
} |
||||||
|
engine.server.Store(server) |
||||||
|
if err = server.ListenAndServe(); err != nil { |
||||||
|
err = errors.Wrapf(err, "addrs: %v", addr) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests.
|
||||||
|
// It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router)
|
||||||
|
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||||
|
func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) { |
||||||
|
server := &http.Server{ |
||||||
|
Addr: addr, |
||||||
|
Handler: engine.mux, |
||||||
|
} |
||||||
|
engine.server.Store(server) |
||||||
|
if err = server.ListenAndServeTLS(certFile, keyFile); err != nil { |
||||||
|
err = errors.Wrapf(err, "tls: %s/%s:%s", addr, certFile, keyFile) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// RunUnix attaches the router to a http.Server and starts listening and serving HTTP requests
|
||||||
|
// through the specified unix socket (ie. a file).
|
||||||
|
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||||
|
func (engine *Engine) RunUnix(file string) (err error) { |
||||||
|
os.Remove(file) |
||||||
|
listener, err := net.Listen("unix", file) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "unix: %s", file) |
||||||
|
return |
||||||
|
} |
||||||
|
defer listener.Close() |
||||||
|
server := &http.Server{ |
||||||
|
Handler: engine.mux, |
||||||
|
} |
||||||
|
engine.server.Store(server) |
||||||
|
if err = server.Serve(listener); err != nil { |
||||||
|
err = errors.Wrapf(err, "unix: %s", file) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// RunServer will serve and start listening HTTP requests by given server and listener.
|
||||||
|
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||||
|
func (engine *Engine) RunServer(server *http.Server, l net.Listener) (err error) { |
||||||
|
server.Handler = engine.mux |
||||||
|
engine.server.Store(server) |
||||||
|
if err = server.Serve(l); err != nil { |
||||||
|
err = errors.Wrapf(err, "listen server: %+v/%+v", server, l) |
||||||
|
return |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (engine *Engine) metadata() HandlerFunc { |
||||||
|
return func(c *Context) { |
||||||
|
c.JSON(engine.metastore, nil) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Inject is
|
||||||
|
func (engine *Engine) Inject(pattern string, handlers ...HandlerFunc) { |
||||||
|
engine.injections = append(engine.injections, injection{ |
||||||
|
pattern: regexp.MustCompile(pattern), |
||||||
|
handlers: handlers, |
||||||
|
}) |
||||||
|
} |
@ -0,0 +1,40 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"os" |
||||||
|
"path" |
||||||
|
) |
||||||
|
|
||||||
|
func lastChar(str string) uint8 { |
||||||
|
if str == "" { |
||||||
|
panic("The length of the string can't be 0") |
||||||
|
} |
||||||
|
return str[len(str)-1] |
||||||
|
} |
||||||
|
|
||||||
|
func joinPaths(absolutePath, relativePath string) string { |
||||||
|
if relativePath == "" { |
||||||
|
return absolutePath |
||||||
|
} |
||||||
|
|
||||||
|
finalPath := path.Join(absolutePath, relativePath) |
||||||
|
appendSlash := lastChar(relativePath) == '/' && lastChar(finalPath) != '/' |
||||||
|
if appendSlash { |
||||||
|
return finalPath + "/" |
||||||
|
} |
||||||
|
return finalPath |
||||||
|
} |
||||||
|
|
||||||
|
func resolveAddress(addr []string) string { |
||||||
|
switch len(addr) { |
||||||
|
case 0: |
||||||
|
if port := os.Getenv("PORT"); port != "" { |
||||||
|
return ":" + port |
||||||
|
} |
||||||
|
return ":8080" |
||||||
|
case 1: |
||||||
|
return addr[0] |
||||||
|
default: |
||||||
|
panic("too much parameters") |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,84 @@ |
|||||||
|
package mocktrace |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/bilibili/kratos/pkg/net/trace" |
||||||
|
) |
||||||
|
|
||||||
|
// MockTrace .
|
||||||
|
type MockTrace struct { |
||||||
|
Spans []*MockSpan |
||||||
|
} |
||||||
|
|
||||||
|
// New .
|
||||||
|
func (m *MockTrace) New(operationName string, opts ...trace.Option) trace.Trace { |
||||||
|
span := &MockSpan{OperationName: operationName, MockTrace: m} |
||||||
|
m.Spans = append(m.Spans, span) |
||||||
|
return span |
||||||
|
} |
||||||
|
|
||||||
|
// Inject .
|
||||||
|
func (m *MockTrace) Inject(t trace.Trace, format interface{}, carrier interface{}) error { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// Extract .
|
||||||
|
func (m *MockTrace) Extract(format interface{}, carrier interface{}) (trace.Trace, error) { |
||||||
|
return &MockSpan{}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// MockSpan .
|
||||||
|
type MockSpan struct { |
||||||
|
*MockTrace |
||||||
|
OperationName string |
||||||
|
FinishErr error |
||||||
|
Finished bool |
||||||
|
Tags []trace.Tag |
||||||
|
Logs []trace.LogField |
||||||
|
} |
||||||
|
|
||||||
|
// Fork .
|
||||||
|
func (m *MockSpan) Fork(serviceName string, operationName string) trace.Trace { |
||||||
|
span := &MockSpan{OperationName: operationName, MockTrace: m.MockTrace} |
||||||
|
m.Spans = append(m.Spans, span) |
||||||
|
return span |
||||||
|
} |
||||||
|
|
||||||
|
// Follow .
|
||||||
|
func (m *MockSpan) Follow(serviceName string, operationName string) trace.Trace { |
||||||
|
span := &MockSpan{OperationName: operationName, MockTrace: m.MockTrace} |
||||||
|
m.Spans = append(m.Spans, span) |
||||||
|
return span |
||||||
|
} |
||||||
|
|
||||||
|
// Finish .
|
||||||
|
func (m *MockSpan) Finish(perr *error) { |
||||||
|
if perr != nil { |
||||||
|
m.FinishErr = *perr |
||||||
|
} |
||||||
|
m.Finished = true |
||||||
|
} |
||||||
|
|
||||||
|
// SetTag .
|
||||||
|
func (m *MockSpan) SetTag(tags ...trace.Tag) trace.Trace { |
||||||
|
m.Tags = append(m.Tags, tags...) |
||||||
|
return m |
||||||
|
} |
||||||
|
|
||||||
|
// SetLog .
|
||||||
|
func (m *MockSpan) SetLog(logs ...trace.LogField) trace.Trace { |
||||||
|
m.Logs = append(m.Logs, logs...) |
||||||
|
return m |
||||||
|
} |
||||||
|
|
||||||
|
// Visit .
|
||||||
|
func (m *MockSpan) Visit(fn func(k, v string)) {} |
||||||
|
|
||||||
|
// SetTitle .
|
||||||
|
func (m *MockSpan) SetTitle(title string) { |
||||||
|
m.OperationName = title |
||||||
|
} |
||||||
|
|
||||||
|
// TraceID .
|
||||||
|
func (m *MockSpan) TraceID() string { |
||||||
|
return "" |
||||||
|
} |
@ -0,0 +1,24 @@ |
|||||||
|
// Package mocktrace this ut just make ci happay.
|
||||||
|
package mocktrace |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
func TestMockTrace(t *testing.T) { |
||||||
|
mocktrace := &MockTrace{} |
||||||
|
mocktrace.Inject(nil, nil, nil) |
||||||
|
mocktrace.Extract(nil, nil) |
||||||
|
|
||||||
|
root := mocktrace.New("test") |
||||||
|
root.Fork("", "") |
||||||
|
root.Follow("", "") |
||||||
|
root.Finish(nil) |
||||||
|
err := fmt.Errorf("test") |
||||||
|
root.Finish(&err) |
||||||
|
root.SetTag() |
||||||
|
root.SetLog() |
||||||
|
root.Visit(func(k, v string) {}) |
||||||
|
root.SetTitle("") |
||||||
|
} |
Loading…
Reference in new issue