commit
dbdfe47e5b
@ -0,0 +1,25 @@ |
||||
# cache/memcache |
||||
|
||||
##### 项目简介 |
||||
1. 提供protobuf,gob,json序列化方式,gzip的memcache接口 |
||||
|
||||
#### 使用方式 |
||||
```golang |
||||
// 初始化 注意这里只是示例 展示用法 不能每次都New 只需要初始化一次 |
||||
mc := memcache.New(&memcache.Config{}) |
||||
// 程序关闭的时候调用close方法 |
||||
defer mc.Close() |
||||
// 增加 key |
||||
err = mc.Set(c, &memcache.Item{}) |
||||
// 删除key |
||||
err := mc.Delete(c,key) |
||||
// 获得某个key的内容 |
||||
err := mc.Get(c,key).Scan(&v) |
||||
// 获取多个key的内容 |
||||
replies, err := mc.GetMulti(c, keys) |
||||
for _, key := range replies.Keys() { |
||||
if err = replies.Scan(key, &v); err != nil { |
||||
return |
||||
} |
||||
} |
||||
``` |
@ -0,0 +1,187 @@ |
||||
package memcache |
||||
|
||||
import ( |
||||
"context" |
||||
) |
||||
|
||||
// Memcache memcache client
|
||||
type Memcache struct { |
||||
pool *Pool |
||||
} |
||||
|
||||
// Reply is the result of Get
|
||||
type Reply struct { |
||||
err error |
||||
item *Item |
||||
conn Conn |
||||
closed bool |
||||
} |
||||
|
||||
// Replies is the result of GetMulti
|
||||
type Replies struct { |
||||
err error |
||||
items map[string]*Item |
||||
usedItems map[string]struct{} |
||||
conn Conn |
||||
closed bool |
||||
} |
||||
|
||||
// New get a memcache client
|
||||
func New(c *Config) *Memcache { |
||||
return &Memcache{pool: NewPool(c)} |
||||
} |
||||
|
||||
// Close close connection pool
|
||||
func (mc *Memcache) Close() error { |
||||
return mc.pool.Close() |
||||
} |
||||
|
||||
// Conn direct get a connection
|
||||
func (mc *Memcache) Conn(c context.Context) Conn { |
||||
return mc.pool.Get(c) |
||||
} |
||||
|
||||
// Set writes the given item, unconditionally.
|
||||
func (mc *Memcache) Set(c context.Context, item *Item) (err error) { |
||||
conn := mc.pool.Get(c) |
||||
err = conn.Set(item) |
||||
conn.Close() |
||||
return |
||||
} |
||||
|
||||
// Add writes the given item, if no value already exists for its key.
|
||||
// ErrNotStored is returned if that condition is not met.
|
||||
func (mc *Memcache) Add(c context.Context, item *Item) (err error) { |
||||
conn := mc.pool.Get(c) |
||||
err = conn.Add(item) |
||||
conn.Close() |
||||
return |
||||
} |
||||
|
||||
// Replace writes the given item, but only if the server *does* already hold data for this key.
|
||||
func (mc *Memcache) Replace(c context.Context, item *Item) (err error) { |
||||
conn := mc.pool.Get(c) |
||||
err = conn.Replace(item) |
||||
conn.Close() |
||||
return |
||||
} |
||||
|
||||
// CompareAndSwap writes the given item that was previously returned by Get
|
||||
func (mc *Memcache) CompareAndSwap(c context.Context, item *Item) (err error) { |
||||
conn := mc.pool.Get(c) |
||||
err = conn.CompareAndSwap(item) |
||||
conn.Close() |
||||
return |
||||
} |
||||
|
||||
// Get sends a command to the server for gets data.
|
||||
func (mc *Memcache) Get(c context.Context, key string) *Reply { |
||||
conn := mc.pool.Get(c) |
||||
item, err := conn.Get(key) |
||||
if err != nil { |
||||
conn.Close() |
||||
} |
||||
return &Reply{err: err, item: item, conn: conn} |
||||
} |
||||
|
||||
// Item get raw Item
|
||||
func (r *Reply) Item() *Item { |
||||
return r.item |
||||
} |
||||
|
||||
// Scan converts value, read from the memcache
|
||||
func (r *Reply) Scan(v interface{}) (err error) { |
||||
if r.err != nil { |
||||
return r.err |
||||
} |
||||
err = r.conn.Scan(r.item, v) |
||||
if !r.closed { |
||||
r.conn.Close() |
||||
r.closed = true |
||||
} |
||||
return |
||||
} |
||||
|
||||
// GetMulti is a batch version of Get
|
||||
func (mc *Memcache) GetMulti(c context.Context, keys []string) (*Replies, error) { |
||||
conn := mc.pool.Get(c) |
||||
items, err := conn.GetMulti(keys) |
||||
rs := &Replies{err: err, items: items, conn: conn, usedItems: make(map[string]struct{}, len(keys))} |
||||
if (err != nil) || (len(items) == 0) { |
||||
rs.Close() |
||||
} |
||||
return rs, err |
||||
} |
||||
|
||||
// Close close rows.
|
||||
func (rs *Replies) Close() (err error) { |
||||
if !rs.closed { |
||||
err = rs.conn.Close() |
||||
rs.closed = true |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Item get Item from rows
|
||||
func (rs *Replies) Item(key string) *Item { |
||||
return rs.items[key] |
||||
} |
||||
|
||||
// Scan converts value, read from key in rows
|
||||
func (rs *Replies) Scan(key string, v interface{}) (err error) { |
||||
if rs.err != nil { |
||||
return rs.err |
||||
} |
||||
item, ok := rs.items[key] |
||||
if !ok { |
||||
rs.Close() |
||||
return ErrNotFound |
||||
} |
||||
rs.usedItems[key] = struct{}{} |
||||
err = rs.conn.Scan(item, v) |
||||
if (err != nil) || (len(rs.items) == len(rs.usedItems)) { |
||||
rs.Close() |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Keys keys of result
|
||||
func (rs *Replies) Keys() (keys []string) { |
||||
keys = make([]string, 0, len(rs.items)) |
||||
for key := range rs.items { |
||||
keys = append(keys, key) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Touch updates the expiry for the given key.
|
||||
func (mc *Memcache) Touch(c context.Context, key string, timeout int32) (err error) { |
||||
conn := mc.pool.Get(c) |
||||
err = conn.Touch(key, timeout) |
||||
conn.Close() |
||||
return |
||||
} |
||||
|
||||
// Delete deletes the item with the provided key.
|
||||
func (mc *Memcache) Delete(c context.Context, key string) (err error) { |
||||
conn := mc.pool.Get(c) |
||||
err = conn.Delete(key) |
||||
conn.Close() |
||||
return |
||||
} |
||||
|
||||
// Increment atomically increments key by delta.
|
||||
func (mc *Memcache) Increment(c context.Context, key string, delta uint64) (newValue uint64, err error) { |
||||
conn := mc.pool.Get(c) |
||||
newValue, err = conn.Increment(key, delta) |
||||
conn.Close() |
||||
return |
||||
} |
||||
|
||||
// Decrement atomically decrements key by delta.
|
||||
func (mc *Memcache) Decrement(c context.Context, key string, delta uint64) (newValue uint64, err error) { |
||||
conn := mc.pool.Get(c) |
||||
newValue, err = conn.Decrement(key, delta) |
||||
conn.Close() |
||||
return |
||||
} |
@ -0,0 +1,685 @@ |
||||
package memcache |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"compress/gzip" |
||||
"context" |
||||
"encoding/gob" |
||||
"encoding/json" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"strconv" |
||||
"strings" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/gogo/protobuf/proto" |
||||
pkgerr "github.com/pkg/errors" |
||||
) |
||||
|
||||
var ( |
||||
crlf = []byte("\r\n") |
||||
spaceStr = string(" ") |
||||
replyOK = []byte("OK\r\n") |
||||
replyStored = []byte("STORED\r\n") |
||||
replyNotStored = []byte("NOT_STORED\r\n") |
||||
replyExists = []byte("EXISTS\r\n") |
||||
replyNotFound = []byte("NOT_FOUND\r\n") |
||||
replyDeleted = []byte("DELETED\r\n") |
||||
replyEnd = []byte("END\r\n") |
||||
replyTouched = []byte("TOUCHED\r\n") |
||||
replyValueStr = "VALUE" |
||||
replyClientErrorPrefix = []byte("CLIENT_ERROR ") |
||||
replyServerErrorPrefix = []byte("SERVER_ERROR ") |
||||
) |
||||
|
||||
const ( |
||||
_encodeBuf = 4096 // 4kb
|
||||
// 1024*1024 - 1, set error???
|
||||
_largeValue = 1000 * 1000 // 1MB
|
||||
) |
||||
|
||||
type reader struct { |
||||
io.Reader |
||||
} |
||||
|
||||
func (r *reader) Reset(rd io.Reader) { |
||||
r.Reader = rd |
||||
} |
||||
|
||||
// conn is the low-level implementation of Conn
|
||||
type conn struct { |
||||
// Shared
|
||||
mu sync.Mutex |
||||
err error |
||||
conn net.Conn |
||||
// Read & Write
|
||||
readTimeout time.Duration |
||||
writeTimeout time.Duration |
||||
rw *bufio.ReadWriter |
||||
// Item Reader
|
||||
ir bytes.Reader |
||||
// Compress
|
||||
gr gzip.Reader |
||||
gw *gzip.Writer |
||||
cb bytes.Buffer |
||||
// Encoding
|
||||
edb bytes.Buffer |
||||
// json
|
||||
jr reader |
||||
jd *json.Decoder |
||||
je *json.Encoder |
||||
// protobuffer
|
||||
ped *proto.Buffer |
||||
} |
||||
|
||||
// DialOption specifies an option for dialing a Memcache server.
|
||||
type DialOption struct { |
||||
f func(*dialOptions) |
||||
} |
||||
|
||||
type dialOptions struct { |
||||
readTimeout time.Duration |
||||
writeTimeout time.Duration |
||||
dial func(network, addr string) (net.Conn, error) |
||||
} |
||||
|
||||
// DialReadTimeout specifies the timeout for reading a single command reply.
|
||||
func DialReadTimeout(d time.Duration) DialOption { |
||||
return DialOption{func(do *dialOptions) { |
||||
do.readTimeout = d |
||||
}} |
||||
} |
||||
|
||||
// DialWriteTimeout specifies the timeout for writing a single command.
|
||||
func DialWriteTimeout(d time.Duration) DialOption { |
||||
return DialOption{func(do *dialOptions) { |
||||
do.writeTimeout = d |
||||
}} |
||||
} |
||||
|
||||
// DialConnectTimeout specifies the timeout for connecting to the Memcache server.
|
||||
func DialConnectTimeout(d time.Duration) DialOption { |
||||
return DialOption{func(do *dialOptions) { |
||||
dialer := net.Dialer{Timeout: d} |
||||
do.dial = dialer.Dial |
||||
}} |
||||
} |
||||
|
||||
// DialNetDial specifies a custom dial function for creating TCP
|
||||
// connections. If this option is left out, then net.Dial is
|
||||
// used. DialNetDial overrides DialConnectTimeout.
|
||||
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption { |
||||
return DialOption{func(do *dialOptions) { |
||||
do.dial = dial |
||||
}} |
||||
} |
||||
|
||||
// Dial connects to the Memcache server at the given network and
|
||||
// address using the specified options.
|
||||
func Dial(network, address string, options ...DialOption) (Conn, error) { |
||||
do := dialOptions{ |
||||
dial: net.Dial, |
||||
} |
||||
for _, option := range options { |
||||
option.f(&do) |
||||
} |
||||
netConn, err := do.dial(network, address) |
||||
if err != nil { |
||||
return nil, pkgerr.WithStack(err) |
||||
} |
||||
return NewConn(netConn, do.readTimeout, do.writeTimeout), nil |
||||
} |
||||
|
||||
// NewConn returns a new memcache connection for the given net connection.
|
||||
func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn { |
||||
if writeTimeout <= 0 || readTimeout <= 0 { |
||||
panic("must config memcache timeout") |
||||
} |
||||
c := &conn{ |
||||
conn: netConn, |
||||
rw: bufio.NewReadWriter(bufio.NewReader(netConn), |
||||
bufio.NewWriter(netConn)), |
||||
readTimeout: readTimeout, |
||||
writeTimeout: writeTimeout, |
||||
} |
||||
c.jd = json.NewDecoder(&c.jr) |
||||
c.je = json.NewEncoder(&c.edb) |
||||
c.gw = gzip.NewWriter(&c.cb) |
||||
c.edb.Grow(_encodeBuf) |
||||
// NOTE reuse bytes.Buffer internal buf
|
||||
// DON'T concurrency call Scan
|
||||
c.ped = proto.NewBuffer(c.edb.Bytes()) |
||||
return c |
||||
} |
||||
|
||||
func (c *conn) Close() error { |
||||
c.mu.Lock() |
||||
err := c.err |
||||
if c.err == nil { |
||||
c.err = pkgerr.New("memcache: closed") |
||||
err = c.conn.Close() |
||||
} |
||||
c.mu.Unlock() |
||||
return err |
||||
} |
||||
|
||||
func (c *conn) fatal(err error) error { |
||||
c.mu.Lock() |
||||
if c.err == nil { |
||||
c.err = pkgerr.WithStack(err) |
||||
// Close connection to force errors on subsequent calls and to unblock
|
||||
// other reader or writer.
|
||||
c.conn.Close() |
||||
} |
||||
c.mu.Unlock() |
||||
return c.err |
||||
} |
||||
|
||||
func (c *conn) Err() error { |
||||
c.mu.Lock() |
||||
err := c.err |
||||
c.mu.Unlock() |
||||
return err |
||||
} |
||||
|
||||
func (c *conn) Add(item *Item) error { |
||||
return c.populate("add", item) |
||||
} |
||||
|
||||
func (c *conn) Set(item *Item) error { |
||||
return c.populate("set", item) |
||||
} |
||||
|
||||
func (c *conn) Replace(item *Item) error { |
||||
return c.populate("replace", item) |
||||
} |
||||
|
||||
func (c *conn) CompareAndSwap(item *Item) error { |
||||
return c.populate("cas", item) |
||||
} |
||||
|
||||
func (c *conn) populate(cmd string, item *Item) (err error) { |
||||
if !legalKey(item.Key) { |
||||
return pkgerr.WithStack(ErrMalformedKey) |
||||
} |
||||
var res []byte |
||||
if res, err = c.encode(item); err != nil { |
||||
return |
||||
} |
||||
l := len(res) |
||||
count := l/(_largeValue) + 1 |
||||
if count == 1 { |
||||
item.Value = res |
||||
return c.populateOne(cmd, item) |
||||
} |
||||
nItem := &Item{ |
||||
Key: item.Key, |
||||
Value: []byte(strconv.Itoa(l)), |
||||
Expiration: item.Expiration, |
||||
Flags: item.Flags | flagLargeValue, |
||||
} |
||||
err = c.populateOne(cmd, nItem) |
||||
if err != nil { |
||||
return |
||||
} |
||||
k := item.Key |
||||
nItem.Flags = item.Flags |
||||
for i := 1; i <= count; i++ { |
||||
if i == count { |
||||
nItem.Value = res[_largeValue*(count-1):] |
||||
} else { |
||||
nItem.Value = res[_largeValue*(i-1) : _largeValue*i] |
||||
} |
||||
nItem.Key = fmt.Sprintf("%s%d", k, i) |
||||
if err = c.populateOne(cmd, nItem); err != nil { |
||||
return |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (c *conn) populateOne(cmd string, item *Item) (err error) { |
||||
if c.writeTimeout != 0 { |
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
||||
} |
||||
|
||||
// <command name> <key> <flags> <exptime> <bytes> [noreply]\r\n
|
||||
if cmd == "cas" { |
||||
_, err = fmt.Fprintf(c.rw, "%s %s %d %d %d %d\r\n", |
||||
cmd, item.Key, item.Flags, item.Expiration, len(item.Value), item.cas) |
||||
} else { |
||||
_, err = fmt.Fprintf(c.rw, "%s %s %d %d %d\r\n", |
||||
cmd, item.Key, item.Flags, item.Expiration, len(item.Value)) |
||||
} |
||||
if err != nil { |
||||
return c.fatal(err) |
||||
} |
||||
c.rw.Write(item.Value) |
||||
c.rw.Write(crlf) |
||||
if err = c.rw.Flush(); err != nil { |
||||
return c.fatal(err) |
||||
} |
||||
if c.readTimeout != 0 { |
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) |
||||
} |
||||
line, err := c.rw.ReadSlice('\n') |
||||
if err != nil { |
||||
return c.fatal(err) |
||||
} |
||||
switch { |
||||
case bytes.Equal(line, replyStored): |
||||
return nil |
||||
case bytes.Equal(line, replyNotStored): |
||||
return ErrNotStored |
||||
case bytes.Equal(line, replyExists): |
||||
return ErrCASConflict |
||||
case bytes.Equal(line, replyNotFound): |
||||
return ErrNotFound |
||||
} |
||||
return pkgerr.WithStack(protocolError(string(line))) |
||||
} |
||||
|
||||
func (c *conn) Get(key string) (r *Item, err error) { |
||||
if !legalKey(key) { |
||||
return nil, pkgerr.WithStack(ErrMalformedKey) |
||||
} |
||||
if c.writeTimeout != 0 { |
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
||||
} |
||||
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", key); err != nil { |
||||
return nil, c.fatal(err) |
||||
} |
||||
if err = c.rw.Flush(); err != nil { |
||||
return nil, c.fatal(err) |
||||
} |
||||
if err = c.parseGetReply(func(it *Item) { |
||||
r = it |
||||
}); err != nil { |
||||
return |
||||
} |
||||
if r == nil { |
||||
err = ErrNotFound |
||||
return |
||||
} |
||||
if r.Flags&flagLargeValue != flagLargeValue { |
||||
return |
||||
} |
||||
if r, err = c.getLargeValue(r); err != nil { |
||||
return |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (c *conn) GetMulti(keys []string) (res map[string]*Item, err error) { |
||||
for _, key := range keys { |
||||
if !legalKey(key) { |
||||
return nil, pkgerr.WithStack(ErrMalformedKey) |
||||
} |
||||
} |
||||
if c.writeTimeout != 0 { |
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
||||
} |
||||
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil { |
||||
return nil, c.fatal(err) |
||||
} |
||||
if err = c.rw.Flush(); err != nil { |
||||
return nil, c.fatal(err) |
||||
} |
||||
res = make(map[string]*Item, len(keys)) |
||||
if err = c.parseGetReply(func(it *Item) { |
||||
res[it.Key] = it |
||||
}); err != nil { |
||||
return |
||||
} |
||||
for k, v := range res { |
||||
if v.Flags&flagLargeValue != flagLargeValue { |
||||
continue |
||||
} |
||||
r, err := c.getLargeValue(v) |
||||
if err != nil { |
||||
return res, err |
||||
} |
||||
res[k] = r |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (c *conn) getMulti(keys []string) (res map[string]*Item, err error) { |
||||
for _, key := range keys { |
||||
if !legalKey(key) { |
||||
return nil, pkgerr.WithStack(ErrMalformedKey) |
||||
} |
||||
} |
||||
if c.writeTimeout != 0 { |
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
||||
} |
||||
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil { |
||||
return nil, c.fatal(err) |
||||
} |
||||
if err = c.rw.Flush(); err != nil { |
||||
return nil, c.fatal(err) |
||||
} |
||||
res = make(map[string]*Item, len(keys)) |
||||
err = c.parseGetReply(func(it *Item) { |
||||
res[it.Key] = it |
||||
}) |
||||
return |
||||
} |
||||
|
||||
func (c *conn) getLargeValue(it *Item) (r *Item, err error) { |
||||
l, err := strconv.Atoi(string(it.Value)) |
||||
if err != nil { |
||||
return |
||||
} |
||||
count := l/_largeValue + 1 |
||||
keys := make([]string, 0, count) |
||||
for i := 1; i <= count; i++ { |
||||
keys = append(keys, fmt.Sprintf("%s%d", it.Key, i)) |
||||
} |
||||
items, err := c.getMulti(keys) |
||||
if err != nil { |
||||
return |
||||
} |
||||
if len(items) < count { |
||||
err = ErrNotFound |
||||
return |
||||
} |
||||
v := make([]byte, 0, l) |
||||
for _, k := range keys { |
||||
if items[k] == nil || items[k].Value == nil { |
||||
err = ErrNotFound |
||||
return |
||||
} |
||||
v = append(v, items[k].Value...) |
||||
} |
||||
it.Value = v |
||||
it.Flags = it.Flags ^ flagLargeValue |
||||
r = it |
||||
return |
||||
} |
||||
|
||||
func (c *conn) parseGetReply(f func(*Item)) error { |
||||
if c.readTimeout != 0 { |
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) |
||||
} |
||||
for { |
||||
line, err := c.rw.ReadSlice('\n') |
||||
if err != nil { |
||||
return c.fatal(err) |
||||
} |
||||
if bytes.Equal(line, replyEnd) { |
||||
return nil |
||||
} |
||||
if bytes.HasPrefix(line, replyServerErrorPrefix) { |
||||
errMsg := line[len(replyServerErrorPrefix):] |
||||
return c.fatal(protocolError(errMsg)) |
||||
} |
||||
it := new(Item) |
||||
size, err := scanGetReply(line, it) |
||||
if err != nil { |
||||
return c.fatal(err) |
||||
} |
||||
it.Value = make([]byte, size+2) |
||||
if _, err = io.ReadFull(c.rw, it.Value); err != nil { |
||||
return c.fatal(err) |
||||
} |
||||
if !bytes.HasSuffix(it.Value, crlf) { |
||||
return c.fatal(protocolError("corrupt get reply, no except CRLF")) |
||||
} |
||||
it.Value = it.Value[:size] |
||||
f(it) |
||||
} |
||||
} |
||||
|
||||
func scanGetReply(line []byte, item *Item) (size int, err error) { |
||||
if !bytes.HasSuffix(line, crlf) { |
||||
return 0, protocolError("corrupt get reply, no except CRLF") |
||||
} |
||||
// VALUE <key> <flags> <bytes> [<cas unique>]
|
||||
chunks := strings.Split(string(line[:len(line)-2]), spaceStr) |
||||
if len(chunks) < 4 { |
||||
return 0, protocolError("corrupt get reply") |
||||
} |
||||
if chunks[0] != replyValueStr { |
||||
return 0, protocolError("corrupt get reply, no except VALUE") |
||||
} |
||||
item.Key = chunks[1] |
||||
flags64, err := strconv.ParseUint(chunks[2], 10, 32) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
item.Flags = uint32(flags64) |
||||
if size, err = strconv.Atoi(chunks[3]); err != nil { |
||||
return |
||||
} |
||||
if len(chunks) > 4 { |
||||
item.cas, err = strconv.ParseUint(chunks[4], 10, 64) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (c *conn) Touch(key string, expire int32) (err error) { |
||||
if !legalKey(key) { |
||||
return pkgerr.WithStack(ErrMalformedKey) |
||||
} |
||||
line, err := c.writeReadLine("touch %s %d\r\n", key, expire) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
switch { |
||||
case bytes.Equal(line, replyTouched): |
||||
return nil |
||||
case bytes.Equal(line, replyNotFound): |
||||
return ErrNotFound |
||||
default: |
||||
return pkgerr.WithStack(protocolError(string(line))) |
||||
} |
||||
} |
||||
|
||||
func (c *conn) Increment(key string, delta uint64) (uint64, error) { |
||||
return c.incrDecr("incr", key, delta) |
||||
} |
||||
|
||||
func (c *conn) Decrement(key string, delta uint64) (newValue uint64, err error) { |
||||
return c.incrDecr("decr", key, delta) |
||||
} |
||||
|
||||
func (c *conn) incrDecr(cmd, key string, delta uint64) (uint64, error) { |
||||
if !legalKey(key) { |
||||
return 0, pkgerr.WithStack(ErrMalformedKey) |
||||
} |
||||
line, err := c.writeReadLine("%s %s %d\r\n", cmd, key, delta) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
switch { |
||||
case bytes.Equal(line, replyNotFound): |
||||
return 0, ErrNotFound |
||||
case bytes.HasPrefix(line, replyClientErrorPrefix): |
||||
errMsg := line[len(replyClientErrorPrefix):] |
||||
return 0, pkgerr.WithStack(protocolError(errMsg)) |
||||
} |
||||
val, err := strconv.ParseUint(string(line[:len(line)-2]), 10, 64) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
return val, nil |
||||
} |
||||
|
||||
func (c *conn) Delete(key string) (err error) { |
||||
if !legalKey(key) { |
||||
return pkgerr.WithStack(ErrMalformedKey) |
||||
} |
||||
line, err := c.writeReadLine("delete %s\r\n", key) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
switch { |
||||
case bytes.Equal(line, replyOK): |
||||
return nil |
||||
case bytes.Equal(line, replyDeleted): |
||||
return nil |
||||
case bytes.Equal(line, replyNotStored): |
||||
return ErrNotStored |
||||
case bytes.Equal(line, replyExists): |
||||
return ErrCASConflict |
||||
case bytes.Equal(line, replyNotFound): |
||||
return ErrNotFound |
||||
} |
||||
return pkgerr.WithStack(protocolError(string(line))) |
||||
} |
||||
|
||||
func (c *conn) writeReadLine(format string, args ...interface{}) ([]byte, error) { |
||||
if c.writeTimeout != 0 { |
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
||||
} |
||||
_, err := fmt.Fprintf(c.rw, format, args...) |
||||
if err != nil { |
||||
return nil, c.fatal(pkgerr.WithStack(err)) |
||||
} |
||||
if err = c.rw.Flush(); err != nil { |
||||
return nil, c.fatal(pkgerr.WithStack(err)) |
||||
} |
||||
if c.readTimeout != 0 { |
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) |
||||
} |
||||
line, err := c.rw.ReadSlice('\n') |
||||
if err != nil { |
||||
return line, c.fatal(pkgerr.WithStack(err)) |
||||
} |
||||
return line, nil |
||||
} |
||||
|
||||
func (c *conn) Scan(item *Item, v interface{}) (err error) { |
||||
c.ir.Reset(item.Value) |
||||
if item.Flags&FlagGzip == FlagGzip { |
||||
if err = c.gr.Reset(&c.ir); err != nil { |
||||
return |
||||
} |
||||
if err = c.decode(&c.gr, item, v); err != nil { |
||||
err = pkgerr.WithStack(err) |
||||
return |
||||
} |
||||
err = c.gr.Close() |
||||
} else { |
||||
err = c.decode(&c.ir, item, v) |
||||
} |
||||
err = pkgerr.WithStack(err) |
||||
return |
||||
} |
||||
|
||||
func (c *conn) WithContext(ctx context.Context) Conn { |
||||
// FIXME: implement WithContext
|
||||
return c |
||||
} |
||||
|
||||
func (c *conn) encode(item *Item) (data []byte, err error) { |
||||
if (item.Flags | _flagEncoding) == _flagEncoding { |
||||
if item.Value == nil { |
||||
return nil, ErrItem |
||||
} |
||||
} else if item.Object == nil { |
||||
return nil, ErrItem |
||||
} |
||||
// encoding
|
||||
switch { |
||||
case item.Flags&FlagGOB == FlagGOB: |
||||
c.edb.Reset() |
||||
if err = gob.NewEncoder(&c.edb).Encode(item.Object); err != nil { |
||||
return |
||||
} |
||||
data = c.edb.Bytes() |
||||
case item.Flags&FlagProtobuf == FlagProtobuf: |
||||
c.edb.Reset() |
||||
c.ped.SetBuf(c.edb.Bytes()) |
||||
pb, ok := item.Object.(proto.Message) |
||||
if !ok { |
||||
err = ErrItemObject |
||||
return |
||||
} |
||||
if err = c.ped.Marshal(pb); err != nil { |
||||
return |
||||
} |
||||
data = c.ped.Bytes() |
||||
case item.Flags&FlagJSON == FlagJSON: |
||||
c.edb.Reset() |
||||
if err = c.je.Encode(item.Object); err != nil { |
||||
return |
||||
} |
||||
data = c.edb.Bytes() |
||||
default: |
||||
data = item.Value |
||||
} |
||||
// compress
|
||||
if item.Flags&FlagGzip == FlagGzip { |
||||
c.cb.Reset() |
||||
c.gw.Reset(&c.cb) |
||||
if _, err = c.gw.Write(data); err != nil { |
||||
return |
||||
} |
||||
if err = c.gw.Close(); err != nil { |
||||
return |
||||
} |
||||
data = c.cb.Bytes() |
||||
} |
||||
if len(data) > 8000000 { |
||||
err = ErrValueSize |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (c *conn) decode(rd io.Reader, item *Item, v interface{}) (err error) { |
||||
var data []byte |
||||
switch { |
||||
case item.Flags&FlagGOB == FlagGOB: |
||||
err = gob.NewDecoder(rd).Decode(v) |
||||
case item.Flags&FlagJSON == FlagJSON: |
||||
c.jr.Reset(rd) |
||||
err = c.jd.Decode(v) |
||||
default: |
||||
data = item.Value |
||||
if item.Flags&FlagGzip == FlagGzip { |
||||
c.edb.Reset() |
||||
if _, err = io.Copy(&c.edb, rd); err != nil { |
||||
return |
||||
} |
||||
data = c.edb.Bytes() |
||||
} |
||||
if item.Flags&FlagProtobuf == FlagProtobuf { |
||||
m, ok := v.(proto.Message) |
||||
if !ok { |
||||
err = ErrItemObject |
||||
return |
||||
} |
||||
c.ped.SetBuf(data) |
||||
err = c.ped.Unmarshal(m) |
||||
} else { |
||||
switch v.(type) { |
||||
case *[]byte: |
||||
d := v.(*[]byte) |
||||
*d = data |
||||
case *string: |
||||
d := v.(*string) |
||||
*d = string(data) |
||||
case interface{}: |
||||
err = json.Unmarshal(data, v) |
||||
} |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
func legalKey(key string) bool { |
||||
if len(key) > 250 || len(key) == 0 { |
||||
return false |
||||
} |
||||
for i := 0; i < len(key); i++ { |
||||
if key[i] <= ' ' || key[i] == 0x7f { |
||||
return false |
||||
} |
||||
} |
||||
return true |
||||
} |
@ -0,0 +1,76 @@ |
||||
package memcache |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"strings" |
||||
|
||||
pkgerr "github.com/pkg/errors" |
||||
) |
||||
|
||||
var ( |
||||
// ErrNotFound not found
|
||||
ErrNotFound = errors.New("memcache: key not found") |
||||
// ErrExists exists
|
||||
ErrExists = errors.New("memcache: key exists") |
||||
// ErrNotStored not stored
|
||||
ErrNotStored = errors.New("memcache: key not stored") |
||||
// ErrCASConflict means that a CompareAndSwap call failed due to the
|
||||
// cached value being modified between the Get and the CompareAndSwap.
|
||||
// If the cached value was simply evicted rather than replaced,
|
||||
// ErrNotStored will be returned instead.
|
||||
ErrCASConflict = errors.New("memcache: compare-and-swap conflict") |
||||
|
||||
// ErrPoolExhausted is returned from a pool connection method (Store, Get,
|
||||
// Delete, IncrDecr, Err) when the maximum number of database connections
|
||||
// in the pool has been reached.
|
||||
ErrPoolExhausted = errors.New("memcache: connection pool exhausted") |
||||
// ErrPoolClosed pool closed
|
||||
ErrPoolClosed = errors.New("memcache: connection pool closed") |
||||
// ErrConnClosed conn closed
|
||||
ErrConnClosed = errors.New("memcache: connection closed") |
||||
// ErrMalformedKey is returned when an invalid key is used.
|
||||
// Keys must be at maximum 250 bytes long and not
|
||||
// contain whitespace or control characters.
|
||||
ErrMalformedKey = errors.New("memcache: malformed key is too long or contains invalid characters") |
||||
// ErrValueSize item value size must less than 1mb
|
||||
ErrValueSize = errors.New("memcache: item value size must not greater than 1mb") |
||||
// ErrStat stat error for monitor
|
||||
ErrStat = errors.New("memcache unexpected errors") |
||||
// ErrItem item nil.
|
||||
ErrItem = errors.New("memcache: item object nil") |
||||
// ErrItemObject object type Assertion failed
|
||||
ErrItemObject = errors.New("memcache: item object protobuf type assertion failed") |
||||
) |
||||
|
||||
type protocolError string |
||||
|
||||
func (pe protocolError) Error() string { |
||||
return fmt.Sprintf("memcache: %s (possible server error or unsupported concurrent read by application)", string(pe)) |
||||
} |
||||
|
||||
func formatErr(err error) string { |
||||
e := pkgerr.Cause(err) |
||||
switch e { |
||||
case ErrNotFound, ErrExists, ErrNotStored, nil: |
||||
return "" |
||||
default: |
||||
es := e.Error() |
||||
switch { |
||||
case strings.HasPrefix(es, "read"): |
||||
return "read timeout" |
||||
case strings.HasPrefix(es, "dial"): |
||||
return "dial timeout" |
||||
case strings.HasPrefix(es, "write"): |
||||
return "write timeout" |
||||
case strings.Contains(es, "EOF"): |
||||
return "eof" |
||||
case strings.Contains(es, "reset"): |
||||
return "reset" |
||||
case strings.Contains(es, "broken"): |
||||
return "broken pipe" |
||||
default: |
||||
return "unexpected err" |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,136 @@ |
||||
package memcache |
||||
|
||||
import ( |
||||
"context" |
||||
) |
||||
|
||||
// Error represents an error returned in a command reply.
|
||||
type Error string |
||||
|
||||
func (err Error) Error() string { return string(err) } |
||||
|
||||
const ( |
||||
// Flag, 15(encoding) bit+ 17(compress) bit
|
||||
|
||||
// FlagRAW default flag.
|
||||
FlagRAW = uint32(0) |
||||
// FlagGOB gob encoding.
|
||||
FlagGOB = uint32(1) << 0 |
||||
// FlagJSON json encoding.
|
||||
FlagJSON = uint32(1) << 1 |
||||
// FlagProtobuf protobuf
|
||||
FlagProtobuf = uint32(1) << 2 |
||||
|
||||
_flagEncoding = uint32(0xFFFF8000) |
||||
|
||||
// FlagGzip gzip compress.
|
||||
FlagGzip = uint32(1) << 15 |
||||
|
||||
// left mv 31??? not work!!!
|
||||
flagLargeValue = uint32(1) << 30 |
||||
) |
||||
|
||||
// Item is an reply to be got or stored in a memcached server.
|
||||
type Item struct { |
||||
// Key is the Item's key (250 bytes maximum).
|
||||
Key string |
||||
|
||||
// Value is the Item's value.
|
||||
Value []byte |
||||
|
||||
// Object is the Item's object for use codec.
|
||||
Object interface{} |
||||
|
||||
// Flags are server-opaque flags whose semantics are entirely
|
||||
// up to the app.
|
||||
Flags uint32 |
||||
|
||||
// Expiration is the cache expiration time, in seconds: either a relative
|
||||
// time from now (up to 1 month), or an absolute Unix epoch time.
|
||||
// Zero means the Item has no expiration time.
|
||||
Expiration int32 |
||||
|
||||
// Compare and swap ID.
|
||||
cas uint64 |
||||
} |
||||
|
||||
// Conn represents a connection to a Memcache server.
|
||||
// Command Reference: https://github.com/memcached/memcached/wiki/Commands
|
||||
type Conn interface { |
||||
// Close closes the connection.
|
||||
Close() error |
||||
|
||||
// Err returns a non-nil value if the connection is broken. The returned
|
||||
// value is either the first non-nil value returned from the underlying
|
||||
// network connection or a protocol parsing error. Applications should
|
||||
// close broken connections.
|
||||
Err() error |
||||
|
||||
// Add writes the given item, if no value already exists for its key.
|
||||
// ErrNotStored is returned if that condition is not met.
|
||||
Add(item *Item) error |
||||
|
||||
// Set writes the given item, unconditionally.
|
||||
Set(item *Item) error |
||||
|
||||
// Replace writes the given item, but only if the server *does* already
|
||||
// hold data for this key.
|
||||
Replace(item *Item) error |
||||
|
||||
// Get sends a command to the server for gets data.
|
||||
Get(key string) (*Item, error) |
||||
|
||||
// GetMulti is a batch version of Get. The returned map from keys to items
|
||||
// may have fewer elements than the input slice, due to memcache cache
|
||||
// misses. Each key must be at most 250 bytes in length.
|
||||
// If no error is returned, the returned map will also be non-nil.
|
||||
GetMulti(keys []string) (map[string]*Item, error) |
||||
|
||||
// Delete deletes the item with the provided key.
|
||||
// The error ErrCacheMiss is returned if the item didn't already exist in
|
||||
// the cache.
|
||||
Delete(key string) error |
||||
|
||||
// Increment atomically increments key by delta. The return value is the
|
||||
// new value after being incremented or an error. If the value didn't exist
|
||||
// in memcached the error is ErrCacheMiss. The value in memcached must be
|
||||
// an decimal number, or an error will be returned.
|
||||
// On 64-bit overflow, the new value wraps around.
|
||||
Increment(key string, delta uint64) (newValue uint64, err error) |
||||
|
||||
// Decrement atomically decrements key by delta. The return value is the
|
||||
// new value after being decremented or an error. If the value didn't exist
|
||||
// in memcached the error is ErrCacheMiss. The value in memcached must be
|
||||
// an decimal number, or an error will be returned. On underflow, the new
|
||||
// value is capped at zero and does not wrap around.
|
||||
Decrement(key string, delta uint64) (newValue uint64, err error) |
||||
|
||||
// CompareAndSwap writes the given item that was previously returned by
|
||||
// Get, if the value was neither modified or evicted between the Get and
|
||||
// the CompareAndSwap calls. The item's Key should not change between calls
|
||||
// but all other item fields may differ. ErrCASConflict is returned if the
|
||||
// value was modified in between the calls.
|
||||
// ErrNotStored is returned if the value was evicted in between the calls.
|
||||
CompareAndSwap(item *Item) error |
||||
|
||||
// Touch updates the expiry for the given key. The seconds parameter is
|
||||
// either a Unix timestamp or, if seconds is less than 1 month, the number
|
||||
// of seconds into the future at which time the item will expire.
|
||||
//ErrCacheMiss is returned if the key is not in the cache. The key must be
|
||||
// at most 250 bytes in length.
|
||||
Touch(key string, seconds int32) (err error) |
||||
|
||||
// Scan converts value read from the memcache into the following
|
||||
// common Go types and special types:
|
||||
//
|
||||
// *string
|
||||
// *[]byte
|
||||
// *interface{}
|
||||
//
|
||||
Scan(item *Item, v interface{}) (err error) |
||||
|
||||
// WithContext return a Conn with its context changed to ctx
|
||||
// the context controls the entire lifetime of Conn before you change it
|
||||
// NOTE: this method is not thread-safe
|
||||
WithContext(ctx context.Context) Conn |
||||
} |
@ -0,0 +1,59 @@ |
||||
package memcache |
||||
|
||||
import ( |
||||
"context" |
||||
) |
||||
|
||||
// MockErr for unit test.
|
||||
type MockErr struct { |
||||
Error error |
||||
} |
||||
|
||||
var _ Conn = MockErr{} |
||||
|
||||
// MockWith return a mock conn.
|
||||
func MockWith(err error) MockErr { |
||||
return MockErr{Error: err} |
||||
} |
||||
|
||||
// Err .
|
||||
func (m MockErr) Err() error { return m.Error } |
||||
|
||||
// Close .
|
||||
func (m MockErr) Close() error { return m.Error } |
||||
|
||||
// Add .
|
||||
func (m MockErr) Add(item *Item) error { return m.Error } |
||||
|
||||
// Set .
|
||||
func (m MockErr) Set(item *Item) error { return m.Error } |
||||
|
||||
// Replace .
|
||||
func (m MockErr) Replace(item *Item) error { return m.Error } |
||||
|
||||
// CompareAndSwap .
|
||||
func (m MockErr) CompareAndSwap(item *Item) error { return m.Error } |
||||
|
||||
// Get .
|
||||
func (m MockErr) Get(key string) (*Item, error) { return nil, m.Error } |
||||
|
||||
// GetMulti .
|
||||
func (m MockErr) GetMulti(keys []string) (map[string]*Item, error) { return nil, m.Error } |
||||
|
||||
// Touch .
|
||||
func (m MockErr) Touch(key string, timeout int32) error { return m.Error } |
||||
|
||||
// Delete .
|
||||
func (m MockErr) Delete(key string) error { return m.Error } |
||||
|
||||
// Increment .
|
||||
func (m MockErr) Increment(key string, delta uint64) (uint64, error) { return 0, m.Error } |
||||
|
||||
// Decrement .
|
||||
func (m MockErr) Decrement(key string, delta uint64) (uint64, error) { return 0, m.Error } |
||||
|
||||
// Scan .
|
||||
func (m MockErr) Scan(item *Item, v interface{}) error { return m.Error } |
||||
|
||||
// WithContext .
|
||||
func (m MockErr) WithContext(ctx context.Context) Conn { return m } |
@ -0,0 +1,197 @@ |
||||
package memcache |
||||
|
||||
import ( |
||||
"context" |
||||
"io" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/container/pool" |
||||
"github.com/bilibili/kratos/pkg/stat" |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
) |
||||
|
||||
var stats = stat.Cache |
||||
|
||||
// Config memcache config.
|
||||
type Config struct { |
||||
*pool.Config |
||||
|
||||
Name string // memcache name, for trace
|
||||
Proto string |
||||
Addr string |
||||
DialTimeout xtime.Duration |
||||
ReadTimeout xtime.Duration |
||||
WriteTimeout xtime.Duration |
||||
} |
||||
|
||||
// Pool memcache connection pool struct.
|
||||
type Pool struct { |
||||
p pool.Pool |
||||
c *Config |
||||
} |
||||
|
||||
// NewPool new a memcache conn pool.
|
||||
func NewPool(c *Config) (p *Pool) { |
||||
if c.DialTimeout <= 0 || c.ReadTimeout <= 0 || c.WriteTimeout <= 0 { |
||||
panic("must config memcache timeout") |
||||
} |
||||
p1 := pool.NewList(c.Config) |
||||
cnop := DialConnectTimeout(time.Duration(c.DialTimeout)) |
||||
rdop := DialReadTimeout(time.Duration(c.ReadTimeout)) |
||||
wrop := DialWriteTimeout(time.Duration(c.WriteTimeout)) |
||||
p1.New = func(ctx context.Context) (io.Closer, error) { |
||||
conn, err := Dial(c.Proto, c.Addr, cnop, rdop, wrop) |
||||
return &traceConn{Conn: conn, address: c.Addr}, err |
||||
} |
||||
p = &Pool{p: p1, c: c} |
||||
return |
||||
} |
||||
|
||||
// Get gets a connection. The application must close the returned connection.
|
||||
// This method always returns a valid connection so that applications can defer
|
||||
// error handling to the first use of the connection. If there is an error
|
||||
// getting an underlying connection, then the connection Err, Do, Send, Flush
|
||||
// and Receive methods return that error.
|
||||
func (p *Pool) Get(ctx context.Context) Conn { |
||||
c, err := p.p.Get(ctx) |
||||
if err != nil { |
||||
return errorConnection{err} |
||||
} |
||||
c1, _ := c.(Conn) |
||||
return &pooledConnection{p: p, c: c1.WithContext(ctx), ctx: ctx} |
||||
} |
||||
|
||||
// Close release the resources used by the pool.
|
||||
func (p *Pool) Close() error { |
||||
return p.p.Close() |
||||
} |
||||
|
||||
type pooledConnection struct { |
||||
p *Pool |
||||
c Conn |
||||
ctx context.Context |
||||
} |
||||
|
||||
func pstat(key string, t time.Time, err error) { |
||||
stats.Timing(key, int64(time.Since(t)/time.Millisecond)) |
||||
if err != nil { |
||||
if msg := formatErr(err); msg != "" { |
||||
stats.Incr("memcache", msg) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (pc *pooledConnection) Close() error { |
||||
c := pc.c |
||||
if _, ok := c.(errorConnection); ok { |
||||
return nil |
||||
} |
||||
pc.c = errorConnection{ErrConnClosed} |
||||
pc.p.p.Put(context.Background(), c, c.Err() != nil) |
||||
return nil |
||||
} |
||||
|
||||
func (pc *pooledConnection) Err() error { |
||||
return pc.c.Err() |
||||
} |
||||
|
||||
func (pc *pooledConnection) Set(item *Item) (err error) { |
||||
now := time.Now() |
||||
err = pc.c.Set(item) |
||||
pstat("memcache:set", now, err) |
||||
return |
||||
} |
||||
|
||||
func (pc *pooledConnection) Add(item *Item) (err error) { |
||||
now := time.Now() |
||||
err = pc.c.Add(item) |
||||
pstat("memcache:add", now, err) |
||||
return |
||||
} |
||||
|
||||
func (pc *pooledConnection) Replace(item *Item) (err error) { |
||||
now := time.Now() |
||||
err = pc.c.Replace(item) |
||||
pstat("memcache:replace", now, err) |
||||
return |
||||
} |
||||
|
||||
func (pc *pooledConnection) CompareAndSwap(item *Item) (err error) { |
||||
now := time.Now() |
||||
err = pc.c.CompareAndSwap(item) |
||||
pstat("memcache:cas", now, err) |
||||
return |
||||
} |
||||
|
||||
func (pc *pooledConnection) Get(key string) (r *Item, err error) { |
||||
now := time.Now() |
||||
r, err = pc.c.Get(key) |
||||
pstat("memcache:get", now, err) |
||||
return |
||||
} |
||||
|
||||
func (pc *pooledConnection) GetMulti(keys []string) (res map[string]*Item, err error) { |
||||
// if keys is empty slice returns empty map direct
|
||||
if len(keys) == 0 { |
||||
return make(map[string]*Item), nil |
||||
} |
||||
now := time.Now() |
||||
res, err = pc.c.GetMulti(keys) |
||||
pstat("memcache:gets", now, err) |
||||
return |
||||
} |
||||
|
||||
func (pc *pooledConnection) Touch(key string, timeout int32) (err error) { |
||||
now := time.Now() |
||||
err = pc.c.Touch(key, timeout) |
||||
pstat("memcache:touch", now, err) |
||||
return |
||||
} |
||||
|
||||
func (pc *pooledConnection) Scan(item *Item, v interface{}) error { |
||||
return pc.c.Scan(item, v) |
||||
} |
||||
|
||||
func (pc *pooledConnection) WithContext(ctx context.Context) Conn { |
||||
// TODO: set context
|
||||
pc.ctx = ctx |
||||
return pc |
||||
} |
||||
|
||||
func (pc *pooledConnection) Delete(key string) (err error) { |
||||
now := time.Now() |
||||
err = pc.c.Delete(key) |
||||
pstat("memcache:delete", now, err) |
||||
return |
||||
} |
||||
|
||||
func (pc *pooledConnection) Increment(key string, delta uint64) (newValue uint64, err error) { |
||||
now := time.Now() |
||||
newValue, err = pc.c.Increment(key, delta) |
||||
pstat("memcache:increment", now, err) |
||||
return |
||||
} |
||||
|
||||
func (pc *pooledConnection) Decrement(key string, delta uint64) (newValue uint64, err error) { |
||||
now := time.Now() |
||||
newValue, err = pc.c.Decrement(key, delta) |
||||
pstat("memcache:decrement", now, err) |
||||
return |
||||
} |
||||
|
||||
type errorConnection struct{ err error } |
||||
|
||||
func (ec errorConnection) Err() error { return ec.err } |
||||
func (ec errorConnection) Close() error { return ec.err } |
||||
func (ec errorConnection) Add(item *Item) error { return ec.err } |
||||
func (ec errorConnection) Set(item *Item) error { return ec.err } |
||||
func (ec errorConnection) Replace(item *Item) error { return ec.err } |
||||
func (ec errorConnection) CompareAndSwap(item *Item) error { return ec.err } |
||||
func (ec errorConnection) Get(key string) (*Item, error) { return nil, ec.err } |
||||
func (ec errorConnection) GetMulti(keys []string) (map[string]*Item, error) { return nil, ec.err } |
||||
func (ec errorConnection) Touch(key string, timeout int32) error { return ec.err } |
||||
func (ec errorConnection) Delete(key string) error { return ec.err } |
||||
func (ec errorConnection) Increment(key string, delta uint64) (uint64, error) { return 0, ec.err } |
||||
func (ec errorConnection) Decrement(key string, delta uint64) (uint64, error) { return 0, ec.err } |
||||
func (ec errorConnection) Scan(item *Item, v interface{}) error { return ec.err } |
||||
func (ec errorConnection) WithContext(ctx context.Context) Conn { return ec } |
@ -0,0 +1,109 @@ |
||||
package memcache |
||||
|
||||
import ( |
||||
"context" |
||||
"strconv" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/log" |
||||
"github.com/bilibili/kratos/pkg/net/trace" |
||||
) |
||||
|
||||
const ( |
||||
_traceFamily = "memcache" |
||||
_traceSpanKind = "client" |
||||
_traceComponentName = "library/cache/memcache" |
||||
_tracePeerService = "memcache" |
||||
_slowLogDuration = time.Millisecond * 250 |
||||
) |
||||
|
||||
type traceConn struct { |
||||
Conn |
||||
ctx context.Context |
||||
address string |
||||
} |
||||
|
||||
func (t *traceConn) setTrace(action, statement string) func(error) error { |
||||
now := time.Now() |
||||
parent, ok := trace.FromContext(t.ctx) |
||||
if !ok { |
||||
return func(err error) error { return err } |
||||
} |
||||
span := parent.Fork(_traceFamily, "Memcache:"+action) |
||||
span.SetTag( |
||||
trace.String(trace.TagSpanKind, _traceSpanKind), |
||||
trace.String(trace.TagComponent, _traceComponentName), |
||||
trace.String(trace.TagPeerService, _tracePeerService), |
||||
trace.String(trace.TagPeerAddress, t.address), |
||||
trace.String(trace.TagDBStatement, action+" "+statement), |
||||
) |
||||
return func(err error) error { |
||||
span.Finish(&err) |
||||
t := time.Since(now) |
||||
if t > _slowLogDuration { |
||||
log.Warn("%s slow log action: %s key: %s time: %v", _traceFamily, action, statement, t) |
||||
} |
||||
return err |
||||
} |
||||
} |
||||
|
||||
func (t *traceConn) WithContext(ctx context.Context) Conn { |
||||
t.ctx = ctx |
||||
t.Conn = t.Conn.WithContext(ctx) |
||||
return t |
||||
} |
||||
|
||||
func (t *traceConn) Add(item *Item) error { |
||||
finishFn := t.setTrace("Add", item.Key) |
||||
return finishFn(t.Conn.Add(item)) |
||||
} |
||||
|
||||
func (t *traceConn) Set(item *Item) error { |
||||
finishFn := t.setTrace("Set", item.Key) |
||||
return finishFn(t.Conn.Set(item)) |
||||
} |
||||
|
||||
func (t *traceConn) Replace(item *Item) error { |
||||
finishFn := t.setTrace("Replace", item.Key) |
||||
return finishFn(t.Conn.Replace(item)) |
||||
} |
||||
|
||||
func (t *traceConn) Get(key string) (*Item, error) { |
||||
finishFn := t.setTrace("Get", key) |
||||
item, err := t.Conn.Get(key) |
||||
return item, finishFn(err) |
||||
} |
||||
|
||||
func (t *traceConn) GetMulti(keys []string) (map[string]*Item, error) { |
||||
finishFn := t.setTrace("GetMulti", strings.Join(keys, " ")) |
||||
items, err := t.Conn.GetMulti(keys) |
||||
return items, finishFn(err) |
||||
} |
||||
|
||||
func (t *traceConn) Delete(key string) error { |
||||
finishFn := t.setTrace("Delete", key) |
||||
return finishFn(t.Conn.Delete(key)) |
||||
} |
||||
|
||||
func (t *traceConn) Increment(key string, delta uint64) (newValue uint64, err error) { |
||||
finishFn := t.setTrace("Increment", key+" "+strconv.FormatUint(delta, 10)) |
||||
newValue, err = t.Conn.Increment(key, delta) |
||||
return newValue, finishFn(err) |
||||
} |
||||
|
||||
func (t *traceConn) Decrement(key string, delta uint64) (newValue uint64, err error) { |
||||
finishFn := t.setTrace("Decrement", key+" "+strconv.FormatUint(delta, 10)) |
||||
newValue, err = t.Conn.Decrement(key, delta) |
||||
return newValue, finishFn(err) |
||||
} |
||||
|
||||
func (t *traceConn) CompareAndSwap(item *Item) error { |
||||
finishFn := t.setTrace("CompareAndSwap", item.Key) |
||||
return finishFn(t.Conn.CompareAndSwap(item)) |
||||
} |
||||
|
||||
func (t *traceConn) Touch(key string, seconds int32) (err error) { |
||||
finishFn := t.setTrace("Touch", key+" "+strconv.Itoa(int(seconds))) |
||||
return finishFn(t.Conn.Touch(key, seconds)) |
||||
} |
@ -0,0 +1,32 @@ |
||||
package memcache |
||||
|
||||
import ( |
||||
"github.com/gogo/protobuf/proto" |
||||
) |
||||
|
||||
// RawItem item with FlagRAW flag.
|
||||
//
|
||||
// Expiration is the cache expiration time, in seconds: either a relative
|
||||
// time from now (up to 1 month), or an absolute Unix epoch time.
|
||||
// Zero means the Item has no expiration time.
|
||||
func RawItem(key string, data []byte, flags uint32, expiration int32) *Item { |
||||
return &Item{Key: key, Flags: flags | FlagRAW, Value: data, Expiration: expiration} |
||||
} |
||||
|
||||
// JSONItem item with FlagJSON flag.
|
||||
//
|
||||
// Expiration is the cache expiration time, in seconds: either a relative
|
||||
// time from now (up to 1 month), or an absolute Unix epoch time.
|
||||
// Zero means the Item has no expiration time.
|
||||
func JSONItem(key string, v interface{}, flags uint32, expiration int32) *Item { |
||||
return &Item{Key: key, Flags: flags | FlagJSON, Object: v, Expiration: expiration} |
||||
} |
||||
|
||||
// ProtobufItem item with FlagProtobuf flag.
|
||||
//
|
||||
// Expiration is the cache expiration time, in seconds: either a relative
|
||||
// time from now (up to 1 month), or an absolute Unix epoch time.
|
||||
// Zero means the Item has no expiration time.
|
||||
func ProtobufItem(key string, message proto.Message, flags uint32, expiration int32) *Item { |
||||
return &Item{Key: key, Flags: flags | FlagProtobuf, Object: message, Expiration: expiration} |
||||
} |
@ -0,0 +1,7 @@ |
||||
# cache/redis |
||||
|
||||
##### 项目简介 |
||||
1. 提供redis接口 |
||||
|
||||
#### 使用方式 |
||||
请参考doc.go |
@ -0,0 +1,57 @@ |
||||
// Copyright 2014 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis |
||||
|
||||
import ( |
||||
"strings" |
||||
) |
||||
|
||||
// redis state
|
||||
const ( |
||||
WatchState = 1 << iota |
||||
MultiState |
||||
SubscribeState |
||||
MonitorState |
||||
) |
||||
|
||||
// CommandInfo command info.
|
||||
type CommandInfo struct { |
||||
Set, Clear int |
||||
} |
||||
|
||||
var commandInfos = map[string]CommandInfo{ |
||||
"WATCH": {Set: WatchState}, |
||||
"UNWATCH": {Clear: WatchState}, |
||||
"MULTI": {Set: MultiState}, |
||||
"EXEC": {Clear: WatchState | MultiState}, |
||||
"DISCARD": {Clear: WatchState | MultiState}, |
||||
"PSUBSCRIBE": {Set: SubscribeState}, |
||||
"SUBSCRIBE": {Set: SubscribeState}, |
||||
"MONITOR": {Set: MonitorState}, |
||||
} |
||||
|
||||
func init() { |
||||
for n, ci := range commandInfos { |
||||
commandInfos[strings.ToLower(n)] = ci |
||||
} |
||||
} |
||||
|
||||
// LookupCommandInfo get command info.
|
||||
func LookupCommandInfo(commandName string) CommandInfo { |
||||
if ci, ok := commandInfos[commandName]; ok { |
||||
return ci |
||||
} |
||||
return commandInfos[strings.ToUpper(commandName)] |
||||
} |
@ -0,0 +1,597 @@ |
||||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"context" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"net/url" |
||||
"regexp" |
||||
"strconv" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/stat" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
var stats = stat.Cache |
||||
|
||||
// conn is the low-level implementation of Conn
|
||||
type conn struct { |
||||
|
||||
// Shared
|
||||
mu sync.Mutex |
||||
pending int |
||||
err error |
||||
conn net.Conn |
||||
|
||||
// Read
|
||||
readTimeout time.Duration |
||||
br *bufio.Reader |
||||
|
||||
// Write
|
||||
writeTimeout time.Duration |
||||
bw *bufio.Writer |
||||
|
||||
// Scratch space for formatting argument length.
|
||||
// '*' or '$', length, "\r\n"
|
||||
lenScratch [32]byte |
||||
|
||||
// Scratch space for formatting integers and floats.
|
||||
numScratch [40]byte |
||||
// stat func,default prom
|
||||
stat func(string, *error) func() |
||||
} |
||||
|
||||
func statfunc(cmd string, err *error) func() { |
||||
now := time.Now() |
||||
return func() { |
||||
stats.Timing(fmt.Sprintf("redis:%s", cmd), int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
if msg := formatErr(*err); msg != "" { |
||||
stats.Incr("redis", msg) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
// DialTimeout acts like Dial but takes timeouts for establishing the
|
||||
// connection to the server, writing a command and reading a reply.
|
||||
//
|
||||
// Deprecated: Use Dial with options instead.
|
||||
func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) { |
||||
return Dial(network, address, |
||||
DialConnectTimeout(connectTimeout), |
||||
DialReadTimeout(readTimeout), |
||||
DialWriteTimeout(writeTimeout)) |
||||
} |
||||
|
||||
// DialOption specifies an option for dialing a Redis server.
|
||||
type DialOption struct { |
||||
f func(*dialOptions) |
||||
} |
||||
|
||||
type dialOptions struct { |
||||
readTimeout time.Duration |
||||
writeTimeout time.Duration |
||||
dial func(network, addr string) (net.Conn, error) |
||||
db int |
||||
password string |
||||
stat func(string, *error) func() |
||||
} |
||||
|
||||
// DialStats specifies stat func for stats.default statfunc.
|
||||
func DialStats(fn func(string, *error) func()) DialOption { |
||||
return DialOption{func(do *dialOptions) { |
||||
do.stat = fn |
||||
}} |
||||
} |
||||
|
||||
// DialReadTimeout specifies the timeout for reading a single command reply.
|
||||
func DialReadTimeout(d time.Duration) DialOption { |
||||
return DialOption{func(do *dialOptions) { |
||||
do.readTimeout = d |
||||
}} |
||||
} |
||||
|
||||
// DialWriteTimeout specifies the timeout for writing a single command.
|
||||
func DialWriteTimeout(d time.Duration) DialOption { |
||||
return DialOption{func(do *dialOptions) { |
||||
do.writeTimeout = d |
||||
}} |
||||
} |
||||
|
||||
// DialConnectTimeout specifies the timeout for connecting to the Redis server.
|
||||
func DialConnectTimeout(d time.Duration) DialOption { |
||||
return DialOption{func(do *dialOptions) { |
||||
dialer := net.Dialer{Timeout: d} |
||||
do.dial = dialer.Dial |
||||
}} |
||||
} |
||||
|
||||
// DialNetDial specifies a custom dial function for creating TCP
|
||||
// connections. If this option is left out, then net.Dial is
|
||||
// used. DialNetDial overrides DialConnectTimeout.
|
||||
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption { |
||||
return DialOption{func(do *dialOptions) { |
||||
do.dial = dial |
||||
}} |
||||
} |
||||
|
||||
// DialDatabase specifies the database to select when dialing a connection.
|
||||
func DialDatabase(db int) DialOption { |
||||
return DialOption{func(do *dialOptions) { |
||||
do.db = db |
||||
}} |
||||
} |
||||
|
||||
// DialPassword specifies the password to use when connecting to
|
||||
// the Redis server.
|
||||
func DialPassword(password string) DialOption { |
||||
return DialOption{func(do *dialOptions) { |
||||
do.password = password |
||||
}} |
||||
} |
||||
|
||||
// Dial connects to the Redis server at the given network and
|
||||
// address using the specified options.
|
||||
func Dial(network, address string, options ...DialOption) (Conn, error) { |
||||
do := dialOptions{ |
||||
dial: net.Dial, |
||||
} |
||||
for _, option := range options { |
||||
option.f(&do) |
||||
} |
||||
|
||||
netConn, err := do.dial(network, address) |
||||
if err != nil { |
||||
return nil, errors.WithStack(err) |
||||
} |
||||
c := &conn{ |
||||
conn: netConn, |
||||
bw: bufio.NewWriter(netConn), |
||||
br: bufio.NewReader(netConn), |
||||
readTimeout: do.readTimeout, |
||||
writeTimeout: do.writeTimeout, |
||||
stat: statfunc, |
||||
} |
||||
|
||||
if do.password != "" { |
||||
if _, err := c.Do("AUTH", do.password); err != nil { |
||||
netConn.Close() |
||||
return nil, errors.WithStack(err) |
||||
} |
||||
} |
||||
|
||||
if do.db != 0 { |
||||
if _, err := c.Do("SELECT", do.db); err != nil { |
||||
netConn.Close() |
||||
return nil, errors.WithStack(err) |
||||
} |
||||
} |
||||
if do.stat != nil { |
||||
c.stat = do.stat |
||||
} |
||||
return c, nil |
||||
} |
||||
|
||||
var pathDBRegexp = regexp.MustCompile(`/(\d+)\z`) |
||||
|
||||
// DialURL connects to a Redis server at the given URL using the Redis
|
||||
// URI scheme. URLs should follow the draft IANA specification for the
|
||||
// scheme (https://www.iana.org/assignments/uri-schemes/prov/redis).
|
||||
func DialURL(rawurl string, options ...DialOption) (Conn, error) { |
||||
u, err := url.Parse(rawurl) |
||||
if err != nil { |
||||
return nil, errors.WithStack(err) |
||||
} |
||||
|
||||
if u.Scheme != "redis" { |
||||
return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme) |
||||
} |
||||
|
||||
// As per the IANA draft spec, the host defaults to localhost and
|
||||
// the port defaults to 6379.
|
||||
host, port, err := net.SplitHostPort(u.Host) |
||||
if err != nil { |
||||
// assume port is missing
|
||||
host = u.Host |
||||
port = "6379" |
||||
} |
||||
if host == "" { |
||||
host = "localhost" |
||||
} |
||||
address := net.JoinHostPort(host, port) |
||||
|
||||
if u.User != nil { |
||||
password, isSet := u.User.Password() |
||||
if isSet { |
||||
options = append(options, DialPassword(password)) |
||||
} |
||||
} |
||||
|
||||
match := pathDBRegexp.FindStringSubmatch(u.Path) |
||||
if len(match) == 2 { |
||||
db, err := strconv.Atoi(match[1]) |
||||
if err != nil { |
||||
return nil, errors.Errorf("invalid database: %s", u.Path[1:]) |
||||
} |
||||
if db != 0 { |
||||
options = append(options, DialDatabase(db)) |
||||
} |
||||
} else if u.Path != "" { |
||||
return nil, errors.Errorf("invalid database: %s", u.Path[1:]) |
||||
} |
||||
|
||||
return Dial("tcp", address, options...) |
||||
} |
||||
|
||||
// NewConn new a redis conn.
|
||||
func NewConn(c *Config) (cn Conn, err error) { |
||||
cnop := DialConnectTimeout(time.Duration(c.DialTimeout)) |
||||
rdop := DialReadTimeout(time.Duration(c.ReadTimeout)) |
||||
wrop := DialWriteTimeout(time.Duration(c.WriteTimeout)) |
||||
auop := DialPassword(c.Auth) |
||||
// new conn
|
||||
cn, err = Dial(c.Proto, c.Addr, cnop, rdop, wrop, auop) |
||||
return |
||||
} |
||||
|
||||
func (c *conn) Close() error { |
||||
c.mu.Lock() |
||||
err := c.err |
||||
if c.err == nil { |
||||
c.err = errors.New("redigo: closed") |
||||
err = c.conn.Close() |
||||
} |
||||
c.mu.Unlock() |
||||
return err |
||||
} |
||||
|
||||
func (c *conn) fatal(err error) error { |
||||
c.mu.Lock() |
||||
if c.err == nil { |
||||
c.err = err |
||||
// Close connection to force errors on subsequent calls and to unblock
|
||||
// other reader or writer.
|
||||
c.conn.Close() |
||||
} |
||||
c.mu.Unlock() |
||||
return errors.WithStack(c.err) |
||||
} |
||||
|
||||
func (c *conn) Err() error { |
||||
c.mu.Lock() |
||||
err := c.err |
||||
c.mu.Unlock() |
||||
return err |
||||
} |
||||
|
||||
func (c *conn) writeLen(prefix byte, n int) error { |
||||
c.lenScratch[len(c.lenScratch)-1] = '\n' |
||||
c.lenScratch[len(c.lenScratch)-2] = '\r' |
||||
i := len(c.lenScratch) - 3 |
||||
for { |
||||
c.lenScratch[i] = byte('0' + n%10) |
||||
i-- |
||||
n = n / 10 |
||||
if n == 0 { |
||||
break |
||||
} |
||||
} |
||||
c.lenScratch[i] = prefix |
||||
_, err := c.bw.Write(c.lenScratch[i:]) |
||||
return errors.WithStack(err) |
||||
} |
||||
|
||||
func (c *conn) writeString(s string) error { |
||||
c.writeLen('$', len(s)) |
||||
c.bw.WriteString(s) |
||||
_, err := c.bw.WriteString("\r\n") |
||||
return errors.WithStack(err) |
||||
} |
||||
|
||||
func (c *conn) writeBytes(p []byte) error { |
||||
c.writeLen('$', len(p)) |
||||
c.bw.Write(p) |
||||
_, err := c.bw.WriteString("\r\n") |
||||
return errors.WithStack(err) |
||||
} |
||||
|
||||
func (c *conn) writeInt64(n int64) error { |
||||
return errors.WithStack(c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10))) |
||||
} |
||||
|
||||
func (c *conn) writeFloat64(n float64) error { |
||||
return errors.WithStack(c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))) |
||||
} |
||||
|
||||
func (c *conn) writeCommand(cmd string, args []interface{}) (err error) { |
||||
if c.writeTimeout != 0 { |
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
||||
} |
||||
c.writeLen('*', 1+len(args)) |
||||
err = c.writeString(cmd) |
||||
for _, arg := range args { |
||||
if err != nil { |
||||
break |
||||
} |
||||
switch arg := arg.(type) { |
||||
case string: |
||||
err = c.writeString(arg) |
||||
case []byte: |
||||
err = c.writeBytes(arg) |
||||
case int: |
||||
err = c.writeInt64(int64(arg)) |
||||
case int64: |
||||
err = c.writeInt64(arg) |
||||
case float64: |
||||
err = c.writeFloat64(arg) |
||||
case bool: |
||||
if arg { |
||||
err = c.writeString("1") |
||||
} else { |
||||
err = c.writeString("0") |
||||
} |
||||
case nil: |
||||
err = c.writeString("") |
||||
default: |
||||
var buf bytes.Buffer |
||||
fmt.Fprint(&buf, arg) |
||||
err = errors.WithStack(c.writeBytes(buf.Bytes())) |
||||
} |
||||
} |
||||
return err |
||||
} |
||||
|
||||
type protocolError string |
||||
|
||||
func (pe protocolError) Error() string { |
||||
return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe)) |
||||
} |
||||
|
||||
func (c *conn) readLine() ([]byte, error) { |
||||
p, err := c.br.ReadSlice('\n') |
||||
if err == bufio.ErrBufferFull { |
||||
return nil, errors.WithStack(protocolError("long response line")) |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
i := len(p) - 2 |
||||
if i < 0 || p[i] != '\r' { |
||||
return nil, errors.WithStack(protocolError("bad response line terminator")) |
||||
} |
||||
return p[:i], nil |
||||
} |
||||
|
||||
// parseLen parses bulk string and array lengths.
|
||||
func parseLen(p []byte) (int, error) { |
||||
if len(p) == 0 { |
||||
return -1, errors.WithStack(protocolError("malformed length")) |
||||
} |
||||
|
||||
if p[0] == '-' && len(p) == 2 && p[1] == '1' { |
||||
// handle $-1 and $-1 null replies.
|
||||
return -1, nil |
||||
} |
||||
|
||||
var n int |
||||
for _, b := range p { |
||||
n *= 10 |
||||
if b < '0' || b > '9' { |
||||
return -1, errors.WithStack(protocolError("illegal bytes in length")) |
||||
} |
||||
n += int(b - '0') |
||||
} |
||||
|
||||
return n, nil |
||||
} |
||||
|
||||
// parseInt parses an integer reply.
|
||||
func parseInt(p []byte) (interface{}, error) { |
||||
if len(p) == 0 { |
||||
return 0, errors.WithStack(protocolError("malformed integer")) |
||||
} |
||||
|
||||
var negate bool |
||||
if p[0] == '-' { |
||||
negate = true |
||||
p = p[1:] |
||||
if len(p) == 0 { |
||||
return 0, errors.WithStack(protocolError("malformed integer")) |
||||
} |
||||
} |
||||
|
||||
var n int64 |
||||
for _, b := range p { |
||||
n *= 10 |
||||
if b < '0' || b > '9' { |
||||
return 0, errors.WithStack(protocolError("illegal bytes in length")) |
||||
} |
||||
n += int64(b - '0') |
||||
} |
||||
|
||||
if negate { |
||||
n = -n |
||||
} |
||||
return n, nil |
||||
} |
||||
|
||||
var ( |
||||
okReply interface{} = "OK" |
||||
pongReply interface{} = "PONG" |
||||
) |
||||
|
||||
func (c *conn) readReply() (interface{}, error) { |
||||
line, err := c.readLine() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if len(line) == 0 { |
||||
return nil, errors.WithStack(protocolError("short response line")) |
||||
} |
||||
switch line[0] { |
||||
case '+': |
||||
switch { |
||||
case len(line) == 3 && line[1] == 'O' && line[2] == 'K': |
||||
// Avoid allocation for frequent "+OK" response.
|
||||
return okReply, nil |
||||
case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G': |
||||
// Avoid allocation in PING command benchmarks :)
|
||||
return pongReply, nil |
||||
default: |
||||
return string(line[1:]), nil |
||||
} |
||||
case '-': |
||||
return Error(string(line[1:])), nil |
||||
case ':': |
||||
return parseInt(line[1:]) |
||||
case '$': |
||||
n, err := parseLen(line[1:]) |
||||
if n < 0 || err != nil { |
||||
return nil, err |
||||
} |
||||
p := make([]byte, n) |
||||
_, err = io.ReadFull(c.br, p) |
||||
if err != nil { |
||||
return nil, errors.WithStack(err) |
||||
} |
||||
if line1, err := c.readLine(); err != nil { |
||||
return nil, err |
||||
} else if len(line1) != 0 { |
||||
return nil, errors.WithStack(protocolError("bad bulk string format")) |
||||
} |
||||
return p, nil |
||||
case '*': |
||||
n, err := parseLen(line[1:]) |
||||
if n < 0 || err != nil { |
||||
return nil, err |
||||
} |
||||
r := make([]interface{}, n) |
||||
for i := range r { |
||||
r[i], err = c.readReply() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
return r, nil |
||||
} |
||||
return nil, errors.WithStack(protocolError("unexpected response line")) |
||||
} |
||||
func (c *conn) Send(cmd string, args ...interface{}) (err error) { |
||||
c.mu.Lock() |
||||
c.pending++ |
||||
c.mu.Unlock() |
||||
if err = c.writeCommand(cmd, args); err != nil { |
||||
c.fatal(err) |
||||
} |
||||
return err |
||||
} |
||||
|
||||
func (c *conn) Flush() (err error) { |
||||
if c.writeTimeout != 0 { |
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
||||
} |
||||
if err = c.bw.Flush(); err != nil { |
||||
c.fatal(err) |
||||
} |
||||
return err |
||||
} |
||||
|
||||
func (c *conn) Receive() (reply interface{}, err error) { |
||||
if c.readTimeout != 0 { |
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) |
||||
} |
||||
if reply, err = c.readReply(); err != nil { |
||||
return nil, c.fatal(err) |
||||
} |
||||
// When using pub/sub, the number of receives can be greater than the
|
||||
// number of sends. To enable normal use of the connection after
|
||||
// unsubscribing from all channels, we do not decrement pending to a
|
||||
// negative value.
|
||||
//
|
||||
// The pending field is decremented after the reply is read to handle the
|
||||
// case where Receive is called before Send.
|
||||
c.mu.Lock() |
||||
if c.pending > 0 { |
||||
c.pending-- |
||||
} |
||||
c.mu.Unlock() |
||||
if err, ok := reply.(Error); ok { |
||||
return nil, err |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { |
||||
c.mu.Lock() |
||||
pending := c.pending |
||||
c.pending = 0 |
||||
c.mu.Unlock() |
||||
if cmd == "" && pending == 0 { |
||||
return nil, nil |
||||
} |
||||
var err error |
||||
defer c.stat(cmd, &err)() |
||||
if cmd != "" { |
||||
err = c.writeCommand(cmd, args) |
||||
} |
||||
if err == nil { |
||||
err = errors.WithStack(c.bw.Flush()) |
||||
} |
||||
if err != nil { |
||||
return nil, c.fatal(err) |
||||
} |
||||
if c.readTimeout != 0 { |
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) |
||||
} |
||||
if cmd == "" { |
||||
reply := make([]interface{}, pending) |
||||
for i := range reply { |
||||
var r interface{} |
||||
r, err = c.readReply() |
||||
if err != nil { |
||||
break |
||||
} |
||||
reply[i] = r |
||||
} |
||||
if err != nil { |
||||
return nil, c.fatal(err) |
||||
} |
||||
return reply, nil |
||||
} |
||||
|
||||
var reply interface{} |
||||
for i := 0; i <= pending; i++ { |
||||
var e error |
||||
if reply, e = c.readReply(); e != nil { |
||||
return nil, c.fatal(e) |
||||
} |
||||
if e, ok := reply.(Error); ok && err == nil { |
||||
err = e |
||||
} |
||||
} |
||||
return reply, err |
||||
} |
||||
|
||||
// WithContext FIXME: implement WithContext
|
||||
func (c *conn) WithContext(ctx context.Context) Conn { return c } |
@ -0,0 +1,169 @@ |
||||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
// Package redis is a client for the Redis database.
|
||||
//
|
||||
// The Redigo FAQ (https://github.com/garyburd/redigo/wiki/FAQ) contains more
|
||||
// documentation about this package.
|
||||
//
|
||||
// Connections
|
||||
//
|
||||
// The Conn interface is the primary interface for working with Redis.
|
||||
// Applications create connections by calling the Dial, DialWithTimeout or
|
||||
// NewConn functions. In the future, functions will be added for creating
|
||||
// sharded and other types of connections.
|
||||
//
|
||||
// The application must call the connection Close method when the application
|
||||
// is done with the connection.
|
||||
//
|
||||
// Executing Commands
|
||||
//
|
||||
// The Conn interface has a generic method for executing Redis commands:
|
||||
//
|
||||
// Do(commandName string, args ...interface{}) (reply interface{}, err error)
|
||||
//
|
||||
// The Redis command reference (http://redis.io/commands) lists the available
|
||||
// commands. An example of using the Redis APPEND command is:
|
||||
//
|
||||
// n, err := conn.Do("APPEND", "key", "value")
|
||||
//
|
||||
// The Do method converts command arguments to binary strings for transmission
|
||||
// to the server as follows:
|
||||
//
|
||||
// Go Type Conversion
|
||||
// []byte Sent as is
|
||||
// string Sent as is
|
||||
// int, int64 strconv.FormatInt(v)
|
||||
// float64 strconv.FormatFloat(v, 'g', -1, 64)
|
||||
// bool true -> "1", false -> "0"
|
||||
// nil ""
|
||||
// all other types fmt.Print(v)
|
||||
//
|
||||
// Redis command reply types are represented using the following Go types:
|
||||
//
|
||||
// Redis type Go type
|
||||
// error redis.Error
|
||||
// integer int64
|
||||
// simple string string
|
||||
// bulk string []byte or nil if value not present.
|
||||
// array []interface{} or nil if value not present.
|
||||
//
|
||||
// Use type assertions or the reply helper functions to convert from
|
||||
// interface{} to the specific Go type for the command result.
|
||||
//
|
||||
// Pipelining
|
||||
//
|
||||
// Connections support pipelining using the Send, Flush and Receive methods.
|
||||
//
|
||||
// Send(commandName string, args ...interface{}) error
|
||||
// Flush() error
|
||||
// Receive() (reply interface{}, err error)
|
||||
//
|
||||
// Send writes the command to the connection's output buffer. Flush flushes the
|
||||
// connection's output buffer to the server. Receive reads a single reply from
|
||||
// the server. The following example shows a simple pipeline.
|
||||
//
|
||||
// c.Send("SET", "foo", "bar")
|
||||
// c.Send("GET", "foo")
|
||||
// c.Flush()
|
||||
// c.Receive() // reply from SET
|
||||
// v, err = c.Receive() // reply from GET
|
||||
//
|
||||
// The Do method combines the functionality of the Send, Flush and Receive
|
||||
// methods. The Do method starts by writing the command and flushing the output
|
||||
// buffer. Next, the Do method receives all pending replies including the reply
|
||||
// for the command just sent by Do. If any of the received replies is an error,
|
||||
// then Do returns the error. If there are no errors, then Do returns the last
|
||||
// reply. If the command argument to the Do method is "", then the Do method
|
||||
// will flush the output buffer and receive pending replies without sending a
|
||||
// command.
|
||||
//
|
||||
// Use the Send and Do methods to implement pipelined transactions.
|
||||
//
|
||||
// c.Send("MULTI")
|
||||
// c.Send("INCR", "foo")
|
||||
// c.Send("INCR", "bar")
|
||||
// r, err := c.Do("EXEC")
|
||||
// fmt.Println(r) // prints [1, 1]
|
||||
//
|
||||
// Concurrency
|
||||
//
|
||||
// Connections do not support concurrent calls to the write methods (Send,
|
||||
// Flush) or concurrent calls to the read method (Receive). Connections do
|
||||
// allow a concurrent reader and writer.
|
||||
//
|
||||
// Because the Do method combines the functionality of Send, Flush and Receive,
|
||||
// the Do method cannot be called concurrently with the other methods.
|
||||
//
|
||||
// For full concurrent access to Redis, use the thread-safe Pool to get and
|
||||
// release connections from within a goroutine.
|
||||
//
|
||||
// Publish and Subscribe
|
||||
//
|
||||
// Use the Send, Flush and Receive methods to implement Pub/Sub subscribers.
|
||||
//
|
||||
// c.Send("SUBSCRIBE", "example")
|
||||
// c.Flush()
|
||||
// for {
|
||||
// reply, err := c.Receive()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// // process pushed message
|
||||
// }
|
||||
//
|
||||
// The PubSubConn type wraps a Conn with convenience methods for implementing
|
||||
// subscribers. The Subscribe, PSubscribe, Unsubscribe and PUnsubscribe methods
|
||||
// send and flush a subscription management command. The receive method
|
||||
// converts a pushed message to convenient types for use in a type switch.
|
||||
//
|
||||
// psc := redis.PubSubConn{c}
|
||||
// psc.Subscribe("example")
|
||||
// for {
|
||||
// switch v := psc.Receive().(type) {
|
||||
// case redis.Message:
|
||||
// fmt.Printf("%s: message: %s\n", v.Channel, v.Data)
|
||||
// case redis.Subscription:
|
||||
// fmt.Printf("%s: %s %d\n", v.Channel, v.Kind, v.Count)
|
||||
// case error:
|
||||
// return v
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Reply Helpers
|
||||
//
|
||||
// The Bool, Int, Bytes, String, Strings and Values functions convert a reply
|
||||
// to a value of a specific type. To allow convenient wrapping of calls to the
|
||||
// connection Do and Receive methods, the functions take a second argument of
|
||||
// type error. If the error is non-nil, then the helper function returns the
|
||||
// error. If the error is nil, the function converts the reply to the specified
|
||||
// type:
|
||||
//
|
||||
// exists, err := redis.Bool(c.Do("EXISTS", "foo"))
|
||||
// if err != nil {
|
||||
// // handle error return from c.Do or type conversion error.
|
||||
// }
|
||||
//
|
||||
// The Scan function converts elements of a array reply to Go types:
|
||||
//
|
||||
// var value1 int
|
||||
// var value2 string
|
||||
// reply, err := redis.Values(c.Do("MGET", "key1", "key2"))
|
||||
// if err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
// if _, err := redis.Scan(reply, &value1, &value2); err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
package redis |
@ -0,0 +1,33 @@ |
||||
package redis |
||||
|
||||
import ( |
||||
"strings" |
||||
|
||||
pkgerr "github.com/pkg/errors" |
||||
) |
||||
|
||||
func formatErr(err error) string { |
||||
e := pkgerr.Cause(err) |
||||
switch e { |
||||
case ErrNil, nil: |
||||
return "" |
||||
default: |
||||
es := e.Error() |
||||
switch { |
||||
case strings.HasPrefix(es, "read"): |
||||
return "read timeout" |
||||
case strings.HasPrefix(es, "dial"): |
||||
return "dial timeout" |
||||
case strings.HasPrefix(es, "write"): |
||||
return "write timeout" |
||||
case strings.Contains(es, "EOF"): |
||||
return "eof" |
||||
case strings.Contains(es, "reset"): |
||||
return "reset" |
||||
case strings.Contains(es, "broken"): |
||||
return "broken pipe" |
||||
default: |
||||
return "unexpected err" |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,117 @@ |
||||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"log" |
||||
) |
||||
|
||||
// NewLoggingConn returns a logging wrapper around a connection.
|
||||
func NewLoggingConn(conn Conn, logger *log.Logger, prefix string) Conn { |
||||
if prefix != "" { |
||||
prefix = prefix + "." |
||||
} |
||||
return &loggingConn{conn, logger, prefix} |
||||
} |
||||
|
||||
type loggingConn struct { |
||||
Conn |
||||
logger *log.Logger |
||||
prefix string |
||||
} |
||||
|
||||
func (c *loggingConn) Close() error { |
||||
err := c.Conn.Close() |
||||
var buf bytes.Buffer |
||||
fmt.Fprintf(&buf, "%sClose() -> (%v)", c.prefix, err) |
||||
c.logger.Output(2, buf.String()) |
||||
return err |
||||
} |
||||
|
||||
func (c *loggingConn) printValue(buf *bytes.Buffer, v interface{}) { |
||||
const chop = 32 |
||||
switch v := v.(type) { |
||||
case []byte: |
||||
if len(v) > chop { |
||||
fmt.Fprintf(buf, "%q...", v[:chop]) |
||||
} else { |
||||
fmt.Fprintf(buf, "%q", v) |
||||
} |
||||
case string: |
||||
if len(v) > chop { |
||||
fmt.Fprintf(buf, "%q...", v[:chop]) |
||||
} else { |
||||
fmt.Fprintf(buf, "%q", v) |
||||
} |
||||
case []interface{}: |
||||
if len(v) == 0 { |
||||
buf.WriteString("[]") |
||||
} else { |
||||
sep := "[" |
||||
fin := "]" |
||||
if len(v) > chop { |
||||
v = v[:chop] |
||||
fin = "...]" |
||||
} |
||||
for _, vv := range v { |
||||
buf.WriteString(sep) |
||||
c.printValue(buf, vv) |
||||
sep = ", " |
||||
} |
||||
buf.WriteString(fin) |
||||
} |
||||
default: |
||||
fmt.Fprint(buf, v) |
||||
} |
||||
} |
||||
|
||||
func (c *loggingConn) print(method, commandName string, args []interface{}, reply interface{}, err error) { |
||||
var buf bytes.Buffer |
||||
fmt.Fprintf(&buf, "%s%s(", c.prefix, method) |
||||
if method != "Receive" { |
||||
buf.WriteString(commandName) |
||||
for _, arg := range args { |
||||
buf.WriteString(", ") |
||||
c.printValue(&buf, arg) |
||||
} |
||||
} |
||||
buf.WriteString(") -> (") |
||||
if method != "Send" { |
||||
c.printValue(&buf, reply) |
||||
buf.WriteString(", ") |
||||
} |
||||
fmt.Fprintf(&buf, "%v)", err) |
||||
c.logger.Output(3, buf.String()) |
||||
} |
||||
|
||||
func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{}, error) { |
||||
reply, err := c.Conn.Do(commandName, args...) |
||||
c.print("Do", commandName, args, reply, err) |
||||
return reply, err |
||||
} |
||||
|
||||
func (c *loggingConn) Send(commandName string, args ...interface{}) error { |
||||
err := c.Conn.Send(commandName, args...) |
||||
c.print("Send", commandName, args, nil, err) |
||||
return err |
||||
} |
||||
|
||||
func (c *loggingConn) Receive() (interface{}, error) { |
||||
reply, err := c.Conn.Receive() |
||||
c.print("Receive", "", nil, reply, err) |
||||
return reply, err |
||||
} |
@ -0,0 +1,36 @@ |
||||
package redis |
||||
|
||||
import ( |
||||
"context" |
||||
) |
||||
|
||||
// MockErr for unit test.
|
||||
type MockErr struct { |
||||
Error error |
||||
} |
||||
|
||||
// MockWith return a mock conn.
|
||||
func MockWith(err error) MockErr { |
||||
return MockErr{Error: err} |
||||
} |
||||
|
||||
// Err .
|
||||
func (m MockErr) Err() error { return m.Error } |
||||
|
||||
// Close .
|
||||
func (m MockErr) Close() error { return m.Error } |
||||
|
||||
// Do .
|
||||
func (m MockErr) Do(commandName string, args ...interface{}) (interface{}, error) { return nil, m.Error } |
||||
|
||||
// Send .
|
||||
func (m MockErr) Send(commandName string, args ...interface{}) error { return m.Error } |
||||
|
||||
// Flush .
|
||||
func (m MockErr) Flush() error { return m.Error } |
||||
|
||||
// Receive .
|
||||
func (m MockErr) Receive() (interface{}, error) { return nil, m.Error } |
||||
|
||||
// WithContext .
|
||||
func (m MockErr) WithContext(context.Context) Conn { return m } |
@ -0,0 +1,218 @@ |
||||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis |
||||
|
||||
import ( |
||||
"bytes" |
||||
"context" |
||||
"crypto/rand" |
||||
"crypto/sha1" |
||||
"errors" |
||||
"io" |
||||
"strconv" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/container/pool" |
||||
"github.com/bilibili/kratos/pkg/net/trace" |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
) |
||||
|
||||
var beginTime, _ = time.Parse("2006-01-02 15:04:05", "2006-01-02 15:04:05") |
||||
|
||||
var ( |
||||
errConnClosed = errors.New("redigo: connection closed") |
||||
) |
||||
|
||||
// Pool .
|
||||
type Pool struct { |
||||
*pool.Slice |
||||
// config
|
||||
c *Config |
||||
} |
||||
|
||||
// Config client settings.
|
||||
type Config struct { |
||||
*pool.Config |
||||
|
||||
Name string // redis name, for trace
|
||||
Proto string |
||||
Addr string |
||||
Auth string |
||||
DialTimeout xtime.Duration |
||||
ReadTimeout xtime.Duration |
||||
WriteTimeout xtime.Duration |
||||
} |
||||
|
||||
// NewPool creates a new pool.
|
||||
func NewPool(c *Config, options ...DialOption) (p *Pool) { |
||||
if c.DialTimeout <= 0 || c.ReadTimeout <= 0 || c.WriteTimeout <= 0 { |
||||
panic("must config redis timeout") |
||||
} |
||||
p1 := pool.NewSlice(c.Config) |
||||
cnop := DialConnectTimeout(time.Duration(c.DialTimeout)) |
||||
options = append(options, cnop) |
||||
rdop := DialReadTimeout(time.Duration(c.ReadTimeout)) |
||||
options = append(options, rdop) |
||||
wrop := DialWriteTimeout(time.Duration(c.WriteTimeout)) |
||||
options = append(options, wrop) |
||||
auop := DialPassword(c.Auth) |
||||
options = append(options, auop) |
||||
// new pool
|
||||
p1.New = func(ctx context.Context) (io.Closer, error) { |
||||
conn, err := Dial(c.Proto, c.Addr, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return &traceConn{Conn: conn, connTags: []trace.Tag{trace.TagString(trace.TagPeerAddress, c.Addr)}}, nil |
||||
} |
||||
p = &Pool{Slice: p1, c: c} |
||||
return |
||||
} |
||||
|
||||
// Get gets a connection. The application must close the returned connection.
|
||||
// This method always returns a valid connection so that applications can defer
|
||||
// error handling to the first use of the connection. If there is an error
|
||||
// getting an underlying connection, then the connection Err, Do, Send, Flush
|
||||
// and Receive methods return that error.
|
||||
func (p *Pool) Get(ctx context.Context) Conn { |
||||
c, err := p.Slice.Get(ctx) |
||||
if err != nil { |
||||
return errorConnection{err} |
||||
} |
||||
c1, _ := c.(Conn) |
||||
return &pooledConnection{p: p, c: c1.WithContext(ctx), ctx: ctx, now: beginTime} |
||||
} |
||||
|
||||
// Close releases the resources used by the pool.
|
||||
func (p *Pool) Close() error { |
||||
return p.Slice.Close() |
||||
} |
||||
|
||||
type pooledConnection struct { |
||||
p *Pool |
||||
c Conn |
||||
state int |
||||
|
||||
now time.Time |
||||
cmds []string |
||||
ctx context.Context |
||||
} |
||||
|
||||
var ( |
||||
sentinel []byte |
||||
sentinelOnce sync.Once |
||||
) |
||||
|
||||
func initSentinel() { |
||||
p := make([]byte, 64) |
||||
if _, err := rand.Read(p); err == nil { |
||||
sentinel = p |
||||
} else { |
||||
h := sha1.New() |
||||
io.WriteString(h, "Oops, rand failed. Use time instead.") |
||||
io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10)) |
||||
sentinel = h.Sum(nil) |
||||
} |
||||
} |
||||
|
||||
func (pc *pooledConnection) Close() error { |
||||
c := pc.c |
||||
if _, ok := c.(errorConnection); ok { |
||||
return nil |
||||
} |
||||
pc.c = errorConnection{errConnClosed} |
||||
|
||||
if pc.state&MultiState != 0 { |
||||
c.Send("DISCARD") |
||||
pc.state &^= (MultiState | WatchState) |
||||
} else if pc.state&WatchState != 0 { |
||||
c.Send("UNWATCH") |
||||
pc.state &^= WatchState |
||||
} |
||||
if pc.state&SubscribeState != 0 { |
||||
c.Send("UNSUBSCRIBE") |
||||
c.Send("PUNSUBSCRIBE") |
||||
// To detect the end of the message stream, ask the server to echo
|
||||
// a sentinel value and read until we see that value.
|
||||
sentinelOnce.Do(initSentinel) |
||||
c.Send("ECHO", sentinel) |
||||
c.Flush() |
||||
for { |
||||
p, err := c.Receive() |
||||
if err != nil { |
||||
break |
||||
} |
||||
if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) { |
||||
pc.state &^= SubscribeState |
||||
break |
||||
} |
||||
} |
||||
} |
||||
_, err := c.Do("") |
||||
pc.p.Slice.Put(context.Background(), c, pc.state != 0 || c.Err() != nil) |
||||
return err |
||||
} |
||||
|
||||
func (pc *pooledConnection) Err() error { |
||||
return pc.c.Err() |
||||
} |
||||
|
||||
func (pc *pooledConnection) Do(commandName string, args ...interface{}) (reply interface{}, err error) { |
||||
ci := LookupCommandInfo(commandName) |
||||
pc.state = (pc.state | ci.Set) &^ ci.Clear |
||||
reply, err = pc.c.Do(commandName, args...) |
||||
return |
||||
} |
||||
|
||||
func (pc *pooledConnection) Send(commandName string, args ...interface{}) (err error) { |
||||
ci := LookupCommandInfo(commandName) |
||||
pc.state = (pc.state | ci.Set) &^ ci.Clear |
||||
if pc.now.Equal(beginTime) { |
||||
// mark first send time
|
||||
pc.now = time.Now() |
||||
} |
||||
pc.cmds = append(pc.cmds, commandName) |
||||
return pc.c.Send(commandName, args...) |
||||
} |
||||
|
||||
func (pc *pooledConnection) Flush() error { |
||||
return pc.c.Flush() |
||||
} |
||||
|
||||
func (pc *pooledConnection) Receive() (reply interface{}, err error) { |
||||
reply, err = pc.c.Receive() |
||||
if len(pc.cmds) > 0 { |
||||
pc.cmds = pc.cmds[1:] |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (pc *pooledConnection) WithContext(ctx context.Context) Conn { |
||||
pc.ctx = ctx |
||||
return pc |
||||
} |
||||
|
||||
type errorConnection struct{ err error } |
||||
|
||||
func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) { |
||||
return nil, ec.err |
||||
} |
||||
func (ec errorConnection) Send(string, ...interface{}) error { return ec.err } |
||||
func (ec errorConnection) Err() error { return ec.err } |
||||
func (ec errorConnection) Close() error { return ec.err } |
||||
func (ec errorConnection) Flush() error { return ec.err } |
||||
func (ec errorConnection) Receive() (interface{}, error) { return nil, ec.err } |
||||
func (ec errorConnection) WithContext(context.Context) Conn { return ec } |
@ -0,0 +1,152 @@ |
||||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis |
||||
|
||||
import ( |
||||
"errors" |
||||
|
||||
pkgerr "github.com/pkg/errors" |
||||
) |
||||
|
||||
var ( |
||||
errPubSub = errors.New("redigo: unknown pubsub notification") |
||||
) |
||||
|
||||
// Subscription represents a subscribe or unsubscribe notification.
|
||||
type Subscription struct { |
||||
|
||||
// Kind is "subscribe", "unsubscribe", "psubscribe" or "punsubscribe"
|
||||
Kind string |
||||
|
||||
// The channel that was changed.
|
||||
Channel string |
||||
|
||||
// The current number of subscriptions for connection.
|
||||
Count int |
||||
} |
||||
|
||||
// Message represents a message notification.
|
||||
type Message struct { |
||||
|
||||
// The originating channel.
|
||||
Channel string |
||||
|
||||
// The message data.
|
||||
Data []byte |
||||
} |
||||
|
||||
// PMessage represents a pmessage notification.
|
||||
type PMessage struct { |
||||
|
||||
// The matched pattern.
|
||||
Pattern string |
||||
|
||||
// The originating channel.
|
||||
Channel string |
||||
|
||||
// The message data.
|
||||
Data []byte |
||||
} |
||||
|
||||
// Pong represents a pubsub pong notification.
|
||||
type Pong struct { |
||||
Data string |
||||
} |
||||
|
||||
// PubSubConn wraps a Conn with convenience methods for subscribers.
|
||||
type PubSubConn struct { |
||||
Conn Conn |
||||
} |
||||
|
||||
// Close closes the connection.
|
||||
func (c PubSubConn) Close() error { |
||||
return c.Conn.Close() |
||||
} |
||||
|
||||
// Subscribe subscribes the connection to the specified channels.
|
||||
func (c PubSubConn) Subscribe(channel ...interface{}) error { |
||||
c.Conn.Send("SUBSCRIBE", channel...) |
||||
return c.Conn.Flush() |
||||
} |
||||
|
||||
// PSubscribe subscribes the connection to the given patterns.
|
||||
func (c PubSubConn) PSubscribe(channel ...interface{}) error { |
||||
c.Conn.Send("PSUBSCRIBE", channel...) |
||||
return c.Conn.Flush() |
||||
} |
||||
|
||||
// Unsubscribe unsubscribes the connection from the given channels, or from all
|
||||
// of them if none is given.
|
||||
func (c PubSubConn) Unsubscribe(channel ...interface{}) error { |
||||
c.Conn.Send("UNSUBSCRIBE", channel...) |
||||
return c.Conn.Flush() |
||||
} |
||||
|
||||
// PUnsubscribe unsubscribes the connection from the given patterns, or from all
|
||||
// of them if none is given.
|
||||
func (c PubSubConn) PUnsubscribe(channel ...interface{}) error { |
||||
c.Conn.Send("PUNSUBSCRIBE", channel...) |
||||
return c.Conn.Flush() |
||||
} |
||||
|
||||
// Ping sends a PING to the server with the specified data.
|
||||
func (c PubSubConn) Ping(data string) error { |
||||
c.Conn.Send("PING", data) |
||||
return c.Conn.Flush() |
||||
} |
||||
|
||||
// Receive returns a pushed message as a Subscription, Message, PMessage, Pong
|
||||
// or error. The return value is intended to be used directly in a type switch
|
||||
// as illustrated in the PubSubConn example.
|
||||
func (c PubSubConn) Receive() interface{} { |
||||
reply, err := Values(c.Conn.Receive()) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
var kind string |
||||
reply, err = Scan(reply, &kind) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
switch kind { |
||||
case "message": |
||||
var m Message |
||||
if _, err := Scan(reply, &m.Channel, &m.Data); err != nil { |
||||
return err |
||||
} |
||||
return m |
||||
case "pmessage": |
||||
var pm PMessage |
||||
if _, err := Scan(reply, &pm.Pattern, &pm.Channel, &pm.Data); err != nil { |
||||
return err |
||||
} |
||||
return pm |
||||
case "subscribe", "psubscribe", "unsubscribe", "punsubscribe": |
||||
s := Subscription{Kind: kind} |
||||
if _, err := Scan(reply, &s.Channel, &s.Count); err != nil { |
||||
return err |
||||
} |
||||
return s |
||||
case "pong": |
||||
var p Pong |
||||
if _, err := Scan(reply, &p.Data); err != nil { |
||||
return err |
||||
} |
||||
return p |
||||
} |
||||
return pkgerr.WithStack(errPubSub) |
||||
} |
@ -0,0 +1,51 @@ |
||||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis |
||||
|
||||
import ( |
||||
"context" |
||||
) |
||||
|
||||
// Error represents an error returned in a command reply.
|
||||
type Error string |
||||
|
||||
func (err Error) Error() string { return string(err) } |
||||
|
||||
// Conn represents a connection to a Redis server.
|
||||
type Conn interface { |
||||
// Close closes the connection.
|
||||
Close() error |
||||
|
||||
// Err returns a non-nil value if the connection is broken. The returned
|
||||
// value is either the first non-nil value returned from the underlying
|
||||
// network connection or a protocol parsing error. Applications should
|
||||
// close broken connections.
|
||||
Err() error |
||||
|
||||
// Do sends a command to the server and returns the received reply.
|
||||
Do(commandName string, args ...interface{}) (reply interface{}, err error) |
||||
|
||||
// Send writes the command to the client's output buffer.
|
||||
Send(commandName string, args ...interface{}) error |
||||
|
||||
// Flush flushes the output buffer to the Redis server.
|
||||
Flush() error |
||||
|
||||
// Receive receives a single reply from the Redis server
|
||||
Receive() (reply interface{}, err error) |
||||
|
||||
// WithContext
|
||||
WithContext(ctx context.Context) Conn |
||||
} |
@ -0,0 +1,409 @@ |
||||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis |
||||
|
||||
import ( |
||||
"errors" |
||||
"strconv" |
||||
|
||||
pkgerr "github.com/pkg/errors" |
||||
) |
||||
|
||||
// ErrNil indicates that a reply value is nil.
|
||||
var ErrNil = errors.New("redigo: nil returned") |
||||
|
||||
// Int is a helper that converts a command reply to an integer. If err is not
|
||||
// equal to nil, then Int returns 0, err. Otherwise, Int converts the
|
||||
// reply to an int as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// integer int(reply), nil
|
||||
// bulk string parsed reply, nil
|
||||
// nil 0, ErrNil
|
||||
// other 0, error
|
||||
func Int(reply interface{}, err error) (int, error) { |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
switch reply := reply.(type) { |
||||
case int64: |
||||
x := int(reply) |
||||
if int64(x) != reply { |
||||
return 0, pkgerr.WithStack(strconv.ErrRange) |
||||
} |
||||
return x, nil |
||||
case []byte: |
||||
n, err := strconv.ParseInt(string(reply), 10, 0) |
||||
return int(n), pkgerr.WithStack(err) |
||||
case nil: |
||||
return 0, ErrNil |
||||
case Error: |
||||
return 0, reply |
||||
} |
||||
return 0, pkgerr.Errorf("redigo: unexpected type for Int, got type %T", reply) |
||||
} |
||||
|
||||
// Int64 is a helper that converts a command reply to 64 bit integer. If err is
|
||||
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
|
||||
// reply to an int64 as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// integer reply, nil
|
||||
// bulk string parsed reply, nil
|
||||
// nil 0, ErrNil
|
||||
// other 0, error
|
||||
func Int64(reply interface{}, err error) (int64, error) { |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
switch reply := reply.(type) { |
||||
case int64: |
||||
return reply, nil |
||||
case []byte: |
||||
n, err := strconv.ParseInt(string(reply), 10, 64) |
||||
return n, pkgerr.WithStack(err) |
||||
case nil: |
||||
return 0, ErrNil |
||||
case Error: |
||||
return 0, reply |
||||
} |
||||
return 0, pkgerr.Errorf("redigo: unexpected type for Int64, got type %T", reply) |
||||
} |
||||
|
||||
var errNegativeInt = errors.New("redigo: unexpected value for Uint64") |
||||
|
||||
// Uint64 is a helper that converts a command reply to 64 bit integer. If err is
|
||||
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
|
||||
// reply to an int64 as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// integer reply, nil
|
||||
// bulk string parsed reply, nil
|
||||
// nil 0, ErrNil
|
||||
// other 0, error
|
||||
func Uint64(reply interface{}, err error) (uint64, error) { |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
switch reply := reply.(type) { |
||||
case int64: |
||||
if reply < 0 { |
||||
return 0, pkgerr.WithStack(errNegativeInt) |
||||
} |
||||
return uint64(reply), nil |
||||
case []byte: |
||||
n, err := strconv.ParseUint(string(reply), 10, 64) |
||||
return n, err |
||||
case nil: |
||||
return 0, ErrNil |
||||
case Error: |
||||
return 0, reply |
||||
} |
||||
return 0, pkgerr.Errorf("redigo: unexpected type for Uint64, got type %T", reply) |
||||
} |
||||
|
||||
// Float64 is a helper that converts a command reply to 64 bit float. If err is
|
||||
// not equal to nil, then Float64 returns 0, err. Otherwise, Float64 converts
|
||||
// the reply to an int as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// bulk string parsed reply, nil
|
||||
// nil 0, ErrNil
|
||||
// other 0, error
|
||||
func Float64(reply interface{}, err error) (float64, error) { |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
switch reply := reply.(type) { |
||||
case []byte: |
||||
n, err := strconv.ParseFloat(string(reply), 64) |
||||
return n, pkgerr.WithStack(err) |
||||
case nil: |
||||
return 0, ErrNil |
||||
case Error: |
||||
return 0, reply |
||||
} |
||||
return 0, pkgerr.Errorf("redigo: unexpected type for Float64, got type %T", reply) |
||||
} |
||||
|
||||
// String is a helper that converts a command reply to a string. If err is not
|
||||
// equal to nil, then String returns "", err. Otherwise String converts the
|
||||
// reply to a string as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// bulk string string(reply), nil
|
||||
// simple string reply, nil
|
||||
// nil "", ErrNil
|
||||
// other "", error
|
||||
func String(reply interface{}, err error) (string, error) { |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
switch reply := reply.(type) { |
||||
case []byte: |
||||
return string(reply), nil |
||||
case string: |
||||
return reply, nil |
||||
case nil: |
||||
return "", ErrNil |
||||
case Error: |
||||
return "", reply |
||||
} |
||||
return "", pkgerr.Errorf("redigo: unexpected type for String, got type %T", reply) |
||||
} |
||||
|
||||
// Bytes is a helper that converts a command reply to a slice of bytes. If err
|
||||
// is not equal to nil, then Bytes returns nil, err. Otherwise Bytes converts
|
||||
// the reply to a slice of bytes as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// bulk string reply, nil
|
||||
// simple string []byte(reply), nil
|
||||
// nil nil, ErrNil
|
||||
// other nil, error
|
||||
func Bytes(reply interface{}, err error) ([]byte, error) { |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
switch reply := reply.(type) { |
||||
case []byte: |
||||
return reply, nil |
||||
case string: |
||||
return []byte(reply), nil |
||||
case nil: |
||||
return nil, ErrNil |
||||
case Error: |
||||
return nil, reply |
||||
} |
||||
return nil, pkgerr.Errorf("redigo: unexpected type for Bytes, got type %T", reply) |
||||
} |
||||
|
||||
// Bool is a helper that converts a command reply to a boolean. If err is not
|
||||
// equal to nil, then Bool returns false, err. Otherwise Bool converts the
|
||||
// reply to boolean as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// integer value != 0, nil
|
||||
// bulk string strconv.ParseBool(reply)
|
||||
// nil false, ErrNil
|
||||
// other false, error
|
||||
func Bool(reply interface{}, err error) (bool, error) { |
||||
if err != nil { |
||||
return false, err |
||||
} |
||||
switch reply := reply.(type) { |
||||
case int64: |
||||
return reply != 0, nil |
||||
case []byte: |
||||
b, e := strconv.ParseBool(string(reply)) |
||||
return b, pkgerr.WithStack(e) |
||||
case nil: |
||||
return false, ErrNil |
||||
case Error: |
||||
return false, reply |
||||
} |
||||
return false, pkgerr.Errorf("redigo: unexpected type for Bool, got type %T", reply) |
||||
} |
||||
|
||||
// MultiBulk is a helper that converts an array command reply to a []interface{}.
|
||||
//
|
||||
// Deprecated: Use Values instead.
|
||||
func MultiBulk(reply interface{}, err error) ([]interface{}, error) { return Values(reply, err) } |
||||
|
||||
// Values is a helper that converts an array command reply to a []interface{}.
|
||||
// If err is not equal to nil, then Values returns nil, err. Otherwise, Values
|
||||
// converts the reply as follows:
|
||||
//
|
||||
// Reply type Result
|
||||
// array reply, nil
|
||||
// nil nil, ErrNil
|
||||
// other nil, error
|
||||
func Values(reply interface{}, err error) ([]interface{}, error) { |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
switch reply := reply.(type) { |
||||
case []interface{}: |
||||
return reply, nil |
||||
case nil: |
||||
return nil, ErrNil |
||||
case Error: |
||||
return nil, reply |
||||
} |
||||
return nil, pkgerr.Errorf("redigo: unexpected type for Values, got type %T", reply) |
||||
} |
||||
|
||||
// Strings is a helper that converts an array command reply to a []string. If
|
||||
// err is not equal to nil, then Strings returns nil, err. Nil array items are
|
||||
// converted to "" in the output slice. Strings returns an error if an array
|
||||
// item is not a bulk string or nil.
|
||||
func Strings(reply interface{}, err error) ([]string, error) { |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
switch reply := reply.(type) { |
||||
case []interface{}: |
||||
result := make([]string, len(reply)) |
||||
for i := range reply { |
||||
if reply[i] == nil { |
||||
continue |
||||
} |
||||
p, ok := reply[i].([]byte) |
||||
if !ok { |
||||
return nil, pkgerr.Errorf("redigo: unexpected element type for Strings, got type %T", reply[i]) |
||||
} |
||||
result[i] = string(p) |
||||
} |
||||
return result, nil |
||||
case nil: |
||||
return nil, ErrNil |
||||
case Error: |
||||
return nil, reply |
||||
} |
||||
return nil, pkgerr.Errorf("redigo: unexpected type for Strings, got type %T", reply) |
||||
} |
||||
|
||||
// ByteSlices is a helper that converts an array command reply to a [][]byte.
|
||||
// If err is not equal to nil, then ByteSlices returns nil, err. Nil array
|
||||
// items are stay nil. ByteSlices returns an error if an array item is not a
|
||||
// bulk string or nil.
|
||||
func ByteSlices(reply interface{}, err error) ([][]byte, error) { |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
switch reply := reply.(type) { |
||||
case []interface{}: |
||||
result := make([][]byte, len(reply)) |
||||
for i := range reply { |
||||
if reply[i] == nil { |
||||
continue |
||||
} |
||||
p, ok := reply[i].([]byte) |
||||
if !ok { |
||||
return nil, pkgerr.Errorf("redigo: unexpected element type for ByteSlices, got type %T", reply[i]) |
||||
} |
||||
result[i] = p |
||||
} |
||||
return result, nil |
||||
case nil: |
||||
return nil, ErrNil |
||||
case Error: |
||||
return nil, reply |
||||
} |
||||
return nil, pkgerr.Errorf("redigo: unexpected type for ByteSlices, got type %T", reply) |
||||
} |
||||
|
||||
// Ints is a helper that converts an array command reply to a []int. If
|
||||
// err is not equal to nil, then Ints returns nil, err.
|
||||
func Ints(reply interface{}, err error) ([]int, error) { |
||||
var ints []int |
||||
values, err := Values(reply, err) |
||||
if err != nil { |
||||
return ints, err |
||||
} |
||||
if err := ScanSlice(values, &ints); err != nil { |
||||
return ints, err |
||||
} |
||||
return ints, nil |
||||
} |
||||
|
||||
// Int64s is a helper that converts an array command reply to a []int64. If
|
||||
// err is not equal to nil, then Int64s returns nil, err.
|
||||
func Int64s(reply interface{}, err error) ([]int64, error) { |
||||
var int64s []int64 |
||||
values, err := Values(reply, err) |
||||
if err != nil { |
||||
return int64s, err |
||||
} |
||||
if err := ScanSlice(values, &int64s); err != nil { |
||||
return int64s, err |
||||
} |
||||
return int64s, nil |
||||
} |
||||
|
||||
// StringMap is a helper that converts an array of strings (alternating key, value)
|
||||
// into a map[string]string. The HGETALL and CONFIG GET commands return replies in this format.
|
||||
// Requires an even number of values in result.
|
||||
func StringMap(result interface{}, err error) (map[string]string, error) { |
||||
values, err := Values(result, err) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if len(values)%2 != 0 { |
||||
return nil, pkgerr.New("redigo: StringMap expects even number of values result") |
||||
} |
||||
m := make(map[string]string, len(values)/2) |
||||
for i := 0; i < len(values); i += 2 { |
||||
key, okKey := values[i].([]byte) |
||||
value, okValue := values[i+1].([]byte) |
||||
if !okKey || !okValue { |
||||
return nil, pkgerr.New("redigo: ScanMap key not a bulk string value") |
||||
} |
||||
m[string(key)] = string(value) |
||||
} |
||||
return m, nil |
||||
} |
||||
|
||||
// IntMap is a helper that converts an array of strings (alternating key, value)
|
||||
// into a map[string]int. The HGETALL commands return replies in this format.
|
||||
// Requires an even number of values in result.
|
||||
func IntMap(result interface{}, err error) (map[string]int, error) { |
||||
values, err := Values(result, err) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if len(values)%2 != 0 { |
||||
return nil, pkgerr.New("redigo: IntMap expects even number of values result") |
||||
} |
||||
m := make(map[string]int, len(values)/2) |
||||
for i := 0; i < len(values); i += 2 { |
||||
key, ok := values[i].([]byte) |
||||
if !ok { |
||||
return nil, pkgerr.New("redigo: ScanMap key not a bulk string value") |
||||
} |
||||
value, err := Int(values[i+1], nil) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
m[string(key)] = value |
||||
} |
||||
return m, nil |
||||
} |
||||
|
||||
// Int64Map is a helper that converts an array of strings (alternating key, value)
|
||||
// into a map[string]int64. The HGETALL commands return replies in this format.
|
||||
// Requires an even number of values in result.
|
||||
func Int64Map(result interface{}, err error) (map[string]int64, error) { |
||||
values, err := Values(result, err) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if len(values)%2 != 0 { |
||||
return nil, pkgerr.New("redigo: Int64Map expects even number of values result") |
||||
} |
||||
m := make(map[string]int64, len(values)/2) |
||||
for i := 0; i < len(values); i += 2 { |
||||
key, ok := values[i].([]byte) |
||||
if !ok { |
||||
return nil, pkgerr.New("redigo: ScanMap key not a bulk string value") |
||||
} |
||||
value, err := Int64(values[i+1], nil) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
m[string(key)] = value |
||||
} |
||||
return m, nil |
||||
} |
@ -0,0 +1,559 @@ |
||||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"reflect" |
||||
"strconv" |
||||
"strings" |
||||
"sync" |
||||
|
||||
pkgerr "github.com/pkg/errors" |
||||
) |
||||
|
||||
func ensureLen(d reflect.Value, n int) { |
||||
if n > d.Cap() { |
||||
d.Set(reflect.MakeSlice(d.Type(), n, n)) |
||||
} else { |
||||
d.SetLen(n) |
||||
} |
||||
} |
||||
|
||||
func cannotConvert(d reflect.Value, s interface{}) error { |
||||
var sname string |
||||
switch s.(type) { |
||||
case string: |
||||
sname = "Redis simple string" |
||||
case Error: |
||||
sname = "Redis error" |
||||
case int64: |
||||
sname = "Redis integer" |
||||
case []byte: |
||||
sname = "Redis bulk string" |
||||
case []interface{}: |
||||
sname = "Redis array" |
||||
default: |
||||
sname = reflect.TypeOf(s).String() |
||||
} |
||||
return pkgerr.Errorf("cannot convert from %s to %s", sname, d.Type()) |
||||
} |
||||
|
||||
func convertAssignBulkString(d reflect.Value, s []byte) (err error) { |
||||
switch d.Type().Kind() { |
||||
case reflect.Float32, reflect.Float64: |
||||
var x float64 |
||||
x, err = strconv.ParseFloat(string(s), d.Type().Bits()) |
||||
d.SetFloat(x) |
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
||||
var x int64 |
||||
x, err = strconv.ParseInt(string(s), 10, d.Type().Bits()) |
||||
d.SetInt(x) |
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
||||
var x uint64 |
||||
x, err = strconv.ParseUint(string(s), 10, d.Type().Bits()) |
||||
d.SetUint(x) |
||||
case reflect.Bool: |
||||
var x bool |
||||
x, err = strconv.ParseBool(string(s)) |
||||
d.SetBool(x) |
||||
case reflect.String: |
||||
d.SetString(string(s)) |
||||
case reflect.Slice: |
||||
if d.Type().Elem().Kind() != reflect.Uint8 { |
||||
err = cannotConvert(d, s) |
||||
} else { |
||||
d.SetBytes(s) |
||||
} |
||||
default: |
||||
err = cannotConvert(d, s) |
||||
} |
||||
err = pkgerr.WithStack(err) |
||||
return |
||||
} |
||||
|
||||
func convertAssignInt(d reflect.Value, s int64) (err error) { |
||||
switch d.Type().Kind() { |
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
||||
d.SetInt(s) |
||||
if d.Int() != s { |
||||
err = strconv.ErrRange |
||||
d.SetInt(0) |
||||
} |
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
||||
if s < 0 { |
||||
err = strconv.ErrRange |
||||
} else { |
||||
x := uint64(s) |
||||
d.SetUint(x) |
||||
if d.Uint() != x { |
||||
err = strconv.ErrRange |
||||
d.SetUint(0) |
||||
} |
||||
} |
||||
case reflect.Bool: |
||||
d.SetBool(s != 0) |
||||
default: |
||||
err = cannotConvert(d, s) |
||||
} |
||||
err = pkgerr.WithStack(err) |
||||
return |
||||
} |
||||
|
||||
func convertAssignValue(d reflect.Value, s interface{}) (err error) { |
||||
switch s := s.(type) { |
||||
case []byte: |
||||
err = convertAssignBulkString(d, s) |
||||
case int64: |
||||
err = convertAssignInt(d, s) |
||||
default: |
||||
err = cannotConvert(d, s) |
||||
} |
||||
return err |
||||
} |
||||
|
||||
func convertAssignArray(d reflect.Value, s []interface{}) error { |
||||
if d.Type().Kind() != reflect.Slice { |
||||
return cannotConvert(d, s) |
||||
} |
||||
ensureLen(d, len(s)) |
||||
for i := 0; i < len(s); i++ { |
||||
if err := convertAssignValue(d.Index(i), s[i]); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func convertAssign(d interface{}, s interface{}) (err error) { |
||||
// Handle the most common destination types using type switches and
|
||||
// fall back to reflection for all other types.
|
||||
switch s := s.(type) { |
||||
case nil: |
||||
// ingore
|
||||
case []byte: |
||||
switch d := d.(type) { |
||||
case *string: |
||||
*d = string(s) |
||||
case *int: |
||||
*d, err = strconv.Atoi(string(s)) |
||||
case *bool: |
||||
*d, err = strconv.ParseBool(string(s)) |
||||
case *[]byte: |
||||
*d = s |
||||
case *interface{}: |
||||
*d = s |
||||
case nil: |
||||
// skip value
|
||||
default: |
||||
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr { |
||||
err = cannotConvert(d, s) |
||||
} else { |
||||
err = convertAssignBulkString(d.Elem(), s) |
||||
} |
||||
} |
||||
case int64: |
||||
switch d := d.(type) { |
||||
case *int: |
||||
x := int(s) |
||||
if int64(x) != s { |
||||
err = strconv.ErrRange |
||||
x = 0 |
||||
} |
||||
*d = x |
||||
case *bool: |
||||
*d = s != 0 |
||||
case *interface{}: |
||||
*d = s |
||||
case nil: |
||||
// skip value
|
||||
default: |
||||
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr { |
||||
err = cannotConvert(d, s) |
||||
} else { |
||||
err = convertAssignInt(d.Elem(), s) |
||||
} |
||||
} |
||||
case string: |
||||
switch d := d.(type) { |
||||
case *string: |
||||
*d = string(s) |
||||
default: |
||||
err = cannotConvert(reflect.ValueOf(d), s) |
||||
} |
||||
case []interface{}: |
||||
switch d := d.(type) { |
||||
case *[]interface{}: |
||||
*d = s |
||||
case *interface{}: |
||||
*d = s |
||||
case nil: |
||||
// skip value
|
||||
default: |
||||
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr { |
||||
err = cannotConvert(d, s) |
||||
} else { |
||||
err = convertAssignArray(d.Elem(), s) |
||||
} |
||||
} |
||||
case Error: |
||||
err = s |
||||
default: |
||||
err = cannotConvert(reflect.ValueOf(d), s) |
||||
} |
||||
err = pkgerr.WithStack(err) |
||||
return |
||||
} |
||||
|
||||
// Scan copies from src to the values pointed at by dest.
|
||||
//
|
||||
// The values pointed at by dest must be an integer, float, boolean, string,
|
||||
// []byte, interface{} or slices of these types. Scan uses the standard strconv
|
||||
// package to convert bulk strings to numeric and boolean types.
|
||||
//
|
||||
// If a dest value is nil, then the corresponding src value is skipped.
|
||||
//
|
||||
// If a src element is nil, then the corresponding dest value is not modified.
|
||||
//
|
||||
// To enable easy use of Scan in a loop, Scan returns the slice of src
|
||||
// following the copied values.
|
||||
func Scan(src []interface{}, dest ...interface{}) ([]interface{}, error) { |
||||
if len(src) < len(dest) { |
||||
return nil, pkgerr.New("redigo.Scan: array short") |
||||
} |
||||
var err error |
||||
for i, d := range dest { |
||||
err = convertAssign(d, src[i]) |
||||
if err != nil { |
||||
err = fmt.Errorf("redigo.Scan: cannot assign to dest %d: %v", i, err) |
||||
break |
||||
} |
||||
} |
||||
return src[len(dest):], err |
||||
} |
||||
|
||||
type fieldSpec struct { |
||||
name string |
||||
index []int |
||||
omitEmpty bool |
||||
} |
||||
|
||||
type structSpec struct { |
||||
m map[string]*fieldSpec |
||||
l []*fieldSpec |
||||
} |
||||
|
||||
func (ss *structSpec) fieldSpec(name []byte) *fieldSpec { |
||||
return ss.m[string(name)] |
||||
} |
||||
|
||||
func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) { |
||||
for i := 0; i < t.NumField(); i++ { |
||||
f := t.Field(i) |
||||
switch { |
||||
case f.PkgPath != "" && !f.Anonymous: |
||||
// Ignore unexported fields.
|
||||
case f.Anonymous: |
||||
// TODO: Handle pointers. Requires change to decoder and
|
||||
// protection against infinite recursion.
|
||||
if f.Type.Kind() == reflect.Struct { |
||||
compileStructSpec(f.Type, depth, append(index, i), ss) |
||||
} |
||||
default: |
||||
fs := &fieldSpec{name: f.Name} |
||||
tag := f.Tag.Get("redis") |
||||
p := strings.Split(tag, ",") |
||||
if len(p) > 0 { |
||||
if p[0] == "-" { |
||||
continue |
||||
} |
||||
if len(p[0]) > 0 { |
||||
fs.name = p[0] |
||||
} |
||||
for _, s := range p[1:] { |
||||
switch s { |
||||
case "omitempty": |
||||
fs.omitEmpty = true |
||||
default: |
||||
panic(fmt.Errorf("redigo: unknown field tag %s for type %s", s, t.Name())) |
||||
} |
||||
} |
||||
} |
||||
d, found := depth[fs.name] |
||||
if !found { |
||||
d = 1 << 30 |
||||
} |
||||
switch { |
||||
case len(index) == d: |
||||
// At same depth, remove from result.
|
||||
delete(ss.m, fs.name) |
||||
j := 0 |
||||
for i1 := 0; i1 < len(ss.l); i1++ { |
||||
if fs.name != ss.l[i1].name { |
||||
ss.l[j] = ss.l[i1] |
||||
j++ |
||||
} |
||||
} |
||||
ss.l = ss.l[:j] |
||||
case len(index) < d: |
||||
fs.index = make([]int, len(index)+1) |
||||
copy(fs.index, index) |
||||
fs.index[len(index)] = i |
||||
depth[fs.name] = len(index) |
||||
ss.m[fs.name] = fs |
||||
ss.l = append(ss.l, fs) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
var ( |
||||
structSpecMutex sync.RWMutex |
||||
structSpecCache = make(map[reflect.Type]*structSpec) |
||||
) |
||||
|
||||
func structSpecForType(t reflect.Type) *structSpec { |
||||
|
||||
structSpecMutex.RLock() |
||||
ss, found := structSpecCache[t] |
||||
structSpecMutex.RUnlock() |
||||
if found { |
||||
return ss |
||||
} |
||||
|
||||
structSpecMutex.Lock() |
||||
defer structSpecMutex.Unlock() |
||||
ss, found = structSpecCache[t] |
||||
if found { |
||||
return ss |
||||
} |
||||
|
||||
ss = &structSpec{m: make(map[string]*fieldSpec)} |
||||
compileStructSpec(t, make(map[string]int), nil, ss) |
||||
structSpecCache[t] = ss |
||||
return ss |
||||
} |
||||
|
||||
var errScanStructValue = errors.New("redigo.ScanStruct: value must be non-nil pointer to a struct") |
||||
|
||||
// ScanStruct scans alternating names and values from src to a struct. The
|
||||
// HGETALL and CONFIG GET commands return replies in this format.
|
||||
//
|
||||
// ScanStruct uses exported field names to match values in the response. Use
|
||||
// 'redis' field tag to override the name:
|
||||
//
|
||||
// Field int `redis:"myName"`
|
||||
//
|
||||
// Fields with the tag redis:"-" are ignored.
|
||||
//
|
||||
// Integer, float, boolean, string and []byte fields are supported. Scan uses the
|
||||
// standard strconv package to convert bulk string values to numeric and
|
||||
// boolean types.
|
||||
//
|
||||
// If a src element is nil, then the corresponding field is not modified.
|
||||
func ScanStruct(src []interface{}, dest interface{}) error { |
||||
d := reflect.ValueOf(dest) |
||||
if d.Kind() != reflect.Ptr || d.IsNil() { |
||||
return pkgerr.WithStack(errScanStructValue) |
||||
} |
||||
d = d.Elem() |
||||
if d.Kind() != reflect.Struct { |
||||
return pkgerr.WithStack(errScanStructValue) |
||||
} |
||||
ss := structSpecForType(d.Type()) |
||||
|
||||
if len(src)%2 != 0 { |
||||
return pkgerr.New("redigo.ScanStruct: number of values not a multiple of 2") |
||||
} |
||||
|
||||
for i := 0; i < len(src); i += 2 { |
||||
s := src[i+1] |
||||
if s == nil { |
||||
continue |
||||
} |
||||
name, ok := src[i].([]byte) |
||||
if !ok { |
||||
return pkgerr.Errorf("redigo.ScanStruct: key %d not a bulk string value", i) |
||||
} |
||||
fs := ss.fieldSpec(name) |
||||
if fs == nil { |
||||
continue |
||||
} |
||||
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil { |
||||
return pkgerr.Errorf("redigo.ScanStruct: cannot assign field %s: %v", fs.name, err) |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
var ( |
||||
errScanSliceValue = errors.New("redigo.ScanSlice: dest must be non-nil pointer to a struct") |
||||
) |
||||
|
||||
// ScanSlice scans src to the slice pointed to by dest. The elements the dest
|
||||
// slice must be integer, float, boolean, string, struct or pointer to struct
|
||||
// values.
|
||||
//
|
||||
// Struct fields must be integer, float, boolean or string values. All struct
|
||||
// fields are used unless a subset is specified using fieldNames.
|
||||
func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error { |
||||
d := reflect.ValueOf(dest) |
||||
if d.Kind() != reflect.Ptr || d.IsNil() { |
||||
return pkgerr.WithStack(errScanSliceValue) |
||||
} |
||||
d = d.Elem() |
||||
if d.Kind() != reflect.Slice { |
||||
return pkgerr.WithStack(errScanSliceValue) |
||||
} |
||||
|
||||
isPtr := false |
||||
t := d.Type().Elem() |
||||
if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct { |
||||
isPtr = true |
||||
t = t.Elem() |
||||
} |
||||
|
||||
if t.Kind() != reflect.Struct { |
||||
ensureLen(d, len(src)) |
||||
for i, s := range src { |
||||
if s == nil { |
||||
continue |
||||
} |
||||
if err := convertAssignValue(d.Index(i), s); err != nil { |
||||
return pkgerr.Errorf("redigo.ScanSlice: cannot assign element %d: %v", i, err) |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
ss := structSpecForType(t) |
||||
fss := ss.l |
||||
if len(fieldNames) > 0 { |
||||
fss = make([]*fieldSpec, len(fieldNames)) |
||||
for i, name := range fieldNames { |
||||
fss[i] = ss.m[name] |
||||
if fss[i] == nil { |
||||
return pkgerr.Errorf("redigo.ScanSlice: ScanSlice bad field name %s", name) |
||||
} |
||||
} |
||||
} |
||||
|
||||
if len(fss) == 0 { |
||||
return pkgerr.New("redigo.ScanSlice: no struct fields") |
||||
} |
||||
|
||||
n := len(src) / len(fss) |
||||
if n*len(fss) != len(src) { |
||||
return pkgerr.New("redigo.ScanSlice: length not a multiple of struct field count") |
||||
} |
||||
|
||||
ensureLen(d, n) |
||||
for i := 0; i < n; i++ { |
||||
d1 := d.Index(i) |
||||
if isPtr { |
||||
if d1.IsNil() { |
||||
d1.Set(reflect.New(t)) |
||||
} |
||||
d1 = d1.Elem() |
||||
} |
||||
for j, fs := range fss { |
||||
s := src[i*len(fss)+j] |
||||
if s == nil { |
||||
continue |
||||
} |
||||
if err := convertAssignValue(d1.FieldByIndex(fs.index), s); err != nil { |
||||
return pkgerr.Errorf("redigo.ScanSlice: cannot assign element %d to field %s: %v", i*len(fss)+j, fs.name, err) |
||||
} |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Args is a helper for constructing command arguments from structured values.
|
||||
type Args []interface{} |
||||
|
||||
// Add returns the result of appending value to args.
|
||||
func (args Args) Add(value ...interface{}) Args { |
||||
return append(args, value...) |
||||
} |
||||
|
||||
// AddFlat returns the result of appending the flattened value of v to args.
|
||||
//
|
||||
// Maps are flattened by appending the alternating keys and map values to args.
|
||||
//
|
||||
// Slices are flattened by appending the slice elements to args.
|
||||
//
|
||||
// Structs are flattened by appending the alternating names and values of
|
||||
// exported fields to args. If v is a nil struct pointer, then nothing is
|
||||
// appended. The 'redis' field tag overrides struct field names. See ScanStruct
|
||||
// for more information on the use of the 'redis' field tag.
|
||||
//
|
||||
// Other types are appended to args as is.
|
||||
func (args Args) AddFlat(v interface{}) Args { |
||||
rv := reflect.ValueOf(v) |
||||
switch rv.Kind() { |
||||
case reflect.Struct: |
||||
args = flattenStruct(args, rv) |
||||
case reflect.Slice: |
||||
for i := 0; i < rv.Len(); i++ { |
||||
args = append(args, rv.Index(i).Interface()) |
||||
} |
||||
case reflect.Map: |
||||
for _, k := range rv.MapKeys() { |
||||
args = append(args, k.Interface(), rv.MapIndex(k).Interface()) |
||||
} |
||||
case reflect.Ptr: |
||||
if rv.Type().Elem().Kind() == reflect.Struct { |
||||
if !rv.IsNil() { |
||||
args = flattenStruct(args, rv.Elem()) |
||||
} |
||||
} else { |
||||
args = append(args, v) |
||||
} |
||||
default: |
||||
args = append(args, v) |
||||
} |
||||
return args |
||||
} |
||||
|
||||
func flattenStruct(args Args, v reflect.Value) Args { |
||||
ss := structSpecForType(v.Type()) |
||||
for _, fs := range ss.l { |
||||
fv := v.FieldByIndex(fs.index) |
||||
if fs.omitEmpty { |
||||
var empty = false |
||||
switch fv.Kind() { |
||||
case reflect.Array, reflect.Map, reflect.Slice, reflect.String: |
||||
empty = fv.Len() == 0 |
||||
case reflect.Bool: |
||||
empty = !fv.Bool() |
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
||||
empty = fv.Int() == 0 |
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: |
||||
empty = fv.Uint() == 0 |
||||
case reflect.Float32, reflect.Float64: |
||||
empty = fv.Float() == 0 |
||||
case reflect.Interface, reflect.Ptr: |
||||
empty = fv.IsNil() |
||||
} |
||||
if empty { |
||||
continue |
||||
} |
||||
} |
||||
args = append(args, fs.name, fv.Interface()) |
||||
} |
||||
return args |
||||
} |
@ -0,0 +1,86 @@ |
||||
// Copyright 2012 Gary Burd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package redis |
||||
|
||||
import ( |
||||
"crypto/sha1" |
||||
"encoding/hex" |
||||
"io" |
||||
"strings" |
||||
) |
||||
|
||||
// Script encapsulates the source, hash and key count for a Lua script. See
|
||||
// http://redis.io/commands/eval for information on scripts in Redis.
|
||||
type Script struct { |
||||
keyCount int |
||||
src string |
||||
hash string |
||||
} |
||||
|
||||
// NewScript returns a new script object. If keyCount is greater than or equal
|
||||
// to zero, then the count is automatically inserted in the EVAL command
|
||||
// argument list. If keyCount is less than zero, then the application supplies
|
||||
// the count as the first value in the keysAndArgs argument to the Do, Send and
|
||||
// SendHash methods.
|
||||
func NewScript(keyCount int, src string) *Script { |
||||
h := sha1.New() |
||||
io.WriteString(h, src) |
||||
return &Script{keyCount, src, hex.EncodeToString(h.Sum(nil))} |
||||
} |
||||
|
||||
func (s *Script) args(spec string, keysAndArgs []interface{}) []interface{} { |
||||
var args []interface{} |
||||
if s.keyCount < 0 { |
||||
args = make([]interface{}, 1+len(keysAndArgs)) |
||||
args[0] = spec |
||||
copy(args[1:], keysAndArgs) |
||||
} else { |
||||
args = make([]interface{}, 2+len(keysAndArgs)) |
||||
args[0] = spec |
||||
args[1] = s.keyCount |
||||
copy(args[2:], keysAndArgs) |
||||
} |
||||
return args |
||||
} |
||||
|
||||
// Do evaluates the script. Under the covers, Do optimistically evaluates the
|
||||
// script using the EVALSHA command. If the command fails because the script is
|
||||
// not loaded, then Do evaluates the script using the EVAL command (thus
|
||||
// causing the script to load).
|
||||
func (s *Script) Do(c Conn, keysAndArgs ...interface{}) (interface{}, error) { |
||||
v, err := c.Do("EVALSHA", s.args(s.hash, keysAndArgs)...) |
||||
if e, ok := err.(Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") { |
||||
v, err = c.Do("EVAL", s.args(s.src, keysAndArgs)...) |
||||
} |
||||
return v, err |
||||
} |
||||
|
||||
// SendHash evaluates the script without waiting for the reply. The script is
|
||||
// evaluated with the EVALSHA command. The application must ensure that the
|
||||
// script is loaded by a previous call to Send, Do or Load methods.
|
||||
func (s *Script) SendHash(c Conn, keysAndArgs ...interface{}) error { |
||||
return c.Send("EVALSHA", s.args(s.hash, keysAndArgs)...) |
||||
} |
||||
|
||||
// Send evaluates the script without waiting for the reply.
|
||||
func (s *Script) Send(c Conn, keysAndArgs ...interface{}) error { |
||||
return c.Send("EVAL", s.args(s.src, keysAndArgs)...) |
||||
} |
||||
|
||||
// Load loads the script without evaluating it.
|
||||
func (s *Script) Load(c Conn) error { |
||||
_, err := c.Do("SCRIPT", "LOAD", s.src) |
||||
return err |
||||
} |
@ -0,0 +1,142 @@ |
||||
package redis |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/log" |
||||
"github.com/bilibili/kratos/pkg/net/trace" |
||||
) |
||||
|
||||
const ( |
||||
_traceComponentName = "pkg/cache/redis" |
||||
_tracePeerService = "redis" |
||||
_traceSpanKind = "client" |
||||
_slowLogDuration = time.Millisecond * 250 |
||||
) |
||||
|
||||
var _internalTags = []trace.Tag{ |
||||
trace.TagString(trace.TagSpanKind, _traceSpanKind), |
||||
trace.TagString(trace.TagComponent, _traceComponentName), |
||||
trace.TagString(trace.TagPeerService, _tracePeerService), |
||||
} |
||||
|
||||
type traceConn struct { |
||||
// tr for pipeline, if tr != nil meaning on pipeline
|
||||
tr trace.Trace |
||||
ctx context.Context |
||||
// connTag include e.g. ip,port
|
||||
connTags []trace.Tag |
||||
|
||||
// origin redis conn
|
||||
Conn |
||||
pending int |
||||
} |
||||
|
||||
func (t *traceConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) { |
||||
statement := getStatement(commandName, args...) |
||||
defer slowLog(statement, time.Now()) |
||||
root, ok := trace.FromContext(t.ctx) |
||||
// NOTE: ignored empty commandName
|
||||
// current sdk will Do empty command after pipeline finished
|
||||
if !ok || commandName == "" { |
||||
return t.Conn.Do(commandName, args...) |
||||
} |
||||
tr := root.Fork("", "Redis:"+commandName) |
||||
tr.SetTag(_internalTags...) |
||||
tr.SetTag(t.connTags...) |
||||
tr.SetTag(trace.TagString(trace.TagDBStatement, statement)) |
||||
reply, err = t.Conn.Do(commandName, args...) |
||||
tr.Finish(&err) |
||||
return |
||||
} |
||||
|
||||
func (t *traceConn) Send(commandName string, args ...interface{}) error { |
||||
statement := getStatement(commandName, args...) |
||||
defer slowLog(statement, time.Now()) |
||||
t.pending++ |
||||
root, ok := trace.FromContext(t.ctx) |
||||
if !ok { |
||||
return t.Conn.Send(commandName, args...) |
||||
} |
||||
if t.tr == nil { |
||||
t.tr = root.Fork("", "Redis:Pipeline") |
||||
t.tr.SetTag(_internalTags...) |
||||
t.tr.SetTag(t.connTags...) |
||||
} |
||||
t.tr.SetLog( |
||||
trace.Log(trace.LogEvent, "Send"), |
||||
trace.Log("db.statement", statement), |
||||
) |
||||
err := t.Conn.Send(commandName, args...) |
||||
if err != nil { |
||||
t.tr.SetTag(trace.TagBool(trace.TagError, true)) |
||||
t.tr.SetLog( |
||||
trace.Log(trace.LogEvent, "Send Fail"), |
||||
trace.Log(trace.LogMessage, err.Error()), |
||||
) |
||||
} |
||||
return err |
||||
} |
||||
|
||||
func (t *traceConn) Flush() error { |
||||
defer slowLog("Flush", time.Now()) |
||||
if t.tr == nil { |
||||
return t.Conn.Flush() |
||||
} |
||||
t.tr.SetLog(trace.Log(trace.LogEvent, "Flush")) |
||||
err := t.Conn.Flush() |
||||
if err != nil { |
||||
t.tr.SetTag(trace.TagBool(trace.TagError, true)) |
||||
t.tr.SetLog( |
||||
trace.Log(trace.LogEvent, "Flush Fail"), |
||||
trace.Log(trace.LogMessage, err.Error()), |
||||
) |
||||
} |
||||
return err |
||||
} |
||||
|
||||
func (t *traceConn) Receive() (reply interface{}, err error) { |
||||
defer slowLog("Receive", time.Now()) |
||||
if t.tr == nil { |
||||
return t.Conn.Receive() |
||||
} |
||||
t.tr.SetLog(trace.Log(trace.LogEvent, "Receive")) |
||||
reply, err = t.Conn.Receive() |
||||
if err != nil { |
||||
t.tr.SetTag(trace.TagBool(trace.TagError, true)) |
||||
t.tr.SetLog( |
||||
trace.Log(trace.LogEvent, "Receive Fail"), |
||||
trace.Log(trace.LogMessage, err.Error()), |
||||
) |
||||
} |
||||
if t.pending > 0 { |
||||
t.pending-- |
||||
} |
||||
if t.pending == 0 { |
||||
t.tr.Finish(nil) |
||||
t.tr = nil |
||||
} |
||||
return reply, err |
||||
} |
||||
|
||||
func (t *traceConn) WithContext(ctx context.Context) Conn { |
||||
t.ctx = ctx |
||||
return t |
||||
} |
||||
|
||||
func slowLog(statement string, now time.Time) { |
||||
du := time.Since(now) |
||||
if du > _slowLogDuration { |
||||
log.Warn("%s slow log statement: %s time: %v", _tracePeerService, statement, du) |
||||
} |
||||
} |
||||
|
||||
func getStatement(commandName string, args ...interface{}) (res string) { |
||||
res = commandName |
||||
if len(args) > 0 { |
||||
res = fmt.Sprintf("%s %v", commandName, args[0]) |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,40 @@ |
||||
### database/hbase |
||||
|
||||
### 项目简介 |
||||
|
||||
Hbase Client,进行封装加入了链路追踪和统计。 |
||||
|
||||
### usage |
||||
```go |
||||
package main |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
|
||||
"github.com/bilibili/kratos/pkg/database/hbase" |
||||
) |
||||
|
||||
func main() { |
||||
config := &hbase.Config{Zookeeper: &hbase.ZKConfig{Addrs: []string{"localhost"}}} |
||||
client := hbase.NewClient(config) |
||||
|
||||
values := map[string]map[string][]byte{"name": {"firstname": []byte("hello"), "lastname": []byte("world")}} |
||||
ctx := context.Background() |
||||
|
||||
_, err := client.PutStr(ctx, "user", "user1", values) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
|
||||
result, err := client.GetStr(ctx, "user", "user1") |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
fmt.Printf("%v", result) |
||||
} |
||||
``` |
||||
|
||||
##### 依赖包 |
||||
|
||||
1.[gohbase](https://github.com/tsuna/gohbase) |
@ -0,0 +1,23 @@ |
||||
package hbase |
||||
|
||||
import ( |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
) |
||||
|
||||
// ZKConfig Server&Client settings.
|
||||
type ZKConfig struct { |
||||
Root string |
||||
Addrs []string |
||||
Timeout xtime.Duration |
||||
} |
||||
|
||||
// Config hbase config
|
||||
type Config struct { |
||||
Zookeeper *ZKConfig |
||||
RPCQueueSize int |
||||
FlushInterval xtime.Duration |
||||
EffectiveUser string |
||||
RegionLookupTimeout xtime.Duration |
||||
RegionReadTimeout xtime.Duration |
||||
TestRowKey string |
||||
} |
@ -0,0 +1,297 @@ |
||||
package hbase |
||||
|
||||
import ( |
||||
"context" |
||||
"io" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/tsuna/gohbase" |
||||
"github.com/tsuna/gohbase/hrpc" |
||||
|
||||
"github.com/bilibili/kratos/pkg/log" |
||||
) |
||||
|
||||
// HookFunc hook function call before every method and hook return function will call after finish.
|
||||
type HookFunc func(ctx context.Context, call hrpc.Call, customName string) func(err error) |
||||
|
||||
// Client hbase client.
|
||||
type Client struct { |
||||
hc gohbase.Client |
||||
addr string |
||||
config *Config |
||||
hooks []HookFunc |
||||
} |
||||
|
||||
// AddHook add hook function.
|
||||
func (c *Client) AddHook(hookFn HookFunc) { |
||||
c.hooks = append(c.hooks, hookFn) |
||||
} |
||||
|
||||
func (c *Client) invokeHook(ctx context.Context, call hrpc.Call, customName string) func(error) { |
||||
finishHooks := make([]func(error), 0, len(c.hooks)) |
||||
for _, fn := range c.hooks { |
||||
finishHooks = append(finishHooks, fn(ctx, call, customName)) |
||||
} |
||||
return func(err error) { |
||||
for _, fn := range finishHooks { |
||||
fn(err) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// NewClient new a hbase client.
|
||||
func NewClient(config *Config, options ...gohbase.Option) *Client { |
||||
rawcli := NewRawClient(config, options...) |
||||
rawcli.AddHook(NewSlowLogHook(250 * time.Millisecond)) |
||||
rawcli.AddHook(MetricsHook(nil)) |
||||
rawcli.AddHook(TraceHook("database/hbase", strings.Join(config.Zookeeper.Addrs, ","))) |
||||
return rawcli |
||||
} |
||||
|
||||
// NewRawClient new a hbase client without prometheus metrics and dapper trace hook.
|
||||
func NewRawClient(config *Config, options ...gohbase.Option) *Client { |
||||
zkquorum := strings.Join(config.Zookeeper.Addrs, ",") |
||||
if config.Zookeeper.Root != "" { |
||||
options = append(options, gohbase.ZookeeperRoot(config.Zookeeper.Root)) |
||||
} |
||||
if config.Zookeeper.Timeout != 0 { |
||||
options = append(options, gohbase.ZookeeperTimeout(time.Duration(config.Zookeeper.Timeout))) |
||||
} |
||||
|
||||
if config.RPCQueueSize != 0 { |
||||
log.Warn("RPCQueueSize configuration be ignored") |
||||
} |
||||
// force RpcQueueSize = 1, don't change it !!! it has reason (゜-゜)つロ
|
||||
options = append(options, gohbase.RpcQueueSize(1)) |
||||
|
||||
if config.FlushInterval != 0 { |
||||
options = append(options, gohbase.FlushInterval(time.Duration(config.FlushInterval))) |
||||
} |
||||
if config.EffectiveUser != "" { |
||||
options = append(options, gohbase.EffectiveUser(config.EffectiveUser)) |
||||
} |
||||
if config.RegionLookupTimeout != 0 { |
||||
options = append(options, gohbase.RegionLookupTimeout(time.Duration(config.RegionLookupTimeout))) |
||||
} |
||||
if config.RegionReadTimeout != 0 { |
||||
options = append(options, gohbase.RegionReadTimeout(time.Duration(config.RegionReadTimeout))) |
||||
} |
||||
hc := gohbase.NewClient(zkquorum, options...) |
||||
return &Client{ |
||||
hc: hc, |
||||
addr: zkquorum, |
||||
config: config, |
||||
} |
||||
} |
||||
|
||||
// ScanAll do scan command and return all result
|
||||
// NOTE: if err != nil the results is safe for range operate even not result found
|
||||
func (c *Client) ScanAll(ctx context.Context, table []byte, options ...func(hrpc.Call) error) (results []*hrpc.Result, err error) { |
||||
cursor, err := c.Scan(ctx, table, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
for { |
||||
result, err := cursor.Next() |
||||
if err != nil { |
||||
if err == io.EOF { |
||||
break |
||||
} |
||||
return nil, err |
||||
} |
||||
results = append(results, result) |
||||
} |
||||
return results, nil |
||||
} |
||||
|
||||
type scanTrace struct { |
||||
hrpc.Scanner |
||||
finishHook func(error) |
||||
} |
||||
|
||||
func (s *scanTrace) Next() (*hrpc.Result, error) { |
||||
result, err := s.Scanner.Next() |
||||
if err != nil { |
||||
s.finishHook(err) |
||||
} |
||||
return result, err |
||||
} |
||||
|
||||
func (s *scanTrace) Close() error { |
||||
err := s.Scanner.Close() |
||||
s.finishHook(err) |
||||
return err |
||||
} |
||||
|
||||
// Scan do a scan command.
|
||||
func (c *Client) Scan(ctx context.Context, table []byte, options ...func(hrpc.Call) error) (scanner hrpc.Scanner, err error) { |
||||
var scan *hrpc.Scan |
||||
scan, err = hrpc.NewScan(ctx, table, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
st := &scanTrace{} |
||||
st.finishHook = c.invokeHook(ctx, scan, "Scan") |
||||
st.Scanner = c.hc.Scan(scan) |
||||
return st, nil |
||||
} |
||||
|
||||
// ScanStr scan string
|
||||
func (c *Client) ScanStr(ctx context.Context, table string, options ...func(hrpc.Call) error) (hrpc.Scanner, error) { |
||||
return c.Scan(ctx, []byte(table), options...) |
||||
} |
||||
|
||||
// ScanStrAll scan string
|
||||
// NOTE: if err != nil the results is safe for range operate even not result found
|
||||
func (c *Client) ScanStrAll(ctx context.Context, table string, options ...func(hrpc.Call) error) ([]*hrpc.Result, error) { |
||||
return c.ScanAll(ctx, []byte(table), options...) |
||||
} |
||||
|
||||
// ScanRange get a scanner for the given table and key range.
|
||||
// The range is half-open, i.e. [startRow; stopRow[ -- stopRow is not
|
||||
// included in the range.
|
||||
func (c *Client) ScanRange(ctx context.Context, table, startRow, stopRow []byte, options ...func(hrpc.Call) error) (scanner hrpc.Scanner, err error) { |
||||
var scan *hrpc.Scan |
||||
scan, err = hrpc.NewScanRange(ctx, table, startRow, stopRow, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
st := &scanTrace{} |
||||
st.finishHook = c.invokeHook(ctx, scan, "ScanRange") |
||||
st.Scanner = c.hc.Scan(scan) |
||||
return st, nil |
||||
} |
||||
|
||||
// ScanRangeStr get a scanner for the given table and key range.
|
||||
// The range is half-open, i.e. [startRow; stopRow[ -- stopRow is not
|
||||
// included in the range.
|
||||
func (c *Client) ScanRangeStr(ctx context.Context, table, startRow, stopRow string, options ...func(hrpc.Call) error) (hrpc.Scanner, error) { |
||||
return c.ScanRange(ctx, []byte(table), []byte(startRow), []byte(stopRow), options...) |
||||
} |
||||
|
||||
// Get get result for the given table and row key.
|
||||
// NOTE: if err != nil then result != nil, if result not exists result.Cells length is 0
|
||||
func (c *Client) Get(ctx context.Context, table, key []byte, options ...func(hrpc.Call) error) (result *hrpc.Result, err error) { |
||||
var get *hrpc.Get |
||||
get, err = hrpc.NewGet(ctx, table, key, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
finishHook := c.invokeHook(ctx, get, "GET") |
||||
result, err = c.hc.Get(get) |
||||
finishHook(err) |
||||
return |
||||
} |
||||
|
||||
// GetStr do a get command.
|
||||
// NOTE: if err != nil then result != nil, if result not exists result.Cells length is 0
|
||||
func (c *Client) GetStr(ctx context.Context, table, key string, options ...func(hrpc.Call) error) (result *hrpc.Result, err error) { |
||||
return c.Get(ctx, []byte(table), []byte(key), options...) |
||||
} |
||||
|
||||
// PutStr insert the given family-column-values in the given row key of the given table.
|
||||
func (c *Client) PutStr(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) { |
||||
put, err := hrpc.NewPutStr(ctx, table, key, values, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
finishHook := c.invokeHook(ctx, put, "PUT") |
||||
result, err := c.hc.Put(put) |
||||
finishHook(err) |
||||
return result, err |
||||
} |
||||
|
||||
// Delete is used to perform Delete operations on a single row.
|
||||
// To delete entire row, values should be nil.
|
||||
//
|
||||
// To delete specific families, qualifiers map should be nil:
|
||||
// map[string]map[string][]byte{
|
||||
// "cf1": nil,
|
||||
// "cf2": nil,
|
||||
// }
|
||||
//
|
||||
// To delete specific qualifiers:
|
||||
// map[string]map[string][]byte{
|
||||
// "cf": map[string][]byte{
|
||||
// "q1": nil,
|
||||
// "q2": nil,
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// To delete all versions before and at a timestamp, pass hrpc.Timestamp() option.
|
||||
// By default all versions will be removed.
|
||||
//
|
||||
// To delete only a specific version at a timestamp, pass hrpc.DeleteOneVersion() option
|
||||
// along with a timestamp. For delete specific qualifiers request, if timestamp is not
|
||||
// passed, only the latest version will be removed. For delete specific families request,
|
||||
// the timestamp should be passed or it will have no effect as it's an expensive
|
||||
// operation to perform.
|
||||
func (c *Client) Delete(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) { |
||||
del, err := hrpc.NewDelStr(ctx, table, key, values, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
finishHook := c.invokeHook(ctx, del, "Delete") |
||||
result, err := c.hc.Delete(del) |
||||
finishHook(err) |
||||
return result, err |
||||
} |
||||
|
||||
// Append do a append command.
|
||||
func (c *Client) Append(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) { |
||||
appd, err := hrpc.NewAppStr(ctx, table, key, values, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
finishHook := c.invokeHook(ctx, appd, "Append") |
||||
result, err := c.hc.Append(appd) |
||||
finishHook(err) |
||||
return result, err |
||||
} |
||||
|
||||
// Increment the given values in HBase under the given table and key.
|
||||
func (c *Client) Increment(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (int64, error) { |
||||
increment, err := hrpc.NewIncStr(ctx, table, key, values, options...) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
finishHook := c.invokeHook(ctx, increment, "Increment") |
||||
result, err := c.hc.Increment(increment) |
||||
finishHook(err) |
||||
return result, err |
||||
} |
||||
|
||||
// IncrementSingle increment the given value by amount in HBase under the given table, key, family and qualifier.
|
||||
func (c *Client) IncrementSingle(ctx context.Context, table string, key string, family string, qualifier string, amount int64, options ...func(hrpc.Call) error) (int64, error) { |
||||
increment, err := hrpc.NewIncStrSingle(ctx, table, key, family, qualifier, amount, options...) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
|
||||
finishHook := c.invokeHook(ctx, increment, "IncrementSingle") |
||||
result, err := c.hc.Increment(increment) |
||||
finishHook(err) |
||||
return result, err |
||||
} |
||||
|
||||
// Ping ping.
|
||||
func (c *Client) Ping(ctx context.Context) (err error) { |
||||
testRowKey := "test" |
||||
if c.config.TestRowKey != "" { |
||||
testRowKey = c.config.TestRowKey |
||||
} |
||||
values := map[string]map[string][]byte{"test": map[string][]byte{"test": []byte("test")}} |
||||
_, err = c.PutStr(ctx, "test", testRowKey, values) |
||||
return |
||||
} |
||||
|
||||
// Close close client.
|
||||
func (c *Client) Close() error { |
||||
c.hc.Close() |
||||
return nil |
||||
} |
@ -0,0 +1,48 @@ |
||||
package hbase |
||||
|
||||
import ( |
||||
"context" |
||||
"io" |
||||
"time" |
||||
|
||||
"github.com/tsuna/gohbase" |
||||
"github.com/tsuna/gohbase/hrpc" |
||||
|
||||
"github.com/bilibili/kratos/pkg/stat" |
||||
) |
||||
|
||||
func codeFromErr(err error) string { |
||||
code := "unknown_error" |
||||
switch err { |
||||
case gohbase.ErrClientClosed: |
||||
code = "client_closed" |
||||
case gohbase.ErrConnotFindRegion: |
||||
code = "connot_find_region" |
||||
case gohbase.TableNotFound: |
||||
code = "table_not_found" |
||||
case gohbase.ErrRegionUnavailable: |
||||
code = "region_unavailable" |
||||
} |
||||
return code |
||||
} |
||||
|
||||
// MetricsHook if stats is nil use stat.DB as default.
|
||||
func MetricsHook(stats stat.Stat) HookFunc { |
||||
if stats == nil { |
||||
stats = stat.DB |
||||
} |
||||
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) { |
||||
now := time.Now() |
||||
if customName == "" { |
||||
customName = call.Name() |
||||
} |
||||
method := "hbase:" + customName |
||||
return func(err error) { |
||||
durationMs := int64(time.Since(now) / time.Millisecond) |
||||
stats.Timing(method, durationMs) |
||||
if err != nil && err != io.EOF { |
||||
stats.Incr(method, codeFromErr(err)) |
||||
} |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,24 @@ |
||||
package hbase |
||||
|
||||
import ( |
||||
"context" |
||||
"time" |
||||
|
||||
"github.com/tsuna/gohbase/hrpc" |
||||
|
||||
"github.com/bilibili/kratos/pkg/log" |
||||
) |
||||
|
||||
// NewSlowLogHook log slow operation.
|
||||
func NewSlowLogHook(threshold time.Duration) HookFunc { |
||||
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) { |
||||
start := time.Now() |
||||
return func(error) { |
||||
duration := time.Since(start) |
||||
if duration < threshold { |
||||
return |
||||
} |
||||
log.Warn("hbase slow log: %s %s %s time: %s", customName, call.Table(), call.Key(), duration) |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,40 @@ |
||||
package hbase |
||||
|
||||
import ( |
||||
"context" |
||||
"io" |
||||
|
||||
"github.com/tsuna/gohbase/hrpc" |
||||
|
||||
"github.com/bilibili/kratos/pkg/net/trace" |
||||
) |
||||
|
||||
// TraceHook create new hbase trace hook.
|
||||
func TraceHook(component, instance string) HookFunc { |
||||
var internalTags []trace.Tag |
||||
internalTags = append(internalTags, trace.TagString(trace.TagComponent, component)) |
||||
internalTags = append(internalTags, trace.TagString(trace.TagDBInstance, instance)) |
||||
internalTags = append(internalTags, trace.TagString(trace.TagPeerService, "hbase")) |
||||
internalTags = append(internalTags, trace.TagString(trace.TagSpanKind, "client")) |
||||
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) { |
||||
noop := func(error) {} |
||||
root, ok := trace.FromContext(ctx) |
||||
if !ok { |
||||
return noop |
||||
} |
||||
if customName == "" { |
||||
customName = call.Name() |
||||
} |
||||
span := root.Fork("", "Hbase:"+customName) |
||||
span.SetTag(internalTags...) |
||||
statement := string(call.Table()) + " " + string(call.Key()) |
||||
span.SetTag(trace.TagString(trace.TagDBStatement, statement)) |
||||
return func(err error) { |
||||
if err == io.EOF { |
||||
// reset error for trace.
|
||||
err = nil |
||||
} |
||||
span.Finish(&err) |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,9 @@ |
||||
#### database/sql |
||||
|
||||
##### 项目简介 |
||||
MySQL数据库驱动,进行封装加入了链路追踪和统计。 |
||||
|
||||
如果需要SQL级别的超时管理 可以在业务代码里面使用context.WithDeadline实现 推荐超时配置放到application.toml里面 方便热加载 |
||||
|
||||
##### 依赖包 |
||||
1. [Go-MySQL-Driver](https://github.com/go-sql-driver/mysql) |
@ -0,0 +1,40 @@ |
||||
package sql |
||||
|
||||
import ( |
||||
"github.com/bilibili/kratos/pkg/log" |
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||
"github.com/bilibili/kratos/pkg/stat" |
||||
"github.com/bilibili/kratos/pkg/time" |
||||
|
||||
// database driver
|
||||
_ "github.com/go-sql-driver/mysql" |
||||
) |
||||
|
||||
var stats = stat.DB |
||||
|
||||
// Config mysql config.
|
||||
type Config struct { |
||||
Addr string // for trace
|
||||
DSN string // write data source name.
|
||||
ReadDSN []string // read data source name.
|
||||
Active int // pool
|
||||
Idle int // pool
|
||||
IdleTimeout time.Duration // connect max life time.
|
||||
QueryTimeout time.Duration // query sql timeout
|
||||
ExecTimeout time.Duration // execute sql timeout
|
||||
TranTimeout time.Duration // transaction sql timeout
|
||||
Breaker *breaker.Config // breaker
|
||||
} |
||||
|
||||
// NewMySQL new db and retry connection when has error.
|
||||
func NewMySQL(c *Config) (db *DB) { |
||||
if c.QueryTimeout == 0 || c.ExecTimeout == 0 || c.TranTimeout == 0 { |
||||
panic("mysql must be set query/execute/transction timeout") |
||||
} |
||||
db, err := Open(c) |
||||
if err != nil { |
||||
log.Error("open mysql error(%v)", err) |
||||
panic(err) |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,678 @@ |
||||
package sql |
||||
|
||||
import ( |
||||
"context" |
||||
"database/sql" |
||||
"fmt" |
||||
"strings" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode" |
||||
"github.com/bilibili/kratos/pkg/log" |
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||
"github.com/bilibili/kratos/pkg/net/trace" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
const ( |
||||
_family = "sql_client" |
||||
_slowLogDuration = time.Millisecond * 250 |
||||
) |
||||
|
||||
var ( |
||||
// ErrStmtNil prepared stmt error
|
||||
ErrStmtNil = errors.New("sql: prepare failed and stmt nil") |
||||
// ErrNoMaster is returned by Master when call master multiple times.
|
||||
ErrNoMaster = errors.New("sql: no master instance") |
||||
// ErrNoRows is returned by Scan when QueryRow doesn't return a row.
|
||||
// In such a case, QueryRow returns a placeholder *Row value that defers
|
||||
// this error until a Scan.
|
||||
ErrNoRows = sql.ErrNoRows |
||||
// ErrTxDone transaction done.
|
||||
ErrTxDone = sql.ErrTxDone |
||||
) |
||||
|
||||
// DB database.
|
||||
type DB struct { |
||||
write *conn |
||||
read []*conn |
||||
idx int64 |
||||
master *DB |
||||
} |
||||
|
||||
// conn database connection
|
||||
type conn struct { |
||||
*sql.DB |
||||
breaker breaker.Breaker |
||||
conf *Config |
||||
} |
||||
|
||||
// Tx transaction.
|
||||
type Tx struct { |
||||
db *conn |
||||
tx *sql.Tx |
||||
t trace.Trace |
||||
c context.Context |
||||
cancel func() |
||||
} |
||||
|
||||
// Row row.
|
||||
type Row struct { |
||||
err error |
||||
*sql.Row |
||||
db *conn |
||||
query string |
||||
args []interface{} |
||||
t trace.Trace |
||||
cancel func() |
||||
} |
||||
|
||||
// Scan copies the columns from the matched row into the values pointed at by dest.
|
||||
func (r *Row) Scan(dest ...interface{}) (err error) { |
||||
defer slowLog(fmt.Sprintf("Scan query(%s) args(%+v)", r.query, r.args), time.Now()) |
||||
if r.t != nil { |
||||
defer r.t.Finish(&err) |
||||
} |
||||
if r.err != nil { |
||||
err = r.err |
||||
} else if r.Row == nil { |
||||
err = ErrStmtNil |
||||
} |
||||
if err != nil { |
||||
return |
||||
} |
||||
err = r.Row.Scan(dest...) |
||||
if r.cancel != nil { |
||||
r.cancel() |
||||
} |
||||
r.db.onBreaker(&err) |
||||
if err != ErrNoRows { |
||||
err = errors.Wrapf(err, "query %s args %+v", r.query, r.args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Rows rows.
|
||||
type Rows struct { |
||||
*sql.Rows |
||||
cancel func() |
||||
} |
||||
|
||||
// Close closes the Rows, preventing further enumeration. If Next is called
|
||||
// and returns false and there are no further result sets,
|
||||
// the Rows are closed automatically and it will suffice to check the
|
||||
// result of Err. Close is idempotent and does not affect the result of Err.
|
||||
func (rs *Rows) Close() (err error) { |
||||
err = errors.WithStack(rs.Rows.Close()) |
||||
if rs.cancel != nil { |
||||
rs.cancel() |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Stmt prepared stmt.
|
||||
type Stmt struct { |
||||
db *conn |
||||
tx bool |
||||
query string |
||||
stmt atomic.Value |
||||
t trace.Trace |
||||
} |
||||
|
||||
// Open opens a database specified by its database driver name and a
|
||||
// driver-specific data source name, usually consisting of at least a database
|
||||
// name and connection information.
|
||||
func Open(c *Config) (*DB, error) { |
||||
db := new(DB) |
||||
d, err := connect(c, c.DSN) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
brkGroup := breaker.NewGroup(c.Breaker) |
||||
brk := brkGroup.Get(c.Addr) |
||||
w := &conn{DB: d, breaker: brk, conf: c} |
||||
rs := make([]*conn, 0, len(c.ReadDSN)) |
||||
for _, rd := range c.ReadDSN { |
||||
d, err := connect(c, rd) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
brk := brkGroup.Get(parseDSNAddr(rd)) |
||||
r := &conn{DB: d, breaker: brk, conf: c} |
||||
rs = append(rs, r) |
||||
} |
||||
db.write = w |
||||
db.read = rs |
||||
db.master = &DB{write: db.write} |
||||
return db, nil |
||||
} |
||||
|
||||
func connect(c *Config, dataSourceName string) (*sql.DB, error) { |
||||
d, err := sql.Open("mysql", dataSourceName) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
return nil, err |
||||
} |
||||
d.SetMaxOpenConns(c.Active) |
||||
d.SetMaxIdleConns(c.Idle) |
||||
d.SetConnMaxLifetime(time.Duration(c.IdleTimeout)) |
||||
return d, nil |
||||
} |
||||
|
||||
// Begin starts a transaction. The isolation level is dependent on the driver.
|
||||
func (db *DB) Begin(c context.Context) (tx *Tx, err error) { |
||||
return db.write.begin(c) |
||||
} |
||||
|
||||
// Exec executes a query without returning any rows.
|
||||
// The args are for any placeholder parameters in the query.
|
||||
func (db *DB) Exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) { |
||||
return db.write.exec(c, query, args...) |
||||
} |
||||
|
||||
// Prepare creates a prepared statement for later queries or executions.
|
||||
// Multiple queries or executions may be run concurrently from the returned
|
||||
// statement. The caller must call the statement's Close method when the
|
||||
// statement is no longer needed.
|
||||
func (db *DB) Prepare(query string) (*Stmt, error) { |
||||
return db.write.prepare(query) |
||||
} |
||||
|
||||
// Prepared creates a prepared statement for later queries or executions.
|
||||
// Multiple queries or executions may be run concurrently from the returned
|
||||
// statement. The caller must call the statement's Close method when the
|
||||
// statement is no longer needed.
|
||||
func (db *DB) Prepared(query string) (stmt *Stmt) { |
||||
return db.write.prepared(query) |
||||
} |
||||
|
||||
// Query executes a query that returns rows, typically a SELECT. The args are
|
||||
// for any placeholder parameters in the query.
|
||||
func (db *DB) Query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) { |
||||
idx := db.readIndex() |
||||
for i := range db.read { |
||||
if rows, err = db.read[(idx+i)%len(db.read)].query(c, query, args...); !ecode.ServiceUnavailable.Equal(err) { |
||||
return |
||||
} |
||||
} |
||||
return db.write.query(c, query, args...) |
||||
} |
||||
|
||||
// QueryRow executes a query that is expected to return at most one row.
|
||||
// QueryRow always returns a non-nil value. Errors are deferred until Row's
|
||||
// Scan method is called.
|
||||
func (db *DB) QueryRow(c context.Context, query string, args ...interface{}) *Row { |
||||
idx := db.readIndex() |
||||
for i := range db.read { |
||||
if row := db.read[(idx+i)%len(db.read)].queryRow(c, query, args...); !ecode.ServiceUnavailable.Equal(row.err) { |
||||
return row |
||||
} |
||||
} |
||||
return db.write.queryRow(c, query, args...) |
||||
} |
||||
|
||||
func (db *DB) readIndex() int { |
||||
if len(db.read) == 0 { |
||||
return 0 |
||||
} |
||||
v := atomic.AddInt64(&db.idx, 1) |
||||
return int(v) % len(db.read) |
||||
} |
||||
|
||||
// Close closes the write and read database, releasing any open resources.
|
||||
func (db *DB) Close() (err error) { |
||||
if e := db.write.Close(); e != nil { |
||||
err = errors.WithStack(e) |
||||
} |
||||
for _, rd := range db.read { |
||||
if e := rd.Close(); e != nil { |
||||
err = errors.WithStack(e) |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Ping verifies a connection to the database is still alive, establishing a
|
||||
// connection if necessary.
|
||||
func (db *DB) Ping(c context.Context) (err error) { |
||||
if err = db.write.ping(c); err != nil { |
||||
return |
||||
} |
||||
for _, rd := range db.read { |
||||
if err = rd.ping(c); err != nil { |
||||
return |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Master return *DB instance direct use master conn
|
||||
// use this *DB instance only when you have some reason need to get result without any delay.
|
||||
func (db *DB) Master() *DB { |
||||
if db.master == nil { |
||||
panic(ErrNoMaster) |
||||
} |
||||
return db.master |
||||
} |
||||
|
||||
func (db *conn) onBreaker(err *error) { |
||||
if err != nil && *err != nil && *err != sql.ErrNoRows && *err != sql.ErrTxDone { |
||||
db.breaker.MarkFailed() |
||||
} else { |
||||
db.breaker.MarkSuccess() |
||||
} |
||||
} |
||||
|
||||
func (db *conn) begin(c context.Context) (tx *Tx, err error) { |
||||
now := time.Now() |
||||
defer slowLog("Begin", now) |
||||
t, ok := trace.FromContext(c) |
||||
if ok { |
||||
t = t.Fork(_family, "begin") |
||||
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, "")) |
||||
defer func() { |
||||
if err != nil { |
||||
t.Finish(&err) |
||||
} |
||||
}() |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("mysql:begin", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.TranTimeout.Shrink(c) |
||||
rtx, err := db.BeginTx(c, nil) |
||||
stats.Timing("mysql:begin", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
cancel() |
||||
return |
||||
} |
||||
tx = &Tx{tx: rtx, t: t, db: db, c: c, cancel: cancel} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", query, args), now) |
||||
if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "exec") |
||||
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("mysql:exec", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.ExecTimeout.Shrink(c) |
||||
res, err = db.ExecContext(c, query, args...) |
||||
cancel() |
||||
db.onBreaker(&err) |
||||
stats.Timing("mysql:exec", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "exec:%s, args:%+v", query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) ping(c context.Context) (err error) { |
||||
now := time.Now() |
||||
defer slowLog("Ping", now) |
||||
if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "ping") |
||||
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, "")) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("mysql:ping", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.ExecTimeout.Shrink(c) |
||||
err = db.PingContext(c) |
||||
cancel() |
||||
db.onBreaker(&err) |
||||
stats.Timing("mysql:ping", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) prepare(query string) (*Stmt, error) { |
||||
defer slowLog(fmt.Sprintf("Prepare query(%s)", query), time.Now()) |
||||
stmt, err := db.Prepare(query) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "prepare %s", query) |
||||
return nil, err |
||||
} |
||||
st := &Stmt{query: query, db: db} |
||||
st.stmt.Store(stmt) |
||||
return st, nil |
||||
} |
||||
|
||||
func (db *conn) prepared(query string) (stmt *Stmt) { |
||||
defer slowLog(fmt.Sprintf("Prepared query(%s)", query), time.Now()) |
||||
stmt = &Stmt{query: query, db: db} |
||||
s, err := db.Prepare(query) |
||||
if err == nil { |
||||
stmt.stmt.Store(s) |
||||
return |
||||
} |
||||
go func() { |
||||
for { |
||||
s, err := db.Prepare(query) |
||||
if err != nil { |
||||
time.Sleep(time.Second) |
||||
continue |
||||
} |
||||
stmt.stmt.Store(s) |
||||
return |
||||
} |
||||
}() |
||||
return |
||||
} |
||||
|
||||
func (db *conn) query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", query, args), now) |
||||
if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "query") |
||||
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("mysql:query", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.QueryTimeout.Shrink(c) |
||||
rs, err := db.DB.QueryContext(c, query, args...) |
||||
db.onBreaker(&err) |
||||
stats.Timing("mysql:query", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "query:%s, args:%+v", query, args) |
||||
cancel() |
||||
return |
||||
} |
||||
rows = &Rows{Rows: rs, cancel: cancel} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) queryRow(c context.Context, query string, args ...interface{}) *Row { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", query, args), now) |
||||
t, ok := trace.FromContext(c) |
||||
if ok { |
||||
t = t.Fork(_family, "queryrow") |
||||
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query)) |
||||
} |
||||
if err := db.breaker.Allow(); err != nil { |
||||
stats.Incr("mysql:queryrow", "breaker") |
||||
return &Row{db: db, t: t, err: err} |
||||
} |
||||
_, c, cancel := db.conf.QueryTimeout.Shrink(c) |
||||
r := db.DB.QueryRowContext(c, query, args...) |
||||
stats.Timing("mysql:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||
return &Row{db: db, Row: r, query: query, args: args, t: t, cancel: cancel} |
||||
} |
||||
|
||||
// Close closes the statement.
|
||||
func (s *Stmt) Close() (err error) { |
||||
if s == nil { |
||||
err = ErrStmtNil |
||||
return |
||||
} |
||||
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||
if ok { |
||||
err = errors.WithStack(stmt.Close()) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Exec executes a prepared statement with the given arguments and returns a
|
||||
// Result summarizing the effect of the statement.
|
||||
func (s *Stmt) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) { |
||||
if s == nil { |
||||
err = ErrStmtNil |
||||
return |
||||
} |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", s.query, args), now) |
||||
if s.tx { |
||||
if s.t != nil { |
||||
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||
} |
||||
} else if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "exec") |
||||
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = s.db.breaker.Allow(); err != nil { |
||||
stats.Incr("mysql:stmt:exec", "breaker") |
||||
return |
||||
} |
||||
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||
if !ok { |
||||
err = ErrStmtNil |
||||
return |
||||
} |
||||
_, c, cancel := s.db.conf.ExecTimeout.Shrink(c) |
||||
res, err = stmt.ExecContext(c, args...) |
||||
cancel() |
||||
s.db.onBreaker(&err) |
||||
stats.Timing("mysql:stmt:exec", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "exec:%s, args:%+v", s.query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Query executes a prepared query statement with the given arguments and
|
||||
// returns the query results as a *Rows.
|
||||
func (s *Stmt) Query(c context.Context, args ...interface{}) (rows *Rows, err error) { |
||||
if s == nil { |
||||
err = ErrStmtNil |
||||
return |
||||
} |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", s.query, args), now) |
||||
if s.tx { |
||||
if s.t != nil { |
||||
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||
} |
||||
} else if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "query") |
||||
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = s.db.breaker.Allow(); err != nil { |
||||
stats.Incr("mysql:stmt:query", "breaker") |
||||
return |
||||
} |
||||
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||
if !ok { |
||||
err = ErrStmtNil |
||||
return |
||||
} |
||||
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c) |
||||
rs, err := stmt.QueryContext(c, args...) |
||||
s.db.onBreaker(&err) |
||||
stats.Timing("mysql:stmt:query", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "query:%s, args:%+v", s.query, args) |
||||
cancel() |
||||
return |
||||
} |
||||
rows = &Rows{Rows: rs, cancel: cancel} |
||||
return |
||||
} |
||||
|
||||
// QueryRow executes a prepared query statement with the given arguments.
|
||||
// If an error occurs during the execution of the statement, that error will
|
||||
// be returned by a call to Scan on the returned *Row, which is always non-nil.
|
||||
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
|
||||
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
|
||||
func (s *Stmt) QueryRow(c context.Context, args ...interface{}) (row *Row) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", s.query, args), now) |
||||
row = &Row{db: s.db, query: s.query, args: args} |
||||
if s == nil { |
||||
row.err = ErrStmtNil |
||||
return |
||||
} |
||||
if s.tx { |
||||
if s.t != nil { |
||||
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||
} |
||||
} else if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "queryrow") |
||||
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query)) |
||||
row.t = t |
||||
} |
||||
if row.err = s.db.breaker.Allow(); row.err != nil { |
||||
stats.Incr("mysql:stmt:queryrow", "breaker") |
||||
return |
||||
} |
||||
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||
if !ok { |
||||
return |
||||
} |
||||
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c) |
||||
row.Row = stmt.QueryRowContext(c, args...) |
||||
row.cancel = cancel |
||||
stats.Timing("mysql:stmt:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||
return |
||||
} |
||||
|
||||
// Commit commits the transaction.
|
||||
func (tx *Tx) Commit() (err error) { |
||||
err = tx.tx.Commit() |
||||
tx.cancel() |
||||
tx.db.onBreaker(&err) |
||||
if tx.t != nil { |
||||
tx.t.Finish(&err) |
||||
} |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Rollback aborts the transaction.
|
||||
func (tx *Tx) Rollback() (err error) { |
||||
err = tx.tx.Rollback() |
||||
tx.cancel() |
||||
tx.db.onBreaker(&err) |
||||
if tx.t != nil { |
||||
tx.t.Finish(&err) |
||||
} |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Exec executes a query that doesn't return rows. For example: an INSERT and
|
||||
// UPDATE.
|
||||
func (tx *Tx) Exec(query string, args ...interface{}) (res sql.Result, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", query, args), now) |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("exec %s", query))) |
||||
} |
||||
res, err = tx.tx.ExecContext(tx.c, query, args...) |
||||
stats.Timing("mysql:tx:exec", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "exec:%s, args:%+v", query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Query executes a query that returns rows, typically a SELECT.
|
||||
func (tx *Tx) Query(query string, args ...interface{}) (rows *Rows, err error) { |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("query %s", query))) |
||||
} |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", query, args), now) |
||||
defer func() { |
||||
stats.Timing("mysql:tx:query", int64(time.Since(now)/time.Millisecond)) |
||||
}() |
||||
rs, err := tx.tx.QueryContext(tx.c, query, args...) |
||||
if err == nil { |
||||
rows = &Rows{Rows: rs} |
||||
} else { |
||||
err = errors.Wrapf(err, "query:%s, args:%+v", query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// QueryRow executes a query that is expected to return at most one row.
|
||||
// QueryRow always returns a non-nil value. Errors are deferred until Row's
|
||||
// Scan method is called.
|
||||
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("queryrow %s", query))) |
||||
} |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", query, args), now) |
||||
defer func() { |
||||
stats.Timing("mysql:tx:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||
}() |
||||
r := tx.tx.QueryRowContext(tx.c, query, args...) |
||||
return &Row{Row: r, db: tx.db, query: query, args: args} |
||||
} |
||||
|
||||
// Stmt returns a transaction-specific prepared statement from an existing statement.
|
||||
func (tx *Tx) Stmt(stmt *Stmt) *Stmt { |
||||
as, ok := stmt.stmt.Load().(*sql.Stmt) |
||||
if !ok { |
||||
return nil |
||||
} |
||||
ts := tx.tx.StmtContext(tx.c, as) |
||||
st := &Stmt{query: stmt.query, tx: true, t: tx.t, db: tx.db} |
||||
st.stmt.Store(ts) |
||||
return st |
||||
} |
||||
|
||||
// Prepare creates a prepared statement for use within a transaction.
|
||||
// The returned statement operates within the transaction and can no longer be
|
||||
// used once the transaction has been committed or rolled back.
|
||||
// To use an existing prepared statement on this transaction, see Tx.Stmt.
|
||||
func (tx *Tx) Prepare(query string) (*Stmt, error) { |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("prepare %s", query))) |
||||
} |
||||
defer slowLog(fmt.Sprintf("Prepare query(%s)", query), time.Now()) |
||||
stmt, err := tx.tx.Prepare(query) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "prepare %s", query) |
||||
return nil, err |
||||
} |
||||
st := &Stmt{query: query, tx: true, t: tx.t, db: tx.db} |
||||
st.stmt.Store(stmt) |
||||
return st, nil |
||||
} |
||||
|
||||
// parseDSNAddr parse dsn name and return addr.
|
||||
func parseDSNAddr(dsn string) (addr string) { |
||||
if dsn == "" { |
||||
return |
||||
} |
||||
part0 := strings.Split(dsn, "@") |
||||
if len(part0) > 1 { |
||||
part1 := strings.Split(part0[1], "?") |
||||
if len(part1) > 0 { |
||||
addr = part1[0] |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
func slowLog(statement string, now time.Time) { |
||||
du := time.Since(now) |
||||
if du > _slowLogDuration { |
||||
log.Warn("%s slow log statement: %s time: %v", _family, statement, du) |
||||
} |
||||
} |
@ -0,0 +1,14 @@ |
||||
#### database/tidb |
||||
|
||||
##### 项目简介 |
||||
TiDB数据库驱动 对mysql驱动进行封装 |
||||
|
||||
##### 功能 |
||||
1. 支持discovery服务发现 多节点直连 |
||||
2. 支持通过lvs单一地址连接 |
||||
3. 支持prepare绑定多个节点 |
||||
4. 支持动态增减节点负载均衡 |
||||
5. 日志区分运行节点 |
||||
|
||||
##### 依赖包 |
||||
1.[Go-MySQL-Driver](https://github.com/go-sql-driver/mysql) |
@ -0,0 +1,58 @@ |
||||
package tidb |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env" |
||||
"github.com/bilibili/kratos/pkg/log" |
||||
"github.com/bilibili/kratos/pkg/naming" |
||||
"github.com/bilibili/kratos/pkg/naming/discovery" |
||||
) |
||||
|
||||
var _schema = "tidb://" |
||||
|
||||
func (db *DB) nodeList() (nodes []string) { |
||||
var ( |
||||
insInfo *naming.InstancesInfo |
||||
insMap map[string][]*naming.Instance |
||||
ins []*naming.Instance |
||||
ok bool |
||||
) |
||||
if insInfo, ok = db.dis.Fetch(context.Background()); !ok { |
||||
return |
||||
} |
||||
insMap = insInfo.Instances |
||||
if ins, ok = insMap[env.Zone]; !ok || len(ins) == 0 { |
||||
return |
||||
} |
||||
for _, in := range ins { |
||||
for _, addr := range in.Addrs { |
||||
if strings.HasPrefix(addr, _schema) { |
||||
addr = strings.Replace(addr, _schema, "", -1) |
||||
nodes = append(nodes, addr) |
||||
} |
||||
} |
||||
} |
||||
log.Info("tidb get %s instances(%v)", db.appid, nodes) |
||||
return |
||||
} |
||||
|
||||
func (db *DB) disc() (nodes []string) { |
||||
db.dis = discovery.Build(db.appid) |
||||
e := db.dis.Watch() |
||||
select { |
||||
case <-e: |
||||
nodes = db.nodeList() |
||||
case <-time.After(10 * time.Second): |
||||
panic("tidb init discovery err") |
||||
} |
||||
if len(nodes) == 0 { |
||||
panic(fmt.Sprintf("tidb %s no instance", db.appid)) |
||||
} |
||||
go db.nodeproc(e) |
||||
log.Info("init tidb discvoery info successfully") |
||||
return |
||||
} |
@ -0,0 +1,82 @@ |
||||
package tidb |
||||
|
||||
import ( |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/log" |
||||
) |
||||
|
||||
func (db *DB) nodeproc(e <-chan struct{}) { |
||||
if db.dis == nil { |
||||
return |
||||
} |
||||
for { |
||||
<-e |
||||
nodes := db.nodeList() |
||||
if len(nodes) == 0 { |
||||
continue |
||||
} |
||||
cm := make(map[string]*conn) |
||||
var conns []*conn |
||||
for _, conn := range db.conns { |
||||
cm[conn.addr] = conn |
||||
} |
||||
for _, node := range nodes { |
||||
if cm[node] != nil { |
||||
conns = append(conns, cm[node]) |
||||
continue |
||||
} |
||||
c, err := db.connectDSN(genDSN(db.conf.DSN, node)) |
||||
if err == nil { |
||||
conns = append(conns, c) |
||||
} else { |
||||
log.Error("tidb: connect addr: %s err: %+v", node, err) |
||||
} |
||||
} |
||||
if len(conns) == 0 { |
||||
log.Error("tidb: no nodes ignore event") |
||||
continue |
||||
} |
||||
oldConns := db.conns |
||||
db.mutex.Lock() |
||||
db.conns = conns |
||||
db.mutex.Unlock() |
||||
log.Info("tidb: new nodes: %v", nodes) |
||||
var removedConn []*conn |
||||
for _, conn := range oldConns { |
||||
var exist bool |
||||
for _, c := range conns { |
||||
if c.addr == conn.addr { |
||||
exist = true |
||||
break |
||||
} |
||||
} |
||||
if !exist { |
||||
removedConn = append(removedConn, conn) |
||||
} |
||||
} |
||||
go db.closeConns(removedConn) |
||||
} |
||||
} |
||||
|
||||
func (db *DB) closeConns(conns []*conn) { |
||||
if len(conns) == 0 { |
||||
return |
||||
} |
||||
du := db.conf.QueryTimeout |
||||
if db.conf.ExecTimeout > du { |
||||
du = db.conf.ExecTimeout |
||||
} |
||||
if db.conf.TranTimeout > du { |
||||
du = db.conf.TranTimeout |
||||
} |
||||
time.Sleep(time.Duration(du)) |
||||
for _, conn := range conns { |
||||
err := conn.Close() |
||||
if err != nil { |
||||
log.Error("tidb: close removed conn: %s err: %v", conn.addr, err) |
||||
} else { |
||||
log.Info("tidb: close removed conn: %s", conn.addr) |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,739 @@ |
||||
package tidb |
||||
|
||||
import ( |
||||
"context" |
||||
"database/sql" |
||||
"fmt" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/log" |
||||
"github.com/bilibili/kratos/pkg/naming" |
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||
"github.com/bilibili/kratos/pkg/net/trace" |
||||
|
||||
"github.com/go-sql-driver/mysql" |
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
const ( |
||||
_family = "tidb_client" |
||||
_slowLogDuration = time.Millisecond * 250 |
||||
) |
||||
|
||||
var ( |
||||
// ErrStmtNil prepared stmt error
|
||||
ErrStmtNil = errors.New("sql: prepare failed and stmt nil") |
||||
// ErrNoRows is returned by Scan when QueryRow doesn't return a row.
|
||||
// In such a case, QueryRow returns a placeholder *Row value that defers
|
||||
// this error until a Scan.
|
||||
ErrNoRows = sql.ErrNoRows |
||||
// ErrTxDone transaction done.
|
||||
ErrTxDone = sql.ErrTxDone |
||||
) |
||||
|
||||
// DB database.
|
||||
type DB struct { |
||||
conf *Config |
||||
conns []*conn |
||||
idx int64 |
||||
dis naming.Resolver |
||||
appid string |
||||
mutex sync.RWMutex |
||||
breakerGroup *breaker.Group |
||||
} |
||||
|
||||
// conn database connection
|
||||
type conn struct { |
||||
*sql.DB |
||||
breaker breaker.Breaker |
||||
conf *Config |
||||
addr string |
||||
} |
||||
|
||||
// Tx transaction.
|
||||
type Tx struct { |
||||
db *conn |
||||
tx *sql.Tx |
||||
t trace.Trace |
||||
c context.Context |
||||
cancel func() |
||||
} |
||||
|
||||
// Row row.
|
||||
type Row struct { |
||||
err error |
||||
*sql.Row |
||||
db *conn |
||||
query string |
||||
args []interface{} |
||||
t trace.Trace |
||||
cancel func() |
||||
} |
||||
|
||||
// Scan copies the columns from the matched row into the values pointed at by dest.
|
||||
func (r *Row) Scan(dest ...interface{}) (err error) { |
||||
defer slowLog(fmt.Sprintf("Scan addr: %s query(%s) args(%+v)", r.db.addr, r.query, r.args), time.Now()) |
||||
if r.t != nil { |
||||
defer r.t.Finish(&err) |
||||
} |
||||
if r.err != nil { |
||||
err = r.err |
||||
} else if r.Row == nil { |
||||
err = ErrStmtNil |
||||
} |
||||
if err != nil { |
||||
return |
||||
} |
||||
err = r.Row.Scan(dest...) |
||||
if r.cancel != nil { |
||||
r.cancel() |
||||
} |
||||
r.db.onBreaker(&err) |
||||
if err != ErrNoRows { |
||||
err = errors.Wrapf(err, "addr: %s, query %s args %+v", r.db.addr, r.query, r.args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Rows rows.
|
||||
type Rows struct { |
||||
*sql.Rows |
||||
cancel func() |
||||
} |
||||
|
||||
// Close closes the Rows, preventing further enumeration. If Next is called
|
||||
// and returns false and there are no further result sets,
|
||||
// the Rows are closed automatically and it will suffice to check the
|
||||
// result of Err. Close is idempotent and does not affect the result of Err.
|
||||
func (rs *Rows) Close() (err error) { |
||||
err = errors.WithStack(rs.Rows.Close()) |
||||
if rs.cancel != nil { |
||||
rs.cancel() |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Stmt prepared stmt.
|
||||
type Stmt struct { |
||||
db *conn |
||||
tx bool |
||||
query string |
||||
stmt atomic.Value |
||||
t trace.Trace |
||||
} |
||||
|
||||
// Stmts random prepared stmt.
|
||||
type Stmts struct { |
||||
query string |
||||
sts map[string]*Stmt |
||||
mu sync.RWMutex |
||||
db *DB |
||||
} |
||||
|
||||
// Open opens a database specified by its database driver name and a
|
||||
// driver-specific data source name, usually consisting of at least a database
|
||||
// name and connection information.
|
||||
func Open(c *Config) (db *DB, err error) { |
||||
db = &DB{conf: c, breakerGroup: breaker.NewGroup(c.Breaker)} |
||||
cfg, err := mysql.ParseDSN(c.DSN) |
||||
if err != nil { |
||||
return |
||||
} |
||||
var dsns []string |
||||
if cfg.Net == "discovery" { |
||||
db.appid = cfg.Addr |
||||
for _, addr := range db.disc() { |
||||
dsns = append(dsns, genDSN(c.DSN, addr)) |
||||
} |
||||
} else { |
||||
dsns = append(dsns, c.DSN) |
||||
} |
||||
|
||||
cs := make([]*conn, 0, len(dsns)) |
||||
for _, dsn := range dsns { |
||||
r, err := db.connectDSN(dsn) |
||||
if err != nil { |
||||
return db, err |
||||
} |
||||
cs = append(cs, r) |
||||
} |
||||
db.conns = cs |
||||
return |
||||
} |
||||
|
||||
func (db *DB) connectDSN(dsn string) (c *conn, err error) { |
||||
d, err := connect(db.conf, dsn) |
||||
if err != nil { |
||||
return |
||||
} |
||||
addr := parseDSNAddr(dsn) |
||||
brk := db.breakerGroup.Get(addr) |
||||
c = &conn{DB: d, breaker: brk, conf: db.conf, addr: addr} |
||||
return |
||||
} |
||||
|
||||
func connect(c *Config, dataSourceName string) (*sql.DB, error) { |
||||
d, err := sql.Open("mysql", dataSourceName) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
return nil, err |
||||
} |
||||
d.SetMaxOpenConns(c.Active) |
||||
d.SetMaxIdleConns(c.Idle) |
||||
d.SetConnMaxLifetime(time.Duration(c.IdleTimeout)) |
||||
return d, nil |
||||
} |
||||
|
||||
func (db *DB) conn() (c *conn) { |
||||
db.mutex.RLock() |
||||
c = db.conns[db.index()] |
||||
db.mutex.RUnlock() |
||||
return |
||||
} |
||||
|
||||
// Begin starts a transaction. The isolation level is dependent on the driver.
|
||||
func (db *DB) Begin(c context.Context) (tx *Tx, err error) { |
||||
return db.conn().begin(c) |
||||
} |
||||
|
||||
// Exec executes a query without returning any rows.
|
||||
// The args are for any placeholder parameters in the query.
|
||||
func (db *DB) Exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) { |
||||
return db.conn().exec(c, query, args...) |
||||
} |
||||
|
||||
// Prepare creates a prepared statement for later queries or executions.
|
||||
// Multiple queries or executions may be run concurrently from the returned
|
||||
// statement. The caller must call the statement's Close method when the
|
||||
// statement is no longer needed.
|
||||
func (db *DB) Prepare(query string) (*Stmt, error) { |
||||
return db.conn().prepare(query) |
||||
} |
||||
|
||||
// Prepared creates a prepared statement for later queries or executions.
|
||||
// Multiple queries or executions may be run concurrently from the returned
|
||||
// statement. The caller must call the statement's Close method when the
|
||||
// statement is no longer needed.
|
||||
func (db *DB) Prepared(query string) (s *Stmts) { |
||||
s = &Stmts{query: query, sts: make(map[string]*Stmt), db: db} |
||||
for _, c := range db.conns { |
||||
st := c.prepared(query) |
||||
s.mu.Lock() |
||||
s.sts[c.addr] = st |
||||
s.mu.Unlock() |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Query executes a query that returns rows, typically a SELECT. The args are
|
||||
// for any placeholder parameters in the query.
|
||||
func (db *DB) Query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) { |
||||
return db.conn().query(c, query, args...) |
||||
} |
||||
|
||||
// QueryRow executes a query that is expected to return at most one row.
|
||||
// QueryRow always returns a non-nil value. Errors are deferred until Row's
|
||||
// Scan method is called.
|
||||
func (db *DB) QueryRow(c context.Context, query string, args ...interface{}) *Row { |
||||
return db.conn().queryRow(c, query, args...) |
||||
} |
||||
|
||||
func (db *DB) index() int { |
||||
if len(db.conns) == 1 { |
||||
return 0 |
||||
} |
||||
v := atomic.AddInt64(&db.idx, 1) |
||||
return int(v) % len(db.conns) |
||||
} |
||||
|
||||
// Close closes the databases, releasing any open resources.
|
||||
func (db *DB) Close() (err error) { |
||||
db.mutex.RLock() |
||||
defer db.mutex.RUnlock() |
||||
for _, d := range db.conns { |
||||
if e := d.Close(); e != nil { |
||||
err = errors.WithStack(e) |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Ping verifies a connection to the database is still alive, establishing a
|
||||
// connection if necessary.
|
||||
func (db *DB) Ping(c context.Context) (err error) { |
||||
if err = db.conn().ping(c); err != nil { |
||||
return |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) onBreaker(err *error) { |
||||
if err != nil && *err != nil && *err != sql.ErrNoRows && *err != sql.ErrTxDone { |
||||
db.breaker.MarkFailed() |
||||
} else { |
||||
db.breaker.MarkSuccess() |
||||
} |
||||
} |
||||
|
||||
func (db *conn) begin(c context.Context) (tx *Tx, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Begin addr: %s", db.addr), now) |
||||
t, ok := trace.FromContext(c) |
||||
if ok { |
||||
t = t.Fork(_family, "begin") |
||||
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, "")) |
||||
defer func() { |
||||
if err != nil { |
||||
t.Finish(&err) |
||||
} |
||||
}() |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("tidb:begin", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.TranTimeout.Shrink(c) |
||||
rtx, err := db.BeginTx(c, nil) |
||||
stats.Timing("tidb:begin", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
cancel() |
||||
return |
||||
} |
||||
tx = &Tx{tx: rtx, t: t, db: db, c: c, cancel: cancel} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", db.addr, query, args), now) |
||||
if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "exec") |
||||
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("tidb:exec", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.ExecTimeout.Shrink(c) |
||||
res, err = db.ExecContext(c, query, args...) |
||||
cancel() |
||||
db.onBreaker(&err) |
||||
stats.Timing("tidb:exec", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", db.addr, query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) ping(c context.Context) (err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Ping addr: %s", db.addr), now) |
||||
if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "ping") |
||||
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, "")) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("tidb:ping", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.ExecTimeout.Shrink(c) |
||||
err = db.PingContext(c) |
||||
cancel() |
||||
db.onBreaker(&err) |
||||
stats.Timing("tidb:ping", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) prepare(query string) (*Stmt, error) { |
||||
defer slowLog(fmt.Sprintf("Prepare addr: %s query(%s)", db.addr, query), time.Now()) |
||||
stmt, err := db.Prepare(query) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "addr: %s prepare %s", db.addr, query) |
||||
return nil, err |
||||
} |
||||
st := &Stmt{query: query, db: db} |
||||
st.stmt.Store(stmt) |
||||
return st, nil |
||||
} |
||||
|
||||
func (db *conn) prepared(query string) (stmt *Stmt) { |
||||
defer slowLog(fmt.Sprintf("Prepared addr: %s query(%s)", db.addr, query), time.Now()) |
||||
stmt = &Stmt{query: query, db: db} |
||||
s, err := db.Prepare(query) |
||||
if err == nil { |
||||
stmt.stmt.Store(s) |
||||
return |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", db.addr, query, args), now) |
||||
if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "query") |
||||
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("tidb:query", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.QueryTimeout.Shrink(c) |
||||
rs, err := db.DB.QueryContext(c, query, args...) |
||||
db.onBreaker(&err) |
||||
stats.Timing("tidb:query", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", db.addr, query, args) |
||||
cancel() |
||||
return |
||||
} |
||||
rows = &Rows{Rows: rs, cancel: cancel} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) queryRow(c context.Context, query string, args ...interface{}) *Row { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", db.addr, query, args), now) |
||||
t, ok := trace.FromContext(c) |
||||
if ok { |
||||
t = t.Fork(_family, "queryrow") |
||||
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query)) |
||||
} |
||||
if err := db.breaker.Allow(); err != nil { |
||||
stats.Incr("tidb:queryrow", "breaker") |
||||
return &Row{db: db, t: t, err: err} |
||||
} |
||||
_, c, cancel := db.conf.QueryTimeout.Shrink(c) |
||||
r := db.DB.QueryRowContext(c, query, args...) |
||||
stats.Timing("tidb:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||
return &Row{db: db, Row: r, query: query, args: args, t: t, cancel: cancel} |
||||
} |
||||
|
||||
// Close closes the statement.
|
||||
func (s *Stmt) Close() (err error) { |
||||
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||
if ok { |
||||
err = errors.WithStack(stmt.Close()) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (s *Stmt) prepare() (st *sql.Stmt) { |
||||
var ok bool |
||||
if st, ok = s.stmt.Load().(*sql.Stmt); ok { |
||||
return |
||||
} |
||||
var err error |
||||
if st, err = s.db.Prepare(s.query); err == nil { |
||||
s.stmt.Store(st) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Exec executes a prepared statement with the given arguments and returns a
|
||||
// Result summarizing the effect of the statement.
|
||||
func (s *Stmt) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now) |
||||
if s.tx { |
||||
if s.t != nil { |
||||
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||
} |
||||
} else if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "exec") |
||||
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = s.db.breaker.Allow(); err != nil { |
||||
stats.Incr("tidb:stmt:exec", "breaker") |
||||
return |
||||
} |
||||
stmt := s.prepare() |
||||
if stmt == nil { |
||||
err = ErrStmtNil |
||||
return |
||||
} |
||||
_, c, cancel := s.db.conf.ExecTimeout.Shrink(c) |
||||
res, err = stmt.ExecContext(c, args...) |
||||
cancel() |
||||
s.db.onBreaker(&err) |
||||
stats.Timing("tidb:stmt:exec", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", s.db.addr, s.query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Query executes a prepared query statement with the given arguments and
|
||||
// returns the query results as a *Rows.
|
||||
func (s *Stmt) Query(c context.Context, args ...interface{}) (rows *Rows, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now) |
||||
if s.tx { |
||||
if s.t != nil { |
||||
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||
} |
||||
} else if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "query") |
||||
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = s.db.breaker.Allow(); err != nil { |
||||
stats.Incr("tidb:stmt:query", "breaker") |
||||
return |
||||
} |
||||
stmt := s.prepare() |
||||
if stmt == nil { |
||||
err = ErrStmtNil |
||||
return |
||||
} |
||||
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c) |
||||
rs, err := stmt.QueryContext(c, args...) |
||||
s.db.onBreaker(&err) |
||||
stats.Timing("tidb:stmt:query", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", s.db.addr, s.query, args) |
||||
cancel() |
||||
return |
||||
} |
||||
rows = &Rows{Rows: rs, cancel: cancel} |
||||
return |
||||
} |
||||
|
||||
// QueryRow executes a prepared query statement with the given arguments.
|
||||
// If an error occurs during the execution of the statement, that error will
|
||||
// be returned by a call to Scan on the returned *Row, which is always non-nil.
|
||||
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
|
||||
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
|
||||
func (s *Stmt) QueryRow(c context.Context, args ...interface{}) (row *Row) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now) |
||||
row = &Row{db: s.db, query: s.query, args: args} |
||||
if s.tx { |
||||
if s.t != nil { |
||||
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||
} |
||||
} else if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "queryrow") |
||||
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query)) |
||||
row.t = t |
||||
} |
||||
if row.err = s.db.breaker.Allow(); row.err != nil { |
||||
stats.Incr("tidb:stmt:queryrow", "breaker") |
||||
return |
||||
} |
||||
stmt := s.prepare() |
||||
if stmt == nil { |
||||
return |
||||
} |
||||
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c) |
||||
row.Row = stmt.QueryRowContext(c, args...) |
||||
row.cancel = cancel |
||||
stats.Timing("tidb:stmt:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||
return |
||||
} |
||||
|
||||
func (s *Stmts) prepare(conn *conn) (st *Stmt) { |
||||
if conn == nil { |
||||
conn = s.db.conn() |
||||
} |
||||
s.mu.RLock() |
||||
st = s.sts[conn.addr] |
||||
s.mu.RUnlock() |
||||
if st == nil { |
||||
st = conn.prepared(s.query) |
||||
s.mu.Lock() |
||||
s.sts[conn.addr] = st |
||||
s.mu.Unlock() |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Exec executes a prepared statement with the given arguments and returns a
|
||||
// Result summarizing the effect of the statement.
|
||||
func (s *Stmts) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) { |
||||
return s.prepare(nil).Exec(c, args...) |
||||
} |
||||
|
||||
// Query executes a prepared query statement with the given arguments and
|
||||
// returns the query results as a *Rows.
|
||||
func (s *Stmts) Query(c context.Context, args ...interface{}) (rows *Rows, err error) { |
||||
return s.prepare(nil).Query(c, args...) |
||||
} |
||||
|
||||
// QueryRow executes a prepared query statement with the given arguments.
|
||||
// If an error occurs during the execution of the statement, that error will
|
||||
// be returned by a call to Scan on the returned *Row, which is always non-nil.
|
||||
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
|
||||
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
|
||||
func (s *Stmts) QueryRow(c context.Context, args ...interface{}) (row *Row) { |
||||
return s.prepare(nil).QueryRow(c, args...) |
||||
} |
||||
|
||||
// Close closes the statement.
|
||||
func (s *Stmts) Close() (err error) { |
||||
for _, st := range s.sts { |
||||
if err = errors.WithStack(st.Close()); err != nil { |
||||
return |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Commit commits the transaction.
|
||||
func (tx *Tx) Commit() (err error) { |
||||
err = tx.tx.Commit() |
||||
tx.cancel() |
||||
tx.db.onBreaker(&err) |
||||
if tx.t != nil { |
||||
tx.t.Finish(&err) |
||||
} |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Rollback aborts the transaction.
|
||||
func (tx *Tx) Rollback() (err error) { |
||||
err = tx.tx.Rollback() |
||||
tx.cancel() |
||||
tx.db.onBreaker(&err) |
||||
if tx.t != nil { |
||||
tx.t.Finish(&err) |
||||
} |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Exec executes a query that doesn't return rows. For example: an INSERT and
|
||||
// UPDATE.
|
||||
func (tx *Tx) Exec(query string, args ...interface{}) (res sql.Result, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now) |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("exec %s", query))) |
||||
} |
||||
res, err = tx.tx.ExecContext(tx.c, query, args...) |
||||
stats.Timing("tidb:tx:exec", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", tx.db.addr, query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Query executes a query that returns rows, typically a SELECT.
|
||||
func (tx *Tx) Query(query string, args ...interface{}) (rows *Rows, err error) { |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("query %s", query))) |
||||
} |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now) |
||||
defer func() { |
||||
stats.Timing("tidb:tx:query", int64(time.Since(now)/time.Millisecond)) |
||||
}() |
||||
rs, err := tx.tx.QueryContext(tx.c, query, args...) |
||||
if err == nil { |
||||
rows = &Rows{Rows: rs} |
||||
} else { |
||||
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", tx.db.addr, query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// QueryRow executes a query that is expected to return at most one row.
|
||||
// QueryRow always returns a non-nil value. Errors are deferred until Row's
|
||||
// Scan method is called.
|
||||
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("queryrow %s", query))) |
||||
} |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now) |
||||
defer func() { |
||||
stats.Timing("tidb:tx:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||
}() |
||||
r := tx.tx.QueryRowContext(tx.c, query, args...) |
||||
return &Row{Row: r, db: tx.db, query: query, args: args} |
||||
} |
||||
|
||||
// Stmt returns a transaction-specific prepared statement from an existing statement.
|
||||
func (tx *Tx) Stmt(stmt *Stmt) *Stmt { |
||||
if stmt == nil { |
||||
return nil |
||||
} |
||||
as, ok := stmt.stmt.Load().(*sql.Stmt) |
||||
if !ok { |
||||
return nil |
||||
} |
||||
ts := tx.tx.StmtContext(tx.c, as) |
||||
st := &Stmt{query: stmt.query, tx: true, t: tx.t, db: tx.db} |
||||
st.stmt.Store(ts) |
||||
return st |
||||
} |
||||
|
||||
// Stmts returns a transaction-specific prepared statement from an existing statement.
|
||||
func (tx *Tx) Stmts(stmt *Stmts) *Stmt { |
||||
return tx.Stmt(stmt.prepare(tx.db)) |
||||
} |
||||
|
||||
// Prepare creates a prepared statement for use within a transaction.
|
||||
// The returned statement operates within the transaction and can no longer be
|
||||
// used once the transaction has been committed or rolled back.
|
||||
// To use an existing prepared statement on this transaction, see Tx.Stmt.
|
||||
func (tx *Tx) Prepare(query string) (*Stmt, error) { |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("prepare %s", query))) |
||||
} |
||||
defer slowLog(fmt.Sprintf("Prepare addr: %s query(%s)", tx.db.addr, query), time.Now()) |
||||
stmt, err := tx.tx.Prepare(query) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "addr: %s prepare %s", tx.db.addr, query) |
||||
return nil, err |
||||
} |
||||
st := &Stmt{query: query, tx: true, t: tx.t, db: tx.db} |
||||
st.stmt.Store(stmt) |
||||
return st, nil |
||||
} |
||||
|
||||
// parseDSNAddr parse dsn name and return addr.
|
||||
func parseDSNAddr(dsn string) (addr string) { |
||||
if dsn == "" { |
||||
return |
||||
} |
||||
cfg, err := mysql.ParseDSN(dsn) |
||||
if err != nil { |
||||
return |
||||
} |
||||
addr = cfg.Addr |
||||
return |
||||
} |
||||
|
||||
func genDSN(dsn, addr string) (res string) { |
||||
cfg, err := mysql.ParseDSN(dsn) |
||||
if err != nil { |
||||
return |
||||
} |
||||
cfg.Addr = addr |
||||
cfg.Net = "tcp" |
||||
res = cfg.FormatDSN() |
||||
return |
||||
} |
||||
|
||||
func slowLog(statement string, now time.Time) { |
||||
du := time.Since(now) |
||||
if du > _slowLogDuration { |
||||
log.Warn("%s slow log statement: %s time: %v", _family, statement, du) |
||||
} |
||||
} |
@ -0,0 +1,38 @@ |
||||
package tidb |
||||
|
||||
import ( |
||||
"github.com/bilibili/kratos/pkg/log" |
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||
"github.com/bilibili/kratos/pkg/stat" |
||||
"github.com/bilibili/kratos/pkg/time" |
||||
|
||||
// database driver
|
||||
_ "github.com/go-sql-driver/mysql" |
||||
) |
||||
|
||||
var stats = stat.DB |
||||
|
||||
// Config mysql config.
|
||||
type Config struct { |
||||
DSN string // dsn
|
||||
Active int // pool
|
||||
Idle int // pool
|
||||
IdleTimeout time.Duration // connect max life time.
|
||||
QueryTimeout time.Duration // query sql timeout
|
||||
ExecTimeout time.Duration // execute sql timeout
|
||||
TranTimeout time.Duration // transaction sql timeout
|
||||
Breaker *breaker.Config // breaker
|
||||
} |
||||
|
||||
// NewTiDB new db and retry connection when has error.
|
||||
func NewTiDB(c *Config) (db *DB) { |
||||
if c.QueryTimeout == 0 || c.ExecTimeout == 0 || c.TranTimeout == 0 { |
||||
panic("tidb must be set query/execute/transction timeout") |
||||
} |
||||
db, err := Open(c) |
||||
if err != nil { |
||||
log.Error("open tidb error(%v)", err) |
||||
panic(err) |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,85 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"net/http" |
||||
"strings" |
||||
|
||||
"gopkg.in/go-playground/validator.v9" |
||||
) |
||||
|
||||
// MIME
|
||||
const ( |
||||
MIMEJSON = "application/json" |
||||
MIMEHTML = "text/html" |
||||
MIMEXML = "application/xml" |
||||
MIMEXML2 = "text/xml" |
||||
MIMEPlain = "text/plain" |
||||
MIMEPOSTForm = "application/x-www-form-urlencoded" |
||||
MIMEMultipartPOSTForm = "multipart/form-data" |
||||
) |
||||
|
||||
// Binding http binding request interface.
|
||||
type Binding interface { |
||||
Name() string |
||||
Bind(*http.Request, interface{}) error |
||||
} |
||||
|
||||
// StructValidator http validator interface.
|
||||
type StructValidator interface { |
||||
// ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right.
|
||||
// If the received type is not a struct, any validation should be skipped and nil must be returned.
|
||||
// If the received type is a struct or pointer to a struct, the validation should be performed.
|
||||
// If the struct is not valid or the validation itself fails, a descriptive error should be returned.
|
||||
// Otherwise nil must be returned.
|
||||
ValidateStruct(interface{}) error |
||||
|
||||
// RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key
|
||||
// NOTE: if the key already exists, the previous validation function will be replaced.
|
||||
// NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
|
||||
RegisterValidation(string, validator.Func) error |
||||
} |
||||
|
||||
// Validator default validator.
|
||||
var Validator StructValidator = &defaultValidator{} |
||||
|
||||
// Binding
|
||||
var ( |
||||
JSON = jsonBinding{} |
||||
XML = xmlBinding{} |
||||
Form = formBinding{} |
||||
Query = queryBinding{} |
||||
FormPost = formPostBinding{} |
||||
FormMultipart = formMultipartBinding{} |
||||
) |
||||
|
||||
// Default get by binding type by method and contexttype.
|
||||
func Default(method, contentType string) Binding { |
||||
if method == "GET" { |
||||
return Form |
||||
} |
||||
|
||||
contentType = stripContentTypeParam(contentType) |
||||
switch contentType { |
||||
case MIMEJSON: |
||||
return JSON |
||||
case MIMEXML, MIMEXML2: |
||||
return XML |
||||
default: //case MIMEPOSTForm, MIMEMultipartPOSTForm:
|
||||
return Form |
||||
} |
||||
} |
||||
|
||||
func validate(obj interface{}) error { |
||||
if Validator == nil { |
||||
return nil |
||||
} |
||||
return Validator.ValidateStruct(obj) |
||||
} |
||||
|
||||
func stripContentTypeParam(contentType string) string { |
||||
i := strings.Index(contentType, ";") |
||||
if i != -1 { |
||||
contentType = contentType[:i] |
||||
} |
||||
return contentType |
||||
} |
@ -0,0 +1,342 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"bytes" |
||||
"mime/multipart" |
||||
"net/http" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
type FooStruct struct { |
||||
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" validate:"required"` |
||||
} |
||||
|
||||
type FooBarStruct struct { |
||||
FooStruct |
||||
Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" validate:"required"` |
||||
Slice []string `form:"slice" validate:"max=10"` |
||||
} |
||||
|
||||
type ComplexDefaultStruct struct { |
||||
Int int `form:"int" default:"999"` |
||||
String string `form:"string" default:"default-string"` |
||||
Bool bool `form:"bool" default:"false"` |
||||
Int64Slice []int64 `form:"int64_slice,split" default:"1,2,3,4"` |
||||
Int8Slice []int8 `form:"int8_slice,split" default:"1,2,3,4"` |
||||
} |
||||
|
||||
type Int8SliceStruct struct { |
||||
State []int8 `form:"state,split"` |
||||
} |
||||
|
||||
type Int64SliceStruct struct { |
||||
State []int64 `form:"state,split"` |
||||
} |
||||
|
||||
type StringSliceStruct struct { |
||||
State []string `form:"state,split"` |
||||
} |
||||
|
||||
func TestBindingDefault(t *testing.T) { |
||||
assert.Equal(t, Default("GET", ""), Form) |
||||
assert.Equal(t, Default("GET", MIMEJSON), Form) |
||||
assert.Equal(t, Default("GET", MIMEJSON+"; charset=utf-8"), Form) |
||||
|
||||
assert.Equal(t, Default("POST", MIMEJSON), JSON) |
||||
assert.Equal(t, Default("PUT", MIMEJSON), JSON) |
||||
|
||||
assert.Equal(t, Default("POST", MIMEJSON+"; charset=utf-8"), JSON) |
||||
assert.Equal(t, Default("PUT", MIMEJSON+"; charset=utf-8"), JSON) |
||||
|
||||
assert.Equal(t, Default("POST", MIMEXML), XML) |
||||
assert.Equal(t, Default("PUT", MIMEXML2), XML) |
||||
|
||||
assert.Equal(t, Default("POST", MIMEPOSTForm), Form) |
||||
assert.Equal(t, Default("PUT", MIMEPOSTForm), Form) |
||||
|
||||
assert.Equal(t, Default("POST", MIMEPOSTForm+"; charset=utf-8"), Form) |
||||
assert.Equal(t, Default("PUT", MIMEPOSTForm+"; charset=utf-8"), Form) |
||||
|
||||
assert.Equal(t, Default("POST", MIMEMultipartPOSTForm), Form) |
||||
assert.Equal(t, Default("PUT", MIMEMultipartPOSTForm), Form) |
||||
|
||||
} |
||||
|
||||
func TestStripContentType(t *testing.T) { |
||||
c1 := "application/vnd.mozilla.xul+xml" |
||||
c2 := "application/vnd.mozilla.xul+xml; charset=utf-8" |
||||
assert.Equal(t, stripContentTypeParam(c1), c1) |
||||
assert.Equal(t, stripContentTypeParam(c2), "application/vnd.mozilla.xul+xml") |
||||
} |
||||
|
||||
func TestBindInt8Form(t *testing.T) { |
||||
params := "state=1,2,3" |
||||
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q := new(Int8SliceStruct) |
||||
Form.Bind(req, q) |
||||
assert.EqualValues(t, []int8{1, 2, 3}, q.State) |
||||
|
||||
params = "state=1,2,3,256" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(Int8SliceStruct) |
||||
assert.Error(t, Form.Bind(req, q)) |
||||
|
||||
params = "state=" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(Int8SliceStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.Len(t, q.State, 0) |
||||
|
||||
params = "state=1,,2" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(Int8SliceStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.EqualValues(t, []int8{1, 2}, q.State) |
||||
} |
||||
|
||||
func TestBindInt64Form(t *testing.T) { |
||||
params := "state=1,2,3" |
||||
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q := new(Int64SliceStruct) |
||||
Form.Bind(req, q) |
||||
assert.EqualValues(t, []int64{1, 2, 3}, q.State) |
||||
|
||||
params = "state=" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(Int64SliceStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.Len(t, q.State, 0) |
||||
} |
||||
|
||||
func TestBindStringForm(t *testing.T) { |
||||
params := "state=1,2,3" |
||||
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q := new(StringSliceStruct) |
||||
Form.Bind(req, q) |
||||
assert.EqualValues(t, []string{"1", "2", "3"}, q.State) |
||||
|
||||
params = "state=" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(StringSliceStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.Len(t, q.State, 0) |
||||
|
||||
params = "state=p,,p" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(StringSliceStruct) |
||||
Form.Bind(req, q) |
||||
assert.EqualValues(t, []string{"p", "p"}, q.State) |
||||
} |
||||
|
||||
func TestBindingJSON(t *testing.T) { |
||||
testBodyBinding(t, |
||||
JSON, "json", |
||||
"/", "/", |
||||
`{"foo": "bar"}`, `{"bar": "foo"}`) |
||||
} |
||||
|
||||
func TestBindingForm(t *testing.T) { |
||||
testFormBinding(t, "POST", |
||||
"/", "/", |
||||
"foo=bar&bar=foo&slice=a&slice=b", "bar2=foo") |
||||
} |
||||
|
||||
func TestBindingForm2(t *testing.T) { |
||||
testFormBinding(t, "GET", |
||||
"/?foo=bar&bar=foo", "/?bar2=foo", |
||||
"", "") |
||||
} |
||||
|
||||
func TestBindingQuery(t *testing.T) { |
||||
testQueryBinding(t, "POST", |
||||
"/?foo=bar&bar=foo", "/", |
||||
"foo=unused", "bar2=foo") |
||||
} |
||||
|
||||
func TestBindingQuery2(t *testing.T) { |
||||
testQueryBinding(t, "GET", |
||||
"/?foo=bar&bar=foo", "/?bar2=foo", |
||||
"foo=unused", "") |
||||
} |
||||
|
||||
func TestBindingXML(t *testing.T) { |
||||
testBodyBinding(t, |
||||
XML, "xml", |
||||
"/", "/", |
||||
"<map><foo>bar</foo></map>", "<map><bar>foo</bar></map>") |
||||
} |
||||
|
||||
func createFormPostRequest() *http.Request { |
||||
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo")) |
||||
req.Header.Set("Content-Type", MIMEPOSTForm) |
||||
return req |
||||
} |
||||
|
||||
func createFormMultipartRequest() *http.Request { |
||||
boundary := "--testboundary" |
||||
body := new(bytes.Buffer) |
||||
mw := multipart.NewWriter(body) |
||||
defer mw.Close() |
||||
|
||||
mw.SetBoundary(boundary) |
||||
mw.WriteField("foo", "bar") |
||||
mw.WriteField("bar", "foo") |
||||
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) |
||||
req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) |
||||
return req |
||||
} |
||||
|
||||
func TestBindingFormPost(t *testing.T) { |
||||
req := createFormPostRequest() |
||||
var obj FooBarStruct |
||||
FormPost.Bind(req, &obj) |
||||
|
||||
assert.Equal(t, obj.Foo, "bar") |
||||
assert.Equal(t, obj.Bar, "foo") |
||||
} |
||||
|
||||
func TestBindingFormMultipart(t *testing.T) { |
||||
req := createFormMultipartRequest() |
||||
var obj FooBarStruct |
||||
FormMultipart.Bind(req, &obj) |
||||
|
||||
assert.Equal(t, obj.Foo, "bar") |
||||
assert.Equal(t, obj.Bar, "foo") |
||||
} |
||||
|
||||
func TestValidationFails(t *testing.T) { |
||||
var obj FooStruct |
||||
req := requestWithBody("POST", "/", `{"bar": "foo"}`) |
||||
err := JSON.Bind(req, &obj) |
||||
assert.Error(t, err) |
||||
} |
||||
|
||||
func TestValidationDisabled(t *testing.T) { |
||||
backup := Validator |
||||
Validator = nil |
||||
defer func() { Validator = backup }() |
||||
|
||||
var obj FooStruct |
||||
req := requestWithBody("POST", "/", `{"bar": "foo"}`) |
||||
err := JSON.Bind(req, &obj) |
||||
assert.NoError(t, err) |
||||
} |
||||
|
||||
func TestExistsSucceeds(t *testing.T) { |
||||
type HogeStruct struct { |
||||
Hoge *int `json:"hoge" binding:"exists"` |
||||
} |
||||
|
||||
var obj HogeStruct |
||||
req := requestWithBody("POST", "/", `{"hoge": 0}`) |
||||
err := JSON.Bind(req, &obj) |
||||
assert.NoError(t, err) |
||||
} |
||||
|
||||
func TestFormDefaultValue(t *testing.T) { |
||||
params := "int=333&string=hello&bool=true&int64_slice=5,6,7,8&int8_slice=5,6,7,8" |
||||
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q := new(ComplexDefaultStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.Equal(t, 333, q.Int) |
||||
assert.Equal(t, "hello", q.String) |
||||
assert.Equal(t, true, q.Bool) |
||||
assert.EqualValues(t, []int64{5, 6, 7, 8}, q.Int64Slice) |
||||
assert.EqualValues(t, []int8{5, 6, 7, 8}, q.Int8Slice) |
||||
|
||||
params = "string=hello&bool=false" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(ComplexDefaultStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.Equal(t, 999, q.Int) |
||||
assert.Equal(t, "hello", q.String) |
||||
assert.Equal(t, false, q.Bool) |
||||
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) |
||||
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) |
||||
|
||||
params = "strings=hello" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(ComplexDefaultStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.Equal(t, 999, q.Int) |
||||
assert.Equal(t, "default-string", q.String) |
||||
assert.Equal(t, false, q.Bool) |
||||
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) |
||||
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) |
||||
|
||||
params = "int=&string=&bool=true&int64_slice=&int8_slice=" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(ComplexDefaultStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.Equal(t, 999, q.Int) |
||||
assert.Equal(t, "default-string", q.String) |
||||
assert.Equal(t, true, q.Bool) |
||||
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) |
||||
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) |
||||
} |
||||
|
||||
func testFormBinding(t *testing.T, method, path, badPath, body, badBody string) { |
||||
b := Form |
||||
assert.Equal(t, b.Name(), "form") |
||||
|
||||
obj := FooBarStruct{} |
||||
req := requestWithBody(method, path, body) |
||||
if method == "POST" { |
||||
req.Header.Add("Content-Type", MIMEPOSTForm) |
||||
} |
||||
err := b.Bind(req, &obj) |
||||
assert.NoError(t, err) |
||||
assert.Equal(t, obj.Foo, "bar") |
||||
assert.Equal(t, obj.Bar, "foo") |
||||
|
||||
obj = FooBarStruct{} |
||||
req = requestWithBody(method, badPath, badBody) |
||||
err = JSON.Bind(req, &obj) |
||||
assert.Error(t, err) |
||||
} |
||||
|
||||
func testQueryBinding(t *testing.T, method, path, badPath, body, badBody string) { |
||||
b := Query |
||||
assert.Equal(t, b.Name(), "query") |
||||
|
||||
obj := FooBarStruct{} |
||||
req := requestWithBody(method, path, body) |
||||
if method == "POST" { |
||||
req.Header.Add("Content-Type", MIMEPOSTForm) |
||||
} |
||||
err := b.Bind(req, &obj) |
||||
assert.NoError(t, err) |
||||
assert.Equal(t, obj.Foo, "bar") |
||||
assert.Equal(t, obj.Bar, "foo") |
||||
} |
||||
|
||||
func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) { |
||||
assert.Equal(t, b.Name(), name) |
||||
|
||||
obj := FooStruct{} |
||||
req := requestWithBody("POST", path, body) |
||||
err := b.Bind(req, &obj) |
||||
assert.NoError(t, err) |
||||
assert.Equal(t, obj.Foo, "bar") |
||||
|
||||
obj = FooStruct{} |
||||
req = requestWithBody("POST", badPath, badBody) |
||||
err = JSON.Bind(req, &obj) |
||||
assert.Error(t, err) |
||||
} |
||||
|
||||
func requestWithBody(method, path, body string) (req *http.Request) { |
||||
req, _ = http.NewRequest(method, path, bytes.NewBufferString(body)) |
||||
return |
||||
} |
||||
func BenchmarkBindingForm(b *testing.B) { |
||||
req := requestWithBody("POST", "/", "foo=bar&bar=foo&slice=a&slice=b&slice=c&slice=w") |
||||
req.Header.Add("Content-Type", MIMEPOSTForm) |
||||
f := Form |
||||
for i := 0; i < b.N; i++ { |
||||
obj := FooBarStruct{} |
||||
f.Bind(req, &obj) |
||||
} |
||||
} |
@ -0,0 +1,45 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"reflect" |
||||
"sync" |
||||
|
||||
"gopkg.in/go-playground/validator.v9" |
||||
) |
||||
|
||||
type defaultValidator struct { |
||||
once sync.Once |
||||
validate *validator.Validate |
||||
} |
||||
|
||||
var _ StructValidator = &defaultValidator{} |
||||
|
||||
func (v *defaultValidator) ValidateStruct(obj interface{}) error { |
||||
if kindOfData(obj) == reflect.Struct { |
||||
v.lazyinit() |
||||
if err := v.validate.Struct(obj); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (v *defaultValidator) RegisterValidation(key string, fn validator.Func) error { |
||||
v.lazyinit() |
||||
return v.validate.RegisterValidation(key, fn) |
||||
} |
||||
|
||||
func (v *defaultValidator) lazyinit() { |
||||
v.once.Do(func() { |
||||
v.validate = validator.New() |
||||
}) |
||||
} |
||||
|
||||
func kindOfData(data interface{}) reflect.Kind { |
||||
value := reflect.ValueOf(data) |
||||
valueType := value.Kind() |
||||
if valueType == reflect.Ptr { |
||||
valueType = value.Elem().Kind() |
||||
} |
||||
return valueType |
||||
} |
@ -0,0 +1,113 @@ |
||||
// Code generated by protoc-gen-go.
|
||||
// source: test.proto
|
||||
// DO NOT EDIT!
|
||||
|
||||
/* |
||||
Package example is a generated protocol buffer package. |
||||
|
||||
It is generated from these files: |
||||
test.proto |
||||
|
||||
It has these top-level messages: |
||||
Test |
||||
*/ |
||||
package example |
||||
|
||||
import proto "github.com/golang/protobuf/proto" |
||||
import math "math" |
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal |
||||
var _ = math.Inf |
||||
|
||||
type FOO int32 |
||||
|
||||
const ( |
||||
FOO_X FOO = 17 |
||||
) |
||||
|
||||
var FOO_name = map[int32]string{ |
||||
17: "X", |
||||
} |
||||
var FOO_value = map[string]int32{ |
||||
"X": 17, |
||||
} |
||||
|
||||
func (x FOO) Enum() *FOO { |
||||
p := new(FOO) |
||||
*p = x |
||||
return p |
||||
} |
||||
func (x FOO) String() string { |
||||
return proto.EnumName(FOO_name, int32(x)) |
||||
} |
||||
func (x *FOO) UnmarshalJSON(data []byte) error { |
||||
value, err := proto.UnmarshalJSONEnum(FOO_value, data, "FOO") |
||||
if err != nil { |
||||
return err |
||||
} |
||||
*x = FOO(value) |
||||
return nil |
||||
} |
||||
|
||||
type Test struct { |
||||
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"` |
||||
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"` |
||||
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"` |
||||
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"` |
||||
XXX_unrecognized []byte `json:"-"` |
||||
} |
||||
|
||||
func (m *Test) Reset() { *m = Test{} } |
||||
func (m *Test) String() string { return proto.CompactTextString(m) } |
||||
func (*Test) ProtoMessage() {} |
||||
|
||||
const Default_Test_Type int32 = 77 |
||||
|
||||
func (m *Test) GetLabel() string { |
||||
if m != nil && m.Label != nil { |
||||
return *m.Label |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
func (m *Test) GetType() int32 { |
||||
if m != nil && m.Type != nil { |
||||
return *m.Type |
||||
} |
||||
return Default_Test_Type |
||||
} |
||||
|
||||
func (m *Test) GetReps() []int64 { |
||||
if m != nil { |
||||
return m.Reps |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (m *Test) GetOptionalgroup() *Test_OptionalGroup { |
||||
if m != nil { |
||||
return m.Optionalgroup |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
type Test_OptionalGroup struct { |
||||
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"` |
||||
XXX_unrecognized []byte `json:"-"` |
||||
} |
||||
|
||||
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} } |
||||
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) } |
||||
func (*Test_OptionalGroup) ProtoMessage() {} |
||||
|
||||
func (m *Test_OptionalGroup) GetRequiredField() string { |
||||
if m != nil && m.RequiredField != nil { |
||||
return *m.RequiredField |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
func init() { |
||||
proto.RegisterEnum("example.FOO", FOO_name, FOO_value) |
||||
} |
@ -0,0 +1,12 @@ |
||||
package example; |
||||
|
||||
enum FOO {X=17;}; |
||||
|
||||
message Test { |
||||
required string label = 1; |
||||
optional int32 type = 2[default=77]; |
||||
repeated int64 reps = 3; |
||||
optional group OptionalGroup = 4{ |
||||
required string RequiredField = 5; |
||||
} |
||||
} |
@ -0,0 +1,36 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"fmt" |
||||
"log" |
||||
"net/http" |
||||
) |
||||
|
||||
type Arg struct { |
||||
Max int64 `form:"max" validate:"max=10"` |
||||
Min int64 `form:"min" validate:"min=2"` |
||||
Range int64 `form:"range" validate:"min=1,max=10"` |
||||
// use split option to split arg 1,2,3 into slice [1 2 3]
|
||||
// otherwise slice type with parse url.Values (eg:a=b&a=c) default.
|
||||
Slice []int64 `form:"slice,split" validate:"min=1"` |
||||
} |
||||
|
||||
func ExampleBinding() { |
||||
req := initHTTP("max=9&min=3&range=3&slice=1,2,3") |
||||
arg := new(Arg) |
||||
if err := Form.Bind(req, arg); err != nil { |
||||
log.Fatal(err) |
||||
} |
||||
fmt.Printf("arg.Max %d\narg.Min %d\narg.Range %d\narg.Slice %v", arg.Max, arg.Min, arg.Range, arg.Slice) |
||||
// Output:
|
||||
// arg.Max 9
|
||||
// arg.Min 3
|
||||
// arg.Range 3
|
||||
// arg.Slice [1 2 3]
|
||||
} |
||||
|
||||
func initHTTP(params string) (req *http.Request) { |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
req.ParseForm() |
||||
return |
||||
} |
@ -0,0 +1,55 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
const defaultMemory = 32 * 1024 * 1024 |
||||
|
||||
type formBinding struct{} |
||||
type formPostBinding struct{} |
||||
type formMultipartBinding struct{} |
||||
|
||||
func (f formBinding) Name() string { |
||||
return "form" |
||||
} |
||||
|
||||
func (f formBinding) Bind(req *http.Request, obj interface{}) error { |
||||
if err := req.ParseForm(); err != nil { |
||||
return errors.WithStack(err) |
||||
} |
||||
if err := mapForm(obj, req.Form); err != nil { |
||||
return err |
||||
} |
||||
return validate(obj) |
||||
} |
||||
|
||||
func (f formPostBinding) Name() string { |
||||
return "form-urlencoded" |
||||
} |
||||
|
||||
func (f formPostBinding) Bind(req *http.Request, obj interface{}) error { |
||||
if err := req.ParseForm(); err != nil { |
||||
return errors.WithStack(err) |
||||
} |
||||
if err := mapForm(obj, req.PostForm); err != nil { |
||||
return err |
||||
} |
||||
return validate(obj) |
||||
} |
||||
|
||||
func (f formMultipartBinding) Name() string { |
||||
return "multipart/form-data" |
||||
} |
||||
|
||||
func (f formMultipartBinding) Bind(req *http.Request, obj interface{}) error { |
||||
if err := req.ParseMultipartForm(defaultMemory); err != nil { |
||||
return errors.WithStack(err) |
||||
} |
||||
if err := mapForm(obj, req.MultipartForm.Value); err != nil { |
||||
return err |
||||
} |
||||
return validate(obj) |
||||
} |
@ -0,0 +1,276 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"reflect" |
||||
"strconv" |
||||
"strings" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
// scache struct reflect type cache.
|
||||
var scache = &cache{ |
||||
data: make(map[reflect.Type]*sinfo), |
||||
} |
||||
|
||||
type cache struct { |
||||
data map[reflect.Type]*sinfo |
||||
mutex sync.RWMutex |
||||
} |
||||
|
||||
func (c *cache) get(obj reflect.Type) (s *sinfo) { |
||||
var ok bool |
||||
c.mutex.RLock() |
||||
if s, ok = c.data[obj]; !ok { |
||||
c.mutex.RUnlock() |
||||
s = c.set(obj) |
||||
return |
||||
} |
||||
c.mutex.RUnlock() |
||||
return |
||||
} |
||||
|
||||
func (c *cache) set(obj reflect.Type) (s *sinfo) { |
||||
s = new(sinfo) |
||||
tp := obj.Elem() |
||||
for i := 0; i < tp.NumField(); i++ { |
||||
fd := new(field) |
||||
fd.tp = tp.Field(i) |
||||
tag := fd.tp.Tag.Get("form") |
||||
fd.name, fd.option = parseTag(tag) |
||||
if defV := fd.tp.Tag.Get("default"); defV != "" { |
||||
dv := reflect.New(fd.tp.Type).Elem() |
||||
setWithProperType(fd.tp.Type.Kind(), []string{defV}, dv, fd.option) |
||||
fd.hasDefault = true |
||||
fd.defaultValue = dv |
||||
} |
||||
s.field = append(s.field, fd) |
||||
} |
||||
c.mutex.Lock() |
||||
c.data[obj] = s |
||||
c.mutex.Unlock() |
||||
return |
||||
} |
||||
|
||||
type sinfo struct { |
||||
field []*field |
||||
} |
||||
|
||||
type field struct { |
||||
tp reflect.StructField |
||||
name string |
||||
option tagOptions |
||||
|
||||
hasDefault bool // if field had default value
|
||||
defaultValue reflect.Value // field default value
|
||||
} |
||||
|
||||
func mapForm(ptr interface{}, form map[string][]string) error { |
||||
sinfo := scache.get(reflect.TypeOf(ptr)) |
||||
val := reflect.ValueOf(ptr).Elem() |
||||
for i, fd := range sinfo.field { |
||||
typeField := fd.tp |
||||
structField := val.Field(i) |
||||
if !structField.CanSet() { |
||||
continue |
||||
} |
||||
|
||||
structFieldKind := structField.Kind() |
||||
inputFieldName := fd.name |
||||
if inputFieldName == "" { |
||||
inputFieldName = typeField.Name |
||||
|
||||
// if "form" tag is nil, we inspect if the field is a struct.
|
||||
// this would not make sense for JSON parsing but it does for a form
|
||||
// since data is flatten
|
||||
if structFieldKind == reflect.Struct { |
||||
err := mapForm(structField.Addr().Interface(), form) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
continue |
||||
} |
||||
} |
||||
inputValue, exists := form[inputFieldName] |
||||
if !exists { |
||||
// Set the field as default value when the input value is not exist
|
||||
if fd.hasDefault { |
||||
structField.Set(fd.defaultValue) |
||||
} |
||||
continue |
||||
} |
||||
// Set the field as default value when the input value is empty
|
||||
if fd.hasDefault && inputValue[0] == "" { |
||||
structField.Set(fd.defaultValue) |
||||
continue |
||||
} |
||||
if _, isTime := structField.Interface().(time.Time); isTime { |
||||
if err := setTimeField(inputValue[0], typeField, structField); err != nil { |
||||
return err |
||||
} |
||||
continue |
||||
} |
||||
if err := setWithProperType(typeField.Type.Kind(), inputValue, structField, fd.option); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func setWithProperType(valueKind reflect.Kind, val []string, structField reflect.Value, option tagOptions) error { |
||||
switch valueKind { |
||||
case reflect.Int: |
||||
return setIntField(val[0], 0, structField) |
||||
case reflect.Int8: |
||||
return setIntField(val[0], 8, structField) |
||||
case reflect.Int16: |
||||
return setIntField(val[0], 16, structField) |
||||
case reflect.Int32: |
||||
return setIntField(val[0], 32, structField) |
||||
case reflect.Int64: |
||||
return setIntField(val[0], 64, structField) |
||||
case reflect.Uint: |
||||
return setUintField(val[0], 0, structField) |
||||
case reflect.Uint8: |
||||
return setUintField(val[0], 8, structField) |
||||
case reflect.Uint16: |
||||
return setUintField(val[0], 16, structField) |
||||
case reflect.Uint32: |
||||
return setUintField(val[0], 32, structField) |
||||
case reflect.Uint64: |
||||
return setUintField(val[0], 64, structField) |
||||
case reflect.Bool: |
||||
return setBoolField(val[0], structField) |
||||
case reflect.Float32: |
||||
return setFloatField(val[0], 32, structField) |
||||
case reflect.Float64: |
||||
return setFloatField(val[0], 64, structField) |
||||
case reflect.String: |
||||
structField.SetString(val[0]) |
||||
case reflect.Slice: |
||||
if option.Contains("split") { |
||||
val = strings.Split(val[0], ",") |
||||
} |
||||
filtered := filterEmpty(val) |
||||
switch structField.Type().Elem().Kind() { |
||||
case reflect.Int64: |
||||
valSli := make([]int64, 0, len(filtered)) |
||||
for i := 0; i < len(filtered); i++ { |
||||
d, err := strconv.ParseInt(filtered[i], 10, 64) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
valSli = append(valSli, d) |
||||
} |
||||
structField.Set(reflect.ValueOf(valSli)) |
||||
case reflect.String: |
||||
valSli := make([]string, 0, len(filtered)) |
||||
for i := 0; i < len(filtered); i++ { |
||||
valSli = append(valSli, filtered[i]) |
||||
} |
||||
structField.Set(reflect.ValueOf(valSli)) |
||||
default: |
||||
sliceOf := structField.Type().Elem().Kind() |
||||
numElems := len(filtered) |
||||
slice := reflect.MakeSlice(structField.Type(), len(filtered), len(filtered)) |
||||
for i := 0; i < numElems; i++ { |
||||
if err := setWithProperType(sliceOf, filtered[i:], slice.Index(i), ""); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
structField.Set(slice) |
||||
} |
||||
default: |
||||
return errors.New("Unknown type") |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func setIntField(val string, bitSize int, field reflect.Value) error { |
||||
if val == "" { |
||||
val = "0" |
||||
} |
||||
intVal, err := strconv.ParseInt(val, 10, bitSize) |
||||
if err == nil { |
||||
field.SetInt(intVal) |
||||
} |
||||
return errors.WithStack(err) |
||||
} |
||||
|
||||
func setUintField(val string, bitSize int, field reflect.Value) error { |
||||
if val == "" { |
||||
val = "0" |
||||
} |
||||
uintVal, err := strconv.ParseUint(val, 10, bitSize) |
||||
if err == nil { |
||||
field.SetUint(uintVal) |
||||
} |
||||
return errors.WithStack(err) |
||||
} |
||||
|
||||
func setBoolField(val string, field reflect.Value) error { |
||||
if val == "" { |
||||
val = "false" |
||||
} |
||||
boolVal, err := strconv.ParseBool(val) |
||||
if err == nil { |
||||
field.SetBool(boolVal) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func setFloatField(val string, bitSize int, field reflect.Value) error { |
||||
if val == "" { |
||||
val = "0.0" |
||||
} |
||||
floatVal, err := strconv.ParseFloat(val, bitSize) |
||||
if err == nil { |
||||
field.SetFloat(floatVal) |
||||
} |
||||
return errors.WithStack(err) |
||||
} |
||||
|
||||
func setTimeField(val string, structField reflect.StructField, value reflect.Value) error { |
||||
timeFormat := structField.Tag.Get("time_format") |
||||
if timeFormat == "" { |
||||
return errors.New("Blank time format") |
||||
} |
||||
|
||||
if val == "" { |
||||
value.Set(reflect.ValueOf(time.Time{})) |
||||
return nil |
||||
} |
||||
|
||||
l := time.Local |
||||
if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC { |
||||
l = time.UTC |
||||
} |
||||
|
||||
if locTag := structField.Tag.Get("time_location"); locTag != "" { |
||||
loc, err := time.LoadLocation(locTag) |
||||
if err != nil { |
||||
return errors.WithStack(err) |
||||
} |
||||
l = loc |
||||
} |
||||
|
||||
t, err := time.ParseInLocation(timeFormat, val, l) |
||||
if err != nil { |
||||
return errors.WithStack(err) |
||||
} |
||||
|
||||
value.Set(reflect.ValueOf(t)) |
||||
return nil |
||||
} |
||||
|
||||
func filterEmpty(val []string) []string { |
||||
filtered := make([]string, 0, len(val)) |
||||
for _, v := range val { |
||||
if v != "" { |
||||
filtered = append(filtered, v) |
||||
} |
||||
} |
||||
return filtered |
||||
} |
@ -0,0 +1,22 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
type jsonBinding struct{} |
||||
|
||||
func (jsonBinding) Name() string { |
||||
return "json" |
||||
} |
||||
|
||||
func (jsonBinding) Bind(req *http.Request, obj interface{}) error { |
||||
decoder := json.NewDecoder(req.Body) |
||||
if err := decoder.Decode(obj); err != nil { |
||||
return errors.WithStack(err) |
||||
} |
||||
return validate(obj) |
||||
} |
@ -0,0 +1,19 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"net/http" |
||||
) |
||||
|
||||
type queryBinding struct{} |
||||
|
||||
func (queryBinding) Name() string { |
||||
return "query" |
||||
} |
||||
|
||||
func (queryBinding) Bind(req *http.Request, obj interface{}) error { |
||||
values := req.URL.Query() |
||||
if err := mapForm(obj, values); err != nil { |
||||
return err |
||||
} |
||||
return validate(obj) |
||||
} |
@ -0,0 +1,44 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package binding |
||||
|
||||
import ( |
||||
"strings" |
||||
) |
||||
|
||||
// tagOptions is the string following a comma in a struct field's "json"
|
||||
// tag, or the empty string. It does not include the leading comma.
|
||||
type tagOptions string |
||||
|
||||
// parseTag splits a struct field's json tag into its name and
|
||||
// comma-separated options.
|
||||
func parseTag(tag string) (string, tagOptions) { |
||||
if idx := strings.Index(tag, ","); idx != -1 { |
||||
return tag[:idx], tagOptions(tag[idx+1:]) |
||||
} |
||||
return tag, tagOptions("") |
||||
} |
||||
|
||||
// Contains reports whether a comma-separated list of options
|
||||
// contains a particular substr flag. substr must be surrounded by a
|
||||
// string boundary or commas.
|
||||
func (o tagOptions) Contains(optionName string) bool { |
||||
if len(o) == 0 { |
||||
return false |
||||
} |
||||
s := string(o) |
||||
for s != "" { |
||||
var next string |
||||
i := strings.Index(s, ",") |
||||
if i >= 0 { |
||||
s, next = s[:i], s[i+1:] |
||||
} |
||||
if s == optionName { |
||||
return true |
||||
} |
||||
s = next |
||||
} |
||||
return false |
||||
} |
@ -0,0 +1,209 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"bytes" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
type testInterface interface { |
||||
String() string |
||||
} |
||||
|
||||
type substructNoValidation struct { |
||||
IString string |
||||
IInt int |
||||
} |
||||
|
||||
type mapNoValidationSub map[string]substructNoValidation |
||||
|
||||
type structNoValidationValues struct { |
||||
substructNoValidation |
||||
|
||||
Boolean bool |
||||
|
||||
Uinteger uint |
||||
Integer int |
||||
Integer8 int8 |
||||
Integer16 int16 |
||||
Integer32 int32 |
||||
Integer64 int64 |
||||
Uinteger8 uint8 |
||||
Uinteger16 uint16 |
||||
Uinteger32 uint32 |
||||
Uinteger64 uint64 |
||||
|
||||
Float32 float32 |
||||
Float64 float64 |
||||
|
||||
String string |
||||
|
||||
Date time.Time |
||||
|
||||
Struct substructNoValidation |
||||
InlinedStruct struct { |
||||
String []string |
||||
Integer int |
||||
} |
||||
|
||||
IntSlice []int |
||||
IntPointerSlice []*int |
||||
StructPointerSlice []*substructNoValidation |
||||
StructSlice []substructNoValidation |
||||
InterfaceSlice []testInterface |
||||
|
||||
UniversalInterface interface{} |
||||
CustomInterface testInterface |
||||
|
||||
FloatMap map[string]float32 |
||||
StructMap mapNoValidationSub |
||||
} |
||||
|
||||
func createNoValidationValues() structNoValidationValues { |
||||
integer := 1 |
||||
s := structNoValidationValues{ |
||||
Boolean: true, |
||||
Uinteger: 1 << 29, |
||||
Integer: -10000, |
||||
Integer8: 120, |
||||
Integer16: -20000, |
||||
Integer32: 1 << 29, |
||||
Integer64: 1 << 61, |
||||
Uinteger8: 250, |
||||
Uinteger16: 50000, |
||||
Uinteger32: 1 << 31, |
||||
Uinteger64: 1 << 62, |
||||
Float32: 123.456, |
||||
Float64: 123.456789, |
||||
String: "text", |
||||
Date: time.Time{}, |
||||
CustomInterface: &bytes.Buffer{}, |
||||
Struct: substructNoValidation{}, |
||||
IntSlice: []int{-3, -2, 1, 0, 1, 2, 3}, |
||||
IntPointerSlice: []*int{&integer}, |
||||
StructSlice: []substructNoValidation{}, |
||||
UniversalInterface: 1.2, |
||||
FloatMap: map[string]float32{ |
||||
"foo": 1.23, |
||||
"bar": 232.323, |
||||
}, |
||||
StructMap: mapNoValidationSub{ |
||||
"foo": substructNoValidation{}, |
||||
"bar": substructNoValidation{}, |
||||
}, |
||||
// StructPointerSlice []noValidationSub
|
||||
// InterfaceSlice []testInterface
|
||||
} |
||||
s.InlinedStruct.Integer = 1000 |
||||
s.InlinedStruct.String = []string{"first", "second"} |
||||
s.IString = "substring" |
||||
s.IInt = 987654 |
||||
return s |
||||
} |
||||
|
||||
func TestValidateNoValidationValues(t *testing.T) { |
||||
origin := createNoValidationValues() |
||||
test := createNoValidationValues() |
||||
empty := structNoValidationValues{} |
||||
|
||||
assert.Nil(t, validate(test)) |
||||
assert.Nil(t, validate(&test)) |
||||
assert.Nil(t, validate(empty)) |
||||
assert.Nil(t, validate(&empty)) |
||||
|
||||
assert.Equal(t, origin, test) |
||||
} |
||||
|
||||
type structNoValidationPointer struct { |
||||
// substructNoValidation
|
||||
|
||||
Boolean bool |
||||
|
||||
Uinteger *uint |
||||
Integer *int |
||||
Integer8 *int8 |
||||
Integer16 *int16 |
||||
Integer32 *int32 |
||||
Integer64 *int64 |
||||
Uinteger8 *uint8 |
||||
Uinteger16 *uint16 |
||||
Uinteger32 *uint32 |
||||
Uinteger64 *uint64 |
||||
|
||||
Float32 *float32 |
||||
Float64 *float64 |
||||
|
||||
String *string |
||||
|
||||
Date *time.Time |
||||
|
||||
Struct *substructNoValidation |
||||
|
||||
IntSlice *[]int |
||||
IntPointerSlice *[]*int |
||||
StructPointerSlice *[]*substructNoValidation |
||||
StructSlice *[]substructNoValidation |
||||
InterfaceSlice *[]testInterface |
||||
|
||||
FloatMap *map[string]float32 |
||||
StructMap *mapNoValidationSub |
||||
} |
||||
|
||||
func TestValidateNoValidationPointers(t *testing.T) { |
||||
//origin := createNoValidation_values()
|
||||
//test := createNoValidation_values()
|
||||
empty := structNoValidationPointer{} |
||||
|
||||
//assert.Nil(t, validate(test))
|
||||
//assert.Nil(t, validate(&test))
|
||||
assert.Nil(t, validate(empty)) |
||||
assert.Nil(t, validate(&empty)) |
||||
|
||||
//assert.Equal(t, origin, test)
|
||||
} |
||||
|
||||
type Object map[string]interface{} |
||||
|
||||
func TestValidatePrimitives(t *testing.T) { |
||||
obj := Object{"foo": "bar", "bar": 1} |
||||
assert.NoError(t, validate(obj)) |
||||
assert.NoError(t, validate(&obj)) |
||||
assert.Equal(t, obj, Object{"foo": "bar", "bar": 1}) |
||||
|
||||
obj2 := []Object{{"foo": "bar", "bar": 1}, {"foo": "bar", "bar": 1}} |
||||
assert.NoError(t, validate(obj2)) |
||||
assert.NoError(t, validate(&obj2)) |
||||
|
||||
nu := 10 |
||||
assert.NoError(t, validate(nu)) |
||||
assert.NoError(t, validate(&nu)) |
||||
assert.Equal(t, nu, 10) |
||||
|
||||
str := "value" |
||||
assert.NoError(t, validate(str)) |
||||
assert.NoError(t, validate(&str)) |
||||
assert.Equal(t, str, "value") |
||||
} |
||||
|
||||
// structCustomValidation is a helper struct we use to check that
|
||||
// custom validation can be registered on it.
|
||||
// The `notone` binding directive is for custom validation and registered later.
|
||||
// type structCustomValidation struct {
|
||||
// Integer int `binding:"notone"`
|
||||
// }
|
||||
|
||||
// notOne is a custom validator meant to be used with `validator.v8` library.
|
||||
// The method signature for `v9` is significantly different and this function
|
||||
// would need to be changed for tests to pass after upgrade.
|
||||
// See https://github.com/gin-gonic/gin/pull/1015.
|
||||
// func notOne(
|
||||
// v *validator.Validate, topStruct reflect.Value, currentStructOrField reflect.Value,
|
||||
// field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string,
|
||||
// ) bool {
|
||||
// if val, ok := field.Interface().(int); ok {
|
||||
// return val != 1
|
||||
// }
|
||||
// return false
|
||||
// }
|
@ -0,0 +1,22 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"encoding/xml" |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
type xmlBinding struct{} |
||||
|
||||
func (xmlBinding) Name() string { |
||||
return "xml" |
||||
} |
||||
|
||||
func (xmlBinding) Bind(req *http.Request, obj interface{}) error { |
||||
decoder := xml.NewDecoder(req.Body) |
||||
if err := decoder.Decode(obj); err != nil { |
||||
return errors.WithStack(err) |
||||
} |
||||
return validate(obj) |
||||
} |
@ -0,0 +1,306 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"context" |
||||
"math" |
||||
"net/http" |
||||
"strconv" |
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode" |
||||
"github.com/bilibili/kratos/pkg/net/http/blademaster/binding" |
||||
"github.com/bilibili/kratos/pkg/net/http/blademaster/render" |
||||
|
||||
"github.com/gogo/protobuf/proto" |
||||
"github.com/gogo/protobuf/types" |
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
const ( |
||||
_abortIndex int8 = math.MaxInt8 / 2 |
||||
) |
||||
|
||||
var ( |
||||
_openParen = []byte("(") |
||||
_closeParen = []byte(")") |
||||
) |
||||
|
||||
// Context is the most important part. It allows us to pass variables between
|
||||
// middleware, manage the flow, validate the JSON of a request and render a
|
||||
// JSON response for example.
|
||||
type Context struct { |
||||
context.Context |
||||
|
||||
Request *http.Request |
||||
Writer http.ResponseWriter |
||||
|
||||
// flow control
|
||||
index int8 |
||||
handlers []HandlerFunc |
||||
|
||||
// Keys is a key/value pair exclusively for the context of each request.
|
||||
Keys map[string]interface{} |
||||
|
||||
Error error |
||||
|
||||
method string |
||||
engine *Engine |
||||
} |
||||
|
||||
/************************************/ |
||||
/*********** FLOW CONTROL ***********/ |
||||
/************************************/ |
||||
|
||||
// Next should be used only inside middleware.
|
||||
// It executes the pending handlers in the chain inside the calling handler.
|
||||
// See example in godoc.
|
||||
func (c *Context) Next() { |
||||
c.index++ |
||||
s := int8(len(c.handlers)) |
||||
for ; c.index < s; c.index++ { |
||||
// only check method on last handler, otherwise middlewares
|
||||
// will never be effected if request method is not matched
|
||||
if c.index == s-1 && c.method != c.Request.Method { |
||||
code := http.StatusMethodNotAllowed |
||||
c.Error = ecode.MethodNotAllowed |
||||
http.Error(c.Writer, http.StatusText(code), code) |
||||
return |
||||
} |
||||
|
||||
c.handlers[c.index](c) |
||||
} |
||||
} |
||||
|
||||
// Abort prevents pending handlers from being called. Note that this will not stop the current handler.
|
||||
// Let's say you have an authorization middleware that validates that the current request is authorized.
|
||||
// If the authorization fails (ex: the password does not match), call Abort to ensure the remaining handlers
|
||||
// for this request are not called.
|
||||
func (c *Context) Abort() { |
||||
c.index = _abortIndex |
||||
} |
||||
|
||||
// AbortWithStatus calls `Abort()` and writes the headers with the specified status code.
|
||||
// For example, a failed attempt to authenticate a request could use: context.AbortWithStatus(401).
|
||||
func (c *Context) AbortWithStatus(code int) { |
||||
c.Status(code) |
||||
c.Abort() |
||||
} |
||||
|
||||
// IsAborted returns true if the current context was aborted.
|
||||
func (c *Context) IsAborted() bool { |
||||
return c.index >= _abortIndex |
||||
} |
||||
|
||||
/************************************/ |
||||
/******** METADATA MANAGEMENT********/ |
||||
/************************************/ |
||||
|
||||
// Set is used to store a new key/value pair exclusively for this context.
|
||||
// It also lazy initializes c.Keys if it was not used previously.
|
||||
func (c *Context) Set(key string, value interface{}) { |
||||
if c.Keys == nil { |
||||
c.Keys = make(map[string]interface{}) |
||||
} |
||||
c.Keys[key] = value |
||||
} |
||||
|
||||
// Get returns the value for the given key, ie: (value, true).
|
||||
// If the value does not exists it returns (nil, false)
|
||||
func (c *Context) Get(key string) (value interface{}, exists bool) { |
||||
value, exists = c.Keys[key] |
||||
return |
||||
} |
||||
|
||||
/************************************/ |
||||
/******** RESPONSE RENDERING ********/ |
||||
/************************************/ |
||||
|
||||
// bodyAllowedForStatus is a copy of http.bodyAllowedForStatus non-exported function.
|
||||
func bodyAllowedForStatus(status int) bool { |
||||
switch { |
||||
case status >= 100 && status <= 199: |
||||
return false |
||||
case status == 204: |
||||
return false |
||||
case status == 304: |
||||
return false |
||||
} |
||||
return true |
||||
} |
||||
|
||||
// Status sets the HTTP response code.
|
||||
func (c *Context) Status(code int) { |
||||
c.Writer.WriteHeader(code) |
||||
} |
||||
|
||||
// Render http response with http code by a render instance.
|
||||
func (c *Context) Render(code int, r render.Render) { |
||||
r.WriteContentType(c.Writer) |
||||
if code > 0 { |
||||
c.Status(code) |
||||
} |
||||
|
||||
if !bodyAllowedForStatus(code) { |
||||
return |
||||
} |
||||
|
||||
params := c.Request.Form |
||||
|
||||
cb := params.Get("callback") |
||||
jsonp := cb != "" && params.Get("jsonp") == "jsonp" |
||||
if jsonp { |
||||
c.Writer.Write([]byte(cb)) |
||||
c.Writer.Write(_openParen) |
||||
} |
||||
|
||||
if err := r.Render(c.Writer); err != nil { |
||||
c.Error = err |
||||
return |
||||
} |
||||
|
||||
if jsonp { |
||||
if _, err := c.Writer.Write(_closeParen); err != nil { |
||||
c.Error = errors.WithStack(err) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// JSON serializes the given struct as JSON into the response body.
|
||||
// It also sets the Content-Type as "application/json".
|
||||
func (c *Context) JSON(data interface{}, err error) { |
||||
code := http.StatusOK |
||||
c.Error = err |
||||
bcode := ecode.Cause(err) |
||||
// TODO app allow 5xx?
|
||||
/* |
||||
if bcode.Code() == -500 { |
||||
code = http.StatusServiceUnavailable |
||||
} |
||||
*/ |
||||
writeStatusCode(c.Writer, bcode.Code()) |
||||
c.Render(code, render.JSON{ |
||||
Code: bcode.Code(), |
||||
Message: bcode.Message(), |
||||
Data: data, |
||||
}) |
||||
} |
||||
|
||||
// JSONMap serializes the given map as map JSON into the response body.
|
||||
// It also sets the Content-Type as "application/json".
|
||||
func (c *Context) JSONMap(data map[string]interface{}, err error) { |
||||
code := http.StatusOK |
||||
c.Error = err |
||||
bcode := ecode.Cause(err) |
||||
// TODO app allow 5xx?
|
||||
/* |
||||
if bcode.Code() == -500 { |
||||
code = http.StatusServiceUnavailable |
||||
} |
||||
*/ |
||||
writeStatusCode(c.Writer, bcode.Code()) |
||||
data["code"] = bcode.Code() |
||||
if _, ok := data["message"]; !ok { |
||||
data["message"] = bcode.Message() |
||||
} |
||||
c.Render(code, render.MapJSON(data)) |
||||
} |
||||
|
||||
// XML serializes the given struct as XML into the response body.
|
||||
// It also sets the Content-Type as "application/xml".
|
||||
func (c *Context) XML(data interface{}, err error) { |
||||
code := http.StatusOK |
||||
c.Error = err |
||||
bcode := ecode.Cause(err) |
||||
// TODO app allow 5xx?
|
||||
/* |
||||
if bcode.Code() == -500 { |
||||
code = http.StatusServiceUnavailable |
||||
} |
||||
*/ |
||||
writeStatusCode(c.Writer, bcode.Code()) |
||||
c.Render(code, render.XML{ |
||||
Code: bcode.Code(), |
||||
Message: bcode.Message(), |
||||
Data: data, |
||||
}) |
||||
} |
||||
|
||||
// Protobuf serializes the given struct as PB into the response body.
|
||||
// It also sets the ContentType as "application/x-protobuf".
|
||||
func (c *Context) Protobuf(data proto.Message, err error) { |
||||
var ( |
||||
bytes []byte |
||||
) |
||||
|
||||
code := http.StatusOK |
||||
c.Error = err |
||||
bcode := ecode.Cause(err) |
||||
|
||||
any := new(types.Any) |
||||
if data != nil { |
||||
if bytes, err = proto.Marshal(data); err != nil { |
||||
c.Error = errors.WithStack(err) |
||||
return |
||||
} |
||||
any.TypeUrl = "type.googleapis.com/" + proto.MessageName(data) |
||||
any.Value = bytes |
||||
} |
||||
writeStatusCode(c.Writer, bcode.Code()) |
||||
c.Render(code, render.PB{ |
||||
Code: int64(bcode.Code()), |
||||
Message: bcode.Message(), |
||||
Data: any, |
||||
}) |
||||
} |
||||
|
||||
// Bytes writes some data into the body stream and updates the HTTP code.
|
||||
func (c *Context) Bytes(code int, contentType string, data ...[]byte) { |
||||
c.Render(code, render.Data{ |
||||
ContentType: contentType, |
||||
Data: data, |
||||
}) |
||||
} |
||||
|
||||
// String writes the given string into the response body.
|
||||
func (c *Context) String(code int, format string, values ...interface{}) { |
||||
c.Render(code, render.String{Format: format, Data: values}) |
||||
} |
||||
|
||||
// Redirect returns a HTTP redirect to the specific location.
|
||||
func (c *Context) Redirect(code int, location string) { |
||||
c.Render(-1, render.Redirect{ |
||||
Code: code, |
||||
Location: location, |
||||
Request: c.Request, |
||||
}) |
||||
} |
||||
|
||||
// BindWith bind req arg with parser.
|
||||
func (c *Context) BindWith(obj interface{}, b binding.Binding) error { |
||||
return c.mustBindWith(obj, b) |
||||
} |
||||
|
||||
// Bind bind req arg with defult form binding.
|
||||
func (c *Context) Bind(obj interface{}) error { |
||||
return c.mustBindWith(obj, binding.Form) |
||||
} |
||||
|
||||
// mustBindWith binds the passed struct pointer using the specified binding engine.
|
||||
// It will abort the request with HTTP 400 if any error ocurrs.
|
||||
// See the binding package.
|
||||
func (c *Context) mustBindWith(obj interface{}, b binding.Binding) (err error) { |
||||
if err = b.Bind(c.Request, obj); err != nil { |
||||
c.Error = ecode.RequestErr |
||||
c.Render(http.StatusOK, render.JSON{ |
||||
Code: ecode.RequestErr.Code(), |
||||
Message: err.Error(), |
||||
Data: nil, |
||||
}) |
||||
c.Abort() |
||||
} |
||||
return |
||||
} |
||||
|
||||
func writeStatusCode(w http.ResponseWriter, ecode int) { |
||||
header := w.Header() |
||||
header.Set("kratos-status-code", strconv.FormatInt(int64(ecode), 10)) |
||||
} |
@ -0,0 +1,249 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"net/http" |
||||
"strconv" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/log" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
// CORSConfig represents all available options for the middleware.
|
||||
type CORSConfig struct { |
||||
AllowAllOrigins bool |
||||
|
||||
// AllowedOrigins is a list of origins a cross-domain request can be executed from.
|
||||
// If the special "*" value is present in the list, all origins will be allowed.
|
||||
// Default value is []
|
||||
AllowOrigins []string |
||||
|
||||
// AllowOriginFunc is a custom function to validate the origin. It take the origin
|
||||
// as argument and returns true if allowed or false otherwise. If this option is
|
||||
// set, the content of AllowedOrigins is ignored.
|
||||
AllowOriginFunc func(origin string) bool |
||||
|
||||
// AllowedMethods is a list of methods the client is allowed to use with
|
||||
// cross-domain requests. Default value is simple methods (GET and POST)
|
||||
AllowMethods []string |
||||
|
||||
// AllowedHeaders is list of non simple headers the client is allowed to use with
|
||||
// cross-domain requests.
|
||||
AllowHeaders []string |
||||
|
||||
// AllowCredentials indicates whether the request can include user credentials like
|
||||
// cookies, HTTP authentication or client side SSL certificates.
|
||||
AllowCredentials bool |
||||
|
||||
// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
|
||||
// API specification
|
||||
ExposeHeaders []string |
||||
|
||||
// MaxAge indicates how long (in seconds) the results of a preflight request
|
||||
// can be cached
|
||||
MaxAge time.Duration |
||||
} |
||||
|
||||
type cors struct { |
||||
allowAllOrigins bool |
||||
allowCredentials bool |
||||
allowOriginFunc func(string) bool |
||||
allowOrigins []string |
||||
normalHeaders http.Header |
||||
preflightHeaders http.Header |
||||
} |
||||
|
||||
type converter func(string) string |
||||
|
||||
// Validate is check configuration of user defined.
|
||||
func (c *CORSConfig) Validate() error { |
||||
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) { |
||||
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed") |
||||
} |
||||
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 { |
||||
return errors.New("conflict settings: all origins disabled") |
||||
} |
||||
for _, origin := range c.AllowOrigins { |
||||
if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") { |
||||
return errors.New("bad origin: origins must either be '*' or include http:// or https://") |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// CORS returns the location middleware with default configuration.
|
||||
func CORS(allowOriginHosts []string) HandlerFunc { |
||||
config := &CORSConfig{ |
||||
AllowMethods: []string{"GET", "POST"}, |
||||
AllowHeaders: []string{"Origin", "Content-Length", "Content-Type"}, |
||||
AllowCredentials: true, |
||||
MaxAge: time.Duration(0), |
||||
AllowOriginFunc: func(origin string) bool { |
||||
for _, host := range allowOriginHosts { |
||||
if strings.HasSuffix(strings.ToLower(origin), host) { |
||||
return true |
||||
} |
||||
} |
||||
return false |
||||
}, |
||||
} |
||||
return newCORS(config) |
||||
} |
||||
|
||||
// newCORS returns the location middleware with user-defined custom configuration.
|
||||
func newCORS(config *CORSConfig) HandlerFunc { |
||||
if err := config.Validate(); err != nil { |
||||
panic(err.Error()) |
||||
} |
||||
cors := &cors{ |
||||
allowOriginFunc: config.AllowOriginFunc, |
||||
allowAllOrigins: config.AllowAllOrigins, |
||||
allowCredentials: config.AllowCredentials, |
||||
allowOrigins: normalize(config.AllowOrigins), |
||||
normalHeaders: generateNormalHeaders(config), |
||||
preflightHeaders: generatePreflightHeaders(config), |
||||
} |
||||
|
||||
return func(c *Context) { |
||||
cors.applyCORS(c) |
||||
} |
||||
} |
||||
|
||||
func (cors *cors) applyCORS(c *Context) { |
||||
origin := c.Request.Header.Get("Origin") |
||||
if len(origin) == 0 { |
||||
// request is not a CORS request
|
||||
return |
||||
} |
||||
if !cors.validateOrigin(origin) { |
||||
log.V(5).Info("The request's Origin header `%s` does not match any of allowed origins.", origin) |
||||
c.AbortWithStatus(http.StatusForbidden) |
||||
return |
||||
} |
||||
|
||||
if c.Request.Method == "OPTIONS" { |
||||
cors.handlePreflight(c) |
||||
defer c.AbortWithStatus(200) |
||||
} else { |
||||
cors.handleNormal(c) |
||||
} |
||||
|
||||
if !cors.allowAllOrigins { |
||||
header := c.Writer.Header() |
||||
header.Set("Access-Control-Allow-Origin", origin) |
||||
} |
||||
} |
||||
|
||||
func (cors *cors) validateOrigin(origin string) bool { |
||||
if cors.allowAllOrigins { |
||||
return true |
||||
} |
||||
for _, value := range cors.allowOrigins { |
||||
if value == origin { |
||||
return true |
||||
} |
||||
} |
||||
if cors.allowOriginFunc != nil { |
||||
return cors.allowOriginFunc(origin) |
||||
} |
||||
return false |
||||
} |
||||
|
||||
func (cors *cors) handlePreflight(c *Context) { |
||||
header := c.Writer.Header() |
||||
for key, value := range cors.preflightHeaders { |
||||
header[key] = value |
||||
} |
||||
} |
||||
|
||||
func (cors *cors) handleNormal(c *Context) { |
||||
header := c.Writer.Header() |
||||
for key, value := range cors.normalHeaders { |
||||
header[key] = value |
||||
} |
||||
} |
||||
|
||||
func generateNormalHeaders(c *CORSConfig) http.Header { |
||||
headers := make(http.Header) |
||||
if c.AllowCredentials { |
||||
headers.Set("Access-Control-Allow-Credentials", "true") |
||||
} |
||||
|
||||
// backport support for early browsers
|
||||
if len(c.AllowMethods) > 0 { |
||||
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper) |
||||
value := strings.Join(allowMethods, ",") |
||||
headers.Set("Access-Control-Allow-Methods", value) |
||||
} |
||||
|
||||
if len(c.ExposeHeaders) > 0 { |
||||
exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey) |
||||
headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ",")) |
||||
} |
||||
if c.AllowAllOrigins { |
||||
headers.Set("Access-Control-Allow-Origin", "*") |
||||
} else { |
||||
headers.Set("Vary", "Origin") |
||||
} |
||||
return headers |
||||
} |
||||
|
||||
func generatePreflightHeaders(c *CORSConfig) http.Header { |
||||
headers := make(http.Header) |
||||
if c.AllowCredentials { |
||||
headers.Set("Access-Control-Allow-Credentials", "true") |
||||
} |
||||
if len(c.AllowMethods) > 0 { |
||||
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper) |
||||
value := strings.Join(allowMethods, ",") |
||||
headers.Set("Access-Control-Allow-Methods", value) |
||||
} |
||||
if len(c.AllowHeaders) > 0 { |
||||
allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey) |
||||
value := strings.Join(allowHeaders, ",") |
||||
headers.Set("Access-Control-Allow-Headers", value) |
||||
} |
||||
if c.MaxAge > time.Duration(0) { |
||||
value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10) |
||||
headers.Set("Access-Control-Max-Age", value) |
||||
} |
||||
if c.AllowAllOrigins { |
||||
headers.Set("Access-Control-Allow-Origin", "*") |
||||
} else { |
||||
// Always set Vary headers
|
||||
// see https://github.com/rs/cors/issues/10,
|
||||
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
|
||||
|
||||
headers.Add("Vary", "Origin") |
||||
headers.Add("Vary", "Access-Control-Request-Method") |
||||
headers.Add("Vary", "Access-Control-Request-Headers") |
||||
} |
||||
return headers |
||||
} |
||||
|
||||
func normalize(values []string) []string { |
||||
if values == nil { |
||||
return nil |
||||
} |
||||
distinctMap := make(map[string]bool, len(values)) |
||||
normalized := make([]string, 0, len(values)) |
||||
for _, value := range values { |
||||
value = strings.TrimSpace(value) |
||||
value = strings.ToLower(value) |
||||
if _, seen := distinctMap[value]; !seen { |
||||
normalized = append(normalized, value) |
||||
distinctMap[value] = true |
||||
} |
||||
} |
||||
return normalized |
||||
} |
||||
|
||||
func convert(s []string, c converter) []string { |
||||
var out []string |
||||
for _, i := range s { |
||||
out = append(out, c(i)) |
||||
} |
||||
return out |
||||
} |
@ -0,0 +1,64 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"net/url" |
||||
"regexp" |
||||
"strings" |
||||
|
||||
"github.com/bilibili/kratos/pkg/log" |
||||
) |
||||
|
||||
func matchHostSuffix(suffix string) func(*url.URL) bool { |
||||
return func(uri *url.URL) bool { |
||||
return strings.HasSuffix(strings.ToLower(uri.Host), suffix) |
||||
} |
||||
} |
||||
|
||||
func matchPattern(pattern *regexp.Regexp) func(*url.URL) bool { |
||||
return func(uri *url.URL) bool { |
||||
return pattern.MatchString(strings.ToLower(uri.String())) |
||||
} |
||||
} |
||||
|
||||
// CSRF returns the csrf middleware to prevent invalid cross site request.
|
||||
// Only referer is checked currently.
|
||||
func CSRF(allowHosts []string, allowPattern []string) HandlerFunc { |
||||
validations := []func(*url.URL) bool{} |
||||
|
||||
addHostSuffix := func(suffix string) { |
||||
validations = append(validations, matchHostSuffix(suffix)) |
||||
} |
||||
addPattern := func(pattern string) { |
||||
validations = append(validations, matchPattern(regexp.MustCompile(pattern))) |
||||
} |
||||
|
||||
for _, r := range allowHosts { |
||||
addHostSuffix(r) |
||||
} |
||||
for _, p := range allowPattern { |
||||
addPattern(p) |
||||
} |
||||
|
||||
return func(c *Context) { |
||||
referer := c.Request.Header.Get("Referer") |
||||
if referer == "" { |
||||
log.V(5).Info("The request's Referer or Origin header is empty.") |
||||
c.AbortWithStatus(403) |
||||
return |
||||
} |
||||
illegal := true |
||||
if uri, err := url.Parse(referer); err == nil && uri.Host != "" { |
||||
for _, validate := range validations { |
||||
if validate(uri) { |
||||
illegal = false |
||||
break |
||||
} |
||||
} |
||||
} |
||||
if illegal { |
||||
log.V(5).Info("The request's Referer header `%s` does not match any of allowed referers.", referer) |
||||
c.AbortWithStatus(403) |
||||
return |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,69 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"fmt" |
||||
"strconv" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode" |
||||
"github.com/bilibili/kratos/pkg/log" |
||||
"github.com/bilibili/kratos/pkg/net/metadata" |
||||
) |
||||
|
||||
// Logger is logger middleware
|
||||
func Logger() HandlerFunc { |
||||
const noUser = "no_user" |
||||
return func(c *Context) { |
||||
now := time.Now() |
||||
ip := metadata.String(c, metadata.RemoteIP) |
||||
req := c.Request |
||||
path := req.URL.Path |
||||
params := req.Form |
||||
var quota float64 |
||||
if deadline, ok := c.Context.Deadline(); ok { |
||||
quota = time.Until(deadline).Seconds() |
||||
} |
||||
|
||||
c.Next() |
||||
|
||||
err := c.Error |
||||
cerr := ecode.Cause(err) |
||||
dt := time.Since(now) |
||||
caller := metadata.String(c, metadata.Caller) |
||||
if caller == "" { |
||||
caller = noUser |
||||
} |
||||
|
||||
stats.Incr(caller, path[1:], strconv.FormatInt(int64(cerr.Code()), 10)) |
||||
stats.Timing(caller, int64(dt/time.Millisecond), path[1:]) |
||||
|
||||
lf := log.Infov |
||||
errmsg := "" |
||||
isSlow := dt >= (time.Millisecond * 500) |
||||
if err != nil { |
||||
errmsg = err.Error() |
||||
lf = log.Errorv |
||||
if cerr.Code() > 0 { |
||||
lf = log.Warnv |
||||
} |
||||
} else { |
||||
if isSlow { |
||||
lf = log.Warnv |
||||
} |
||||
} |
||||
lf(c, |
||||
log.KVString("method", req.Method), |
||||
log.KVString("ip", ip), |
||||
log.KVString("user", caller), |
||||
log.KVString("path", path), |
||||
log.KVString("params", params.Encode()), |
||||
log.KVInt("ret", cerr.Code()), |
||||
log.KVString("msg", cerr.Message()), |
||||
log.KVString("stack", fmt.Sprintf("%+v", err)), |
||||
log.KVString("err", errmsg), |
||||
log.KVFloat64("timeout_quota", quota), |
||||
log.KVFloat64("ts", dt.Seconds()), |
||||
log.KVString("source", "http-access-log"), |
||||
) |
||||
} |
||||
} |
@ -0,0 +1,46 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"flag" |
||||
"net/http" |
||||
"net/http/pprof" |
||||
"os" |
||||
"sync" |
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/dsn" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
var ( |
||||
_perfOnce sync.Once |
||||
_perfDSN string |
||||
) |
||||
|
||||
func init() { |
||||
v := os.Getenv("HTTP_PERF") |
||||
if v == "" { |
||||
v = "tcp://0.0.0.0:2333" |
||||
} |
||||
flag.StringVar(&_perfDSN, "http.perf", v, "listen http perf dsn, or use HTTP_PERF env variable.") |
||||
} |
||||
|
||||
func startPerf() { |
||||
_perfOnce.Do(func() { |
||||
mux := http.NewServeMux() |
||||
mux.HandleFunc("/debug/pprof/", pprof.Index) |
||||
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) |
||||
mux.HandleFunc("/debug/pprof/profile", pprof.Profile) |
||||
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) |
||||
|
||||
go func() { |
||||
d, err := dsn.Parse(_perfDSN) |
||||
if err != nil { |
||||
panic(errors.Errorf("blademaster: http perf dsn must be tcp://$host:port, %s:error(%v)", _perfDSN, err)) |
||||
} |
||||
if err := http.ListenAndServe(d.Host, mux); err != nil { |
||||
panic(errors.Errorf("blademaster: listen %s: error(%v)", d.Host, err)) |
||||
} |
||||
}() |
||||
}) |
||||
} |
@ -0,0 +1,12 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"github.com/prometheus/client_golang/prometheus/promhttp" |
||||
) |
||||
|
||||
func monitor() HandlerFunc { |
||||
return func(c *Context) { |
||||
h := promhttp.Handler() |
||||
h.ServeHTTP(c.Writer, c.Request) |
||||
} |
||||
} |
@ -0,0 +1,32 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net/http/httputil" |
||||
"os" |
||||
"runtime" |
||||
|
||||
"github.com/bilibili/kratos/pkg/log" |
||||
) |
||||
|
||||
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
|
||||
func Recovery() HandlerFunc { |
||||
return func(c *Context) { |
||||
defer func() { |
||||
var rawReq []byte |
||||
if err := recover(); err != nil { |
||||
const size = 64 << 10 |
||||
buf := make([]byte, size) |
||||
buf = buf[:runtime.Stack(buf, false)] |
||||
if c.Request != nil { |
||||
rawReq, _ = httputil.DumpRequest(c.Request, false) |
||||
} |
||||
pl := fmt.Sprintf("http call panic: %s\n%v\n%s\n", string(rawReq), err, buf) |
||||
fmt.Fprintf(os.Stderr, pl) |
||||
log.Error(pl) |
||||
c.AbortWithStatus(500) |
||||
} |
||||
}() |
||||
c.Next() |
||||
} |
||||
} |
@ -0,0 +1,30 @@ |
||||
package render |
||||
|
||||
import ( |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
// Data common bytes struct.
|
||||
type Data struct { |
||||
ContentType string |
||||
Data [][]byte |
||||
} |
||||
|
||||
// Render (Data) writes data with custom ContentType.
|
||||
func (r Data) Render(w http.ResponseWriter) (err error) { |
||||
r.WriteContentType(w) |
||||
for _, d := range r.Data { |
||||
if _, err = w.Write(d); err != nil { |
||||
err = errors.WithStack(err) |
||||
return |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// WriteContentType writes data with custom ContentType.
|
||||
func (r Data) WriteContentType(w http.ResponseWriter) { |
||||
writeContentType(w, []string{r.ContentType}) |
||||
} |
@ -0,0 +1,58 @@ |
||||
package render |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
var jsonContentType = []string{"application/json; charset=utf-8"} |
||||
|
||||
// JSON common json struct.
|
||||
type JSON struct { |
||||
Code int `json:"code"` |
||||
Message string `json:"message"` |
||||
TTL int `json:"ttl"` |
||||
Data interface{} `json:"data,omitempty"` |
||||
} |
||||
|
||||
func writeJSON(w http.ResponseWriter, obj interface{}) (err error) { |
||||
var jsonBytes []byte |
||||
writeContentType(w, jsonContentType) |
||||
if jsonBytes, err = json.Marshal(obj); err != nil { |
||||
err = errors.WithStack(err) |
||||
return |
||||
} |
||||
if _, err = w.Write(jsonBytes); err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Render (JSON) writes data with json ContentType.
|
||||
func (r JSON) Render(w http.ResponseWriter) error { |
||||
// FIXME(zhoujiahui): the TTL field will be configurable in the future
|
||||
if r.TTL <= 0 { |
||||
r.TTL = 1 |
||||
} |
||||
return writeJSON(w, r) |
||||
} |
||||
|
||||
// WriteContentType write json ContentType.
|
||||
func (r JSON) WriteContentType(w http.ResponseWriter) { |
||||
writeContentType(w, jsonContentType) |
||||
} |
||||
|
||||
// MapJSON common map json struct.
|
||||
type MapJSON map[string]interface{} |
||||
|
||||
// Render (MapJSON) writes data with json ContentType.
|
||||
func (m MapJSON) Render(w http.ResponseWriter) error { |
||||
return writeJSON(w, m) |
||||
} |
||||
|
||||
// WriteContentType write json ContentType.
|
||||
func (m MapJSON) WriteContentType(w http.ResponseWriter) { |
||||
writeContentType(w, jsonContentType) |
||||
} |
@ -0,0 +1,38 @@ |
||||
package render |
||||
|
||||
import ( |
||||
"net/http" |
||||
|
||||
"github.com/gogo/protobuf/proto" |
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
var pbContentType = []string{"application/x-protobuf"} |
||||
|
||||
// Render (PB) writes data with protobuf ContentType.
|
||||
func (r PB) Render(w http.ResponseWriter) error { |
||||
if r.TTL <= 0 { |
||||
r.TTL = 1 |
||||
} |
||||
return writePB(w, r) |
||||
} |
||||
|
||||
// WriteContentType write protobuf ContentType.
|
||||
func (r PB) WriteContentType(w http.ResponseWriter) { |
||||
writeContentType(w, pbContentType) |
||||
} |
||||
|
||||
func writePB(w http.ResponseWriter, obj PB) (err error) { |
||||
var pbBytes []byte |
||||
writeContentType(w, pbContentType) |
||||
|
||||
if pbBytes, err = proto.Marshal(&obj); err != nil { |
||||
err = errors.WithStack(err) |
||||
return |
||||
} |
||||
|
||||
if _, err = w.Write(pbBytes); err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,26 @@ |
||||
package render |
||||
|
||||
import ( |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
// Redirect render for redirect to specified location.
|
||||
type Redirect struct { |
||||
Code int |
||||
Request *http.Request |
||||
Location string |
||||
} |
||||
|
||||
// Render (Redirect) redirect to specified location.
|
||||
func (r Redirect) Render(w http.ResponseWriter) error { |
||||
if (r.Code < 300 || r.Code > 308) && r.Code != 201 { |
||||
return errors.Errorf("Cannot redirect with status code %d", r.Code) |
||||
} |
||||
http.Redirect(w, r.Request, r.Location, r.Code) |
||||
return nil |
||||
} |
||||
|
||||
// WriteContentType noneContentType.
|
||||
func (r Redirect) WriteContentType(http.ResponseWriter) {} |
@ -0,0 +1,30 @@ |
||||
package render |
||||
|
||||
import ( |
||||
"net/http" |
||||
) |
||||
|
||||
// Render http reponse render.
|
||||
type Render interface { |
||||
// Render render it to http response writer.
|
||||
Render(http.ResponseWriter) error |
||||
// WriteContentType write content-type to http response writer.
|
||||
WriteContentType(w http.ResponseWriter) |
||||
} |
||||
|
||||
var ( |
||||
_ Render = JSON{} |
||||
_ Render = MapJSON{} |
||||
_ Render = XML{} |
||||
_ Render = String{} |
||||
_ Render = Redirect{} |
||||
_ Render = Data{} |
||||
_ Render = PB{} |
||||
) |
||||
|
||||
func writeContentType(w http.ResponseWriter, value []string) { |
||||
header := w.Header() |
||||
if val := header["Content-Type"]; len(val) == 0 { |
||||
header["Content-Type"] = value |
||||
} |
||||
} |
@ -0,0 +1,89 @@ |
||||
// Code generated by protoc-gen-gogo. DO NOT EDIT.
|
||||
// source: pb.proto
|
||||
|
||||
/* |
||||
Package render is a generated protocol buffer package. |
||||
|
||||
It is generated from these files: |
||||
pb.proto |
||||
|
||||
It has these top-level messages: |
||||
PB |
||||
*/ |
||||
package render |
||||
|
||||
import proto "github.com/gogo/protobuf/proto" |
||||
import fmt "fmt" |
||||
import math "math" |
||||
import google_protobuf "github.com/gogo/protobuf/types" |
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal |
||||
var _ = fmt.Errorf |
||||
var _ = math.Inf |
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
type PB struct { |
||||
Code int64 `protobuf:"varint,1,opt,name=Code,proto3" json:"Code,omitempty"` |
||||
Message string `protobuf:"bytes,2,opt,name=Message,proto3" json:"Message,omitempty"` |
||||
TTL uint64 `protobuf:"varint,3,opt,name=TTL,proto3" json:"TTL,omitempty"` |
||||
Data *google_protobuf.Any `protobuf:"bytes,4,opt,name=Data" json:"Data,omitempty"` |
||||
} |
||||
|
||||
func (m *PB) Reset() { *m = PB{} } |
||||
func (m *PB) String() string { return proto.CompactTextString(m) } |
||||
func (*PB) ProtoMessage() {} |
||||
func (*PB) Descriptor() ([]byte, []int) { return fileDescriptorPb, []int{0} } |
||||
|
||||
func (m *PB) GetCode() int64 { |
||||
if m != nil { |
||||
return m.Code |
||||
} |
||||
return 0 |
||||
} |
||||
|
||||
func (m *PB) GetMessage() string { |
||||
if m != nil { |
||||
return m.Message |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
func (m *PB) GetTTL() uint64 { |
||||
if m != nil { |
||||
return m.TTL |
||||
} |
||||
return 0 |
||||
} |
||||
|
||||
func (m *PB) GetData() *google_protobuf.Any { |
||||
if m != nil { |
||||
return m.Data |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func init() { |
||||
proto.RegisterType((*PB)(nil), "render.PB") |
||||
} |
||||
|
||||
func init() { proto.RegisterFile("pb.proto", fileDescriptorPb) } |
||||
|
||||
var fileDescriptorPb = []byte{ |
||||
// 154 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x28, 0x48, 0xd2, 0x2b, |
||||
0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2b, 0x4a, 0xcd, 0x4b, 0x49, 0x2d, 0x92, 0x92, 0x4c, 0xcf, |
||||
0xcf, 0x4f, 0xcf, 0x49, 0xd5, 0x07, 0x8b, 0x26, 0x95, 0xa6, 0xe9, 0x27, 0xe6, 0x55, 0x42, 0x94, |
||||
0x28, 0xe5, 0x71, 0x31, 0x05, 0x38, 0x09, 0x09, 0x71, 0xb1, 0x38, 0xe7, 0xa7, 0xa4, 0x4a, 0x30, |
||||
0x2a, 0x30, 0x6a, 0x30, 0x07, 0x81, 0xd9, 0x42, 0x12, 0x5c, 0xec, 0xbe, 0xa9, 0xc5, 0xc5, 0x89, |
||||
0xe9, 0xa9, 0x12, 0x4c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x30, 0xae, 0x90, 0x00, 0x17, 0x73, 0x48, |
||||
0x88, 0x8f, 0x04, 0xb3, 0x02, 0xa3, 0x06, 0x4b, 0x10, 0x88, 0x29, 0xa4, 0xc1, 0xc5, 0xe2, 0x92, |
||||
0x58, 0x92, 0x28, 0xc1, 0xa2, 0xc0, 0xa8, 0xc1, 0x6d, 0x24, 0xa2, 0x07, 0xb1, 0x4f, 0x0f, 0x66, |
||||
0x9f, 0x9e, 0x63, 0x5e, 0x65, 0x10, 0x58, 0x45, 0x12, 0x1b, 0x58, 0xcc, 0x18, 0x10, 0x00, 0x00, |
||||
0xff, 0xff, 0x7a, 0x92, 0x16, 0x71, 0xa5, 0x00, 0x00, 0x00, |
||||
} |
@ -0,0 +1,14 @@ |
||||
// use under command to generate pb.pb.go |
||||
// protoc --proto_path=.:$GOPATH/src/github.com/gogo/protobuf --gogo_out=Mgoogle/protobuf/any.proto=github.com/gogo/protobuf/types:. *.proto |
||||
syntax = "proto3"; |
||||
package render; |
||||
|
||||
import "google/protobuf/any.proto"; |
||||
import "github.com/gogo/protobuf/gogoproto/gogo.proto"; |
||||
|
||||
message PB { |
||||
int64 Code = 1; |
||||
string Message = 2; |
||||
uint64 TTL = 3; |
||||
google.protobuf.Any Data = 4; |
||||
} |
@ -0,0 +1,40 @@ |
||||
package render |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
var plainContentType = []string{"text/plain; charset=utf-8"} |
||||
|
||||
// String common string struct.
|
||||
type String struct { |
||||
Format string |
||||
Data []interface{} |
||||
} |
||||
|
||||
// Render (String) writes data with custom ContentType.
|
||||
func (r String) Render(w http.ResponseWriter) error { |
||||
return writeString(w, r.Format, r.Data) |
||||
} |
||||
|
||||
// WriteContentType writes string with text/plain ContentType.
|
||||
func (r String) WriteContentType(w http.ResponseWriter) { |
||||
writeContentType(w, plainContentType) |
||||
} |
||||
|
||||
func writeString(w http.ResponseWriter, format string, data []interface{}) (err error) { |
||||
writeContentType(w, plainContentType) |
||||
if len(data) > 0 { |
||||
_, err = fmt.Fprintf(w, format, data...) |
||||
} else { |
||||
_, err = io.WriteString(w, format) |
||||
} |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,31 @@ |
||||
package render |
||||
|
||||
import ( |
||||
"encoding/xml" |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
// XML common xml struct.
|
||||
type XML struct { |
||||
Code int |
||||
Message string |
||||
Data interface{} |
||||
} |
||||
|
||||
var xmlContentType = []string{"application/xml; charset=utf-8"} |
||||
|
||||
// Render (XML) writes data with xml ContentType.
|
||||
func (r XML) Render(w http.ResponseWriter) (err error) { |
||||
r.WriteContentType(w) |
||||
if err = xml.NewEncoder(w).Encode(r.Data); err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// WriteContentType write xml ContentType.
|
||||
func (r XML) WriteContentType(w http.ResponseWriter) { |
||||
writeContentType(w, xmlContentType) |
||||
} |
@ -0,0 +1,166 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"regexp" |
||||
) |
||||
|
||||
// IRouter http router framework interface.
|
||||
type IRouter interface { |
||||
IRoutes |
||||
Group(string, ...HandlerFunc) *RouterGroup |
||||
} |
||||
|
||||
// IRoutes http router interface.
|
||||
type IRoutes interface { |
||||
UseFunc(...HandlerFunc) IRoutes |
||||
Use(...Handler) IRoutes |
||||
|
||||
Handle(string, string, ...HandlerFunc) IRoutes |
||||
HEAD(string, ...HandlerFunc) IRoutes |
||||
GET(string, ...HandlerFunc) IRoutes |
||||
POST(string, ...HandlerFunc) IRoutes |
||||
PUT(string, ...HandlerFunc) IRoutes |
||||
DELETE(string, ...HandlerFunc) IRoutes |
||||
} |
||||
|
||||
// RouterGroup is used internally to configure router, a RouterGroup is associated with a prefix
|
||||
// and an array of handlers (middleware).
|
||||
type RouterGroup struct { |
||||
Handlers []HandlerFunc |
||||
basePath string |
||||
engine *Engine |
||||
root bool |
||||
baseConfig *MethodConfig |
||||
} |
||||
|
||||
var _ IRouter = &RouterGroup{} |
||||
|
||||
// Use adds middleware to the group, see example code in doc.
|
||||
func (group *RouterGroup) Use(middleware ...Handler) IRoutes { |
||||
for _, m := range middleware { |
||||
group.Handlers = append(group.Handlers, m.ServeHTTP) |
||||
} |
||||
return group.returnObj() |
||||
} |
||||
|
||||
// UseFunc adds middleware to the group, see example code in doc.
|
||||
func (group *RouterGroup) UseFunc(middleware ...HandlerFunc) IRoutes { |
||||
group.Handlers = append(group.Handlers, middleware...) |
||||
return group.returnObj() |
||||
} |
||||
|
||||
// Group creates a new router group. You should add all the routes that have common middlwares or the same path prefix.
|
||||
// For example, all the routes that use a common middlware for authorization could be grouped.
|
||||
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup { |
||||
return &RouterGroup{ |
||||
Handlers: group.combineHandlers(handlers), |
||||
basePath: group.calculateAbsolutePath(relativePath), |
||||
engine: group.engine, |
||||
root: false, |
||||
} |
||||
} |
||||
|
||||
// SetMethodConfig is used to set config on specified method
|
||||
func (group *RouterGroup) SetMethodConfig(config *MethodConfig) *RouterGroup { |
||||
group.baseConfig = config |
||||
return group |
||||
} |
||||
|
||||
// BasePath router group base path.
|
||||
func (group *RouterGroup) BasePath() string { |
||||
return group.basePath |
||||
} |
||||
|
||||
func (group *RouterGroup) handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes { |
||||
absolutePath := group.calculateAbsolutePath(relativePath) |
||||
injections := group.injections(relativePath) |
||||
handlers = group.combineHandlers(injections, handlers) |
||||
group.engine.addRoute(httpMethod, absolutePath, handlers...) |
||||
if group.baseConfig != nil { |
||||
group.engine.SetMethodConfig(absolutePath, group.baseConfig) |
||||
} |
||||
return group.returnObj() |
||||
} |
||||
|
||||
// Handle registers a new request handle and middleware with the given path and method.
|
||||
// The last handler should be the real handler, the other ones should be middleware that can and should be shared among different routes.
|
||||
// See the example code in doc.
|
||||
//
|
||||
// For HEAD, GET, POST, PUT, and DELETE requests the respective shortcut
|
||||
// functions can be used.
|
||||
//
|
||||
// This function is intended for bulk loading and to allow the usage of less
|
||||
// frequently used, non-standardized or custom methods (e.g. for internal
|
||||
// communication with a proxy).
|
||||
func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes { |
||||
if matches, err := regexp.MatchString("^[A-Z]+$", httpMethod); !matches || err != nil { |
||||
panic("http method " + httpMethod + " is not valid") |
||||
} |
||||
return group.handle(httpMethod, relativePath, handlers...) |
||||
} |
||||
|
||||
// HEAD is a shortcut for router.Handle("HEAD", path, handle).
|
||||
func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||
return group.handle("HEAD", relativePath, handlers...) |
||||
} |
||||
|
||||
// GET is a shortcut for router.Handle("GET", path, handle).
|
||||
func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||
return group.handle("GET", relativePath, handlers...) |
||||
} |
||||
|
||||
// POST is a shortcut for router.Handle("POST", path, handle).
|
||||
func (group *RouterGroup) POST(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||
return group.handle("POST", relativePath, handlers...) |
||||
} |
||||
|
||||
// PUT is a shortcut for router.Handle("PUT", path, handle).
|
||||
func (group *RouterGroup) PUT(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||
return group.handle("PUT", relativePath, handlers...) |
||||
} |
||||
|
||||
// DELETE is a shortcut for router.Handle("DELETE", path, handle).
|
||||
func (group *RouterGroup) DELETE(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||
return group.handle("DELETE", relativePath, handlers...) |
||||
} |
||||
|
||||
func (group *RouterGroup) combineHandlers(handlerGroups ...[]HandlerFunc) []HandlerFunc { |
||||
finalSize := len(group.Handlers) |
||||
for _, handlers := range handlerGroups { |
||||
finalSize += len(handlers) |
||||
} |
||||
if finalSize >= int(_abortIndex) { |
||||
panic("too many handlers") |
||||
} |
||||
mergedHandlers := make([]HandlerFunc, finalSize) |
||||
copy(mergedHandlers, group.Handlers) |
||||
position := len(group.Handlers) |
||||
for _, handlers := range handlerGroups { |
||||
copy(mergedHandlers[position:], handlers) |
||||
position += len(handlers) |
||||
} |
||||
return mergedHandlers |
||||
} |
||||
|
||||
func (group *RouterGroup) calculateAbsolutePath(relativePath string) string { |
||||
return joinPaths(group.basePath, relativePath) |
||||
} |
||||
|
||||
func (group *RouterGroup) returnObj() IRoutes { |
||||
if group.root { |
||||
return group.engine |
||||
} |
||||
return group |
||||
} |
||||
|
||||
// injections is
|
||||
func (group *RouterGroup) injections(relativePath string) []HandlerFunc { |
||||
absPath := group.calculateAbsolutePath(relativePath) |
||||
for _, injection := range group.engine.injections { |
||||
if !injection.pattern.MatchString(absPath) { |
||||
continue |
||||
} |
||||
return injection.handlers |
||||
} |
||||
return nil |
||||
} |
@ -0,0 +1,405 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"context" |
||||
"flag" |
||||
"fmt" |
||||
"net" |
||||
"net/http" |
||||
"os" |
||||
"regexp" |
||||
"strings" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/dsn" |
||||
"github.com/bilibili/kratos/pkg/log" |
||||
"github.com/bilibili/kratos/pkg/net/ip" |
||||
"github.com/bilibili/kratos/pkg/net/metadata" |
||||
"github.com/bilibili/kratos/pkg/stat" |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
const ( |
||||
defaultMaxMemory = 32 << 20 // 32 MB
|
||||
) |
||||
|
||||
var ( |
||||
_ IRouter = &Engine{} |
||||
stats = stat.HTTPServer |
||||
|
||||
_httpDSN string |
||||
) |
||||
|
||||
func init() { |
||||
addFlag(flag.CommandLine) |
||||
} |
||||
|
||||
func addFlag(fs *flag.FlagSet) { |
||||
v := os.Getenv("HTTP") |
||||
if v == "" { |
||||
v = "tcp://0.0.0.0:8000/?timeout=1s" |
||||
} |
||||
fs.StringVar(&_httpDSN, "http", v, "listen http dsn, or use HTTP env variable.") |
||||
} |
||||
|
||||
func parseDSN(rawdsn string) *ServerConfig { |
||||
conf := new(ServerConfig) |
||||
d, err := dsn.Parse(rawdsn) |
||||
if err != nil { |
||||
panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn)) |
||||
} |
||||
if _, err = d.Bind(conf); err != nil { |
||||
panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn)) |
||||
} |
||||
return conf |
||||
} |
||||
|
||||
// Handler responds to an HTTP request.
|
||||
type Handler interface { |
||||
ServeHTTP(c *Context) |
||||
} |
||||
|
||||
// HandlerFunc http request handler function.
|
||||
type HandlerFunc func(*Context) |
||||
|
||||
// ServeHTTP calls f(ctx).
|
||||
func (f HandlerFunc) ServeHTTP(c *Context) { |
||||
f(c) |
||||
} |
||||
|
||||
// ServerConfig is the bm server config model
|
||||
type ServerConfig struct { |
||||
Network string `dsn:"network"` |
||||
Addr string `dsn:"address"` |
||||
Timeout xtime.Duration `dsn:"query.timeout"` |
||||
ReadTimeout xtime.Duration `dsn:"query.readTimeout"` |
||||
WriteTimeout xtime.Duration `dsn:"query.writeTimeout"` |
||||
} |
||||
|
||||
// MethodConfig is
|
||||
type MethodConfig struct { |
||||
Timeout xtime.Duration |
||||
} |
||||
|
||||
// Start listen and serve bm engine by given DSN.
|
||||
func (engine *Engine) Start() error { |
||||
conf := engine.conf |
||||
l, err := net.Listen(conf.Network, conf.Addr) |
||||
if err != nil { |
||||
errors.Wrapf(err, "blademaster: listen tcp: %s", conf.Addr) |
||||
return err |
||||
} |
||||
|
||||
log.Info("blademaster: start http listen addr: %s", conf.Addr) |
||||
server := &http.Server{ |
||||
ReadTimeout: time.Duration(conf.ReadTimeout), |
||||
WriteTimeout: time.Duration(conf.WriteTimeout), |
||||
} |
||||
go func() { |
||||
if err := engine.RunServer(server, l); err != nil { |
||||
if errors.Cause(err) == http.ErrServerClosed { |
||||
log.Info("blademaster: server closed") |
||||
return |
||||
} |
||||
panic(errors.Wrapf(err, "blademaster: engine.ListenServer(%+v, %+v)", server, l)) |
||||
} |
||||
}() |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// Engine is the framework's instance, it contains the muxer, middleware and configuration settings.
|
||||
// Create an instance of Engine, by using New() or Default()
|
||||
type Engine struct { |
||||
RouterGroup |
||||
|
||||
lock sync.RWMutex |
||||
conf *ServerConfig |
||||
|
||||
address string |
||||
|
||||
mux *http.ServeMux // http mux router
|
||||
server atomic.Value // store *http.Server
|
||||
metastore map[string]map[string]interface{} // metastore is the path as key and the metadata of this path as value, it export via /metadata
|
||||
|
||||
pcLock sync.RWMutex |
||||
methodConfigs map[string]*MethodConfig |
||||
|
||||
injections []injection |
||||
} |
||||
|
||||
type injection struct { |
||||
pattern *regexp.Regexp |
||||
handlers []HandlerFunc |
||||
} |
||||
|
||||
// NewServer returns a new blank Engine instance without any middleware attached.
|
||||
func NewServer(conf *ServerConfig) *Engine { |
||||
if conf == nil { |
||||
if !flag.Parsed() { |
||||
fmt.Fprint(os.Stderr, "[blademaster] please call flag.Parse() before Init warden server, some configure may not effect.\n") |
||||
} |
||||
conf = parseDSN(_httpDSN) |
||||
} |
||||
engine := &Engine{ |
||||
RouterGroup: RouterGroup{ |
||||
Handlers: nil, |
||||
basePath: "/", |
||||
root: true, |
||||
}, |
||||
address: ip.InternalIP(), |
||||
mux: http.NewServeMux(), |
||||
metastore: make(map[string]map[string]interface{}), |
||||
methodConfigs: make(map[string]*MethodConfig), |
||||
} |
||||
if err := engine.SetConfig(conf); err != nil { |
||||
panic(err) |
||||
} |
||||
engine.RouterGroup.engine = engine |
||||
// NOTE add prometheus monitor location
|
||||
engine.addRoute("GET", "/metrics", monitor()) |
||||
engine.addRoute("GET", "/metadata", engine.metadata()) |
||||
startPerf() |
||||
return engine |
||||
} |
||||
|
||||
// SetMethodConfig is used to set config on specified path
|
||||
func (engine *Engine) SetMethodConfig(path string, mc *MethodConfig) { |
||||
engine.pcLock.Lock() |
||||
engine.methodConfigs[path] = mc |
||||
engine.pcLock.Unlock() |
||||
} |
||||
|
||||
// DefaultServer returns an Engine instance with the Recovery, Logger and CSRF middleware already attached.
|
||||
func DefaultServer(conf *ServerConfig) *Engine { |
||||
engine := NewServer(conf) |
||||
engine.Use(Recovery(), Trace(), Logger()) |
||||
return engine |
||||
} |
||||
|
||||
func (engine *Engine) addRoute(method, path string, handlers ...HandlerFunc) { |
||||
if path[0] != '/' { |
||||
panic("blademaster: path must begin with '/'") |
||||
} |
||||
if method == "" { |
||||
panic("blademaster: HTTP method can not be empty") |
||||
} |
||||
if len(handlers) == 0 { |
||||
panic("blademaster: there must be at least one handler") |
||||
} |
||||
if _, ok := engine.metastore[path]; !ok { |
||||
engine.metastore[path] = make(map[string]interface{}) |
||||
} |
||||
engine.metastore[path]["method"] = method |
||||
engine.mux.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) { |
||||
c := &Context{ |
||||
Context: nil, |
||||
engine: engine, |
||||
index: -1, |
||||
handlers: nil, |
||||
Keys: nil, |
||||
method: "", |
||||
Error: nil, |
||||
} |
||||
|
||||
c.Request = req |
||||
c.Writer = w |
||||
c.handlers = handlers |
||||
c.method = method |
||||
|
||||
engine.handleContext(c) |
||||
}) |
||||
} |
||||
|
||||
// SetConfig is used to set the engine configuration.
|
||||
// Only the valid config will be loaded.
|
||||
func (engine *Engine) SetConfig(conf *ServerConfig) (err error) { |
||||
if conf.Timeout <= 0 { |
||||
return errors.New("blademaster: config timeout must greater than 0") |
||||
} |
||||
if conf.Network == "" { |
||||
conf.Network = "tcp" |
||||
} |
||||
engine.lock.Lock() |
||||
engine.conf = conf |
||||
engine.lock.Unlock() |
||||
return |
||||
} |
||||
|
||||
func (engine *Engine) methodConfig(path string) *MethodConfig { |
||||
engine.pcLock.RLock() |
||||
mc := engine.methodConfigs[path] |
||||
engine.pcLock.RUnlock() |
||||
return mc |
||||
} |
||||
|
||||
func (engine *Engine) handleContext(c *Context) { |
||||
var cancel func() |
||||
req := c.Request |
||||
ctype := req.Header.Get("Content-Type") |
||||
switch { |
||||
case strings.Contains(ctype, "multipart/form-data"): |
||||
req.ParseMultipartForm(defaultMaxMemory) |
||||
default: |
||||
req.ParseForm() |
||||
} |
||||
// get derived timeout from http request header,
|
||||
// compare with the engine configured,
|
||||
// and use the minimum one
|
||||
engine.lock.RLock() |
||||
tm := time.Duration(engine.conf.Timeout) |
||||
engine.lock.RUnlock() |
||||
// the method config is preferred
|
||||
if pc := engine.methodConfig(c.Request.URL.Path); pc != nil { |
||||
tm = time.Duration(pc.Timeout) |
||||
} |
||||
if ctm := timeout(req); ctm > 0 && tm > ctm { |
||||
tm = ctm |
||||
} |
||||
md := metadata.MD{ |
||||
metadata.Color: color(req), |
||||
metadata.RemoteIP: remoteIP(req), |
||||
metadata.RemotePort: remotePort(req), |
||||
metadata.Caller: caller(req), |
||||
metadata.Mirror: mirror(req), |
||||
} |
||||
ctx := metadata.NewContext(context.Background(), md) |
||||
if tm > 0 { |
||||
c.Context, cancel = context.WithTimeout(ctx, tm) |
||||
} else { |
||||
c.Context, cancel = context.WithCancel(ctx) |
||||
} |
||||
defer cancel() |
||||
c.Next() |
||||
} |
||||
|
||||
// Router return a http.Handler for using http.ListenAndServe() directly.
|
||||
func (engine *Engine) Router() http.Handler { |
||||
return engine.mux |
||||
} |
||||
|
||||
// Server is used to load stored http server.
|
||||
func (engine *Engine) Server() *http.Server { |
||||
s, ok := engine.server.Load().(*http.Server) |
||||
if !ok { |
||||
return nil |
||||
} |
||||
return s |
||||
} |
||||
|
||||
// Shutdown the http server without interrupting active connections.
|
||||
func (engine *Engine) Shutdown(ctx context.Context) error { |
||||
server := engine.Server() |
||||
if server == nil { |
||||
return errors.New("blademaster: no server") |
||||
} |
||||
return errors.WithStack(server.Shutdown(ctx)) |
||||
} |
||||
|
||||
// UseFunc attachs a global middleware to the router. ie. the middleware attached though UseFunc() will be
|
||||
// included in the handlers chain for every single request. Even 404, 405, static files...
|
||||
// For example, this is the right place for a logger or error management middleware.
|
||||
func (engine *Engine) UseFunc(middleware ...HandlerFunc) IRoutes { |
||||
engine.RouterGroup.UseFunc(middleware...) |
||||
return engine |
||||
} |
||||
|
||||
// Use attachs a global middleware to the router. ie. the middleware attached though Use() will be
|
||||
// included in the handlers chain for every single request. Even 404, 405, static files...
|
||||
// For example, this is the right place for a logger or error management middleware.
|
||||
func (engine *Engine) Use(middleware ...Handler) IRoutes { |
||||
engine.RouterGroup.Use(middleware...) |
||||
return engine |
||||
} |
||||
|
||||
// Ping is used to set the general HTTP ping handler.
|
||||
func (engine *Engine) Ping(handler HandlerFunc) { |
||||
engine.GET("/ping", handler) |
||||
} |
||||
|
||||
// Register is used to export metadata to discovery.
|
||||
func (engine *Engine) Register(handler HandlerFunc) { |
||||
engine.GET("/register", handler) |
||||
} |
||||
|
||||
// Run attaches the router to a http.Server and starts listening and serving HTTP requests.
|
||||
// It is a shortcut for http.ListenAndServe(addr, router)
|
||||
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||
func (engine *Engine) Run(addr ...string) (err error) { |
||||
address := resolveAddress(addr) |
||||
server := &http.Server{ |
||||
Addr: address, |
||||
Handler: engine.mux, |
||||
} |
||||
engine.server.Store(server) |
||||
if err = server.ListenAndServe(); err != nil { |
||||
err = errors.Wrapf(err, "addrs: %v", addr) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests.
|
||||
// It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router)
|
||||
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||
func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) { |
||||
server := &http.Server{ |
||||
Addr: addr, |
||||
Handler: engine.mux, |
||||
} |
||||
engine.server.Store(server) |
||||
if err = server.ListenAndServeTLS(certFile, keyFile); err != nil { |
||||
err = errors.Wrapf(err, "tls: %s/%s:%s", addr, certFile, keyFile) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// RunUnix attaches the router to a http.Server and starts listening and serving HTTP requests
|
||||
// through the specified unix socket (ie. a file).
|
||||
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||
func (engine *Engine) RunUnix(file string) (err error) { |
||||
os.Remove(file) |
||||
listener, err := net.Listen("unix", file) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "unix: %s", file) |
||||
return |
||||
} |
||||
defer listener.Close() |
||||
server := &http.Server{ |
||||
Handler: engine.mux, |
||||
} |
||||
engine.server.Store(server) |
||||
if err = server.Serve(listener); err != nil { |
||||
err = errors.Wrapf(err, "unix: %s", file) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// RunServer will serve and start listening HTTP requests by given server and listener.
|
||||
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||
func (engine *Engine) RunServer(server *http.Server, l net.Listener) (err error) { |
||||
server.Handler = engine.mux |
||||
engine.server.Store(server) |
||||
if err = server.Serve(l); err != nil { |
||||
err = errors.Wrapf(err, "listen server: %+v/%+v", server, l) |
||||
return |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (engine *Engine) metadata() HandlerFunc { |
||||
return func(c *Context) { |
||||
c.JSON(engine.metastore, nil) |
||||
} |
||||
} |
||||
|
||||
// Inject is
|
||||
func (engine *Engine) Inject(pattern string, handlers ...HandlerFunc) { |
||||
engine.injections = append(engine.injections, injection{ |
||||
pattern: regexp.MustCompile(pattern), |
||||
handlers: handlers, |
||||
}) |
||||
} |
@ -0,0 +1,40 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"os" |
||||
"path" |
||||
) |
||||
|
||||
func lastChar(str string) uint8 { |
||||
if str == "" { |
||||
panic("The length of the string can't be 0") |
||||
} |
||||
return str[len(str)-1] |
||||
} |
||||
|
||||
func joinPaths(absolutePath, relativePath string) string { |
||||
if relativePath == "" { |
||||
return absolutePath |
||||
} |
||||
|
||||
finalPath := path.Join(absolutePath, relativePath) |
||||
appendSlash := lastChar(relativePath) == '/' && lastChar(finalPath) != '/' |
||||
if appendSlash { |
||||
return finalPath + "/" |
||||
} |
||||
return finalPath |
||||
} |
||||
|
||||
func resolveAddress(addr []string) string { |
||||
switch len(addr) { |
||||
case 0: |
||||
if port := os.Getenv("PORT"); port != "" { |
||||
return ":" + port |
||||
} |
||||
return ":8080" |
||||
case 1: |
||||
return addr[0] |
||||
default: |
||||
panic("too much parameters") |
||||
} |
||||
} |
@ -0,0 +1,84 @@ |
||||
package mocktrace |
||||
|
||||
import ( |
||||
"github.com/bilibili/kratos/pkg/net/trace" |
||||
) |
||||
|
||||
// MockTrace .
|
||||
type MockTrace struct { |
||||
Spans []*MockSpan |
||||
} |
||||
|
||||
// New .
|
||||
func (m *MockTrace) New(operationName string, opts ...trace.Option) trace.Trace { |
||||
span := &MockSpan{OperationName: operationName, MockTrace: m} |
||||
m.Spans = append(m.Spans, span) |
||||
return span |
||||
} |
||||
|
||||
// Inject .
|
||||
func (m *MockTrace) Inject(t trace.Trace, format interface{}, carrier interface{}) error { |
||||
return nil |
||||
} |
||||
|
||||
// Extract .
|
||||
func (m *MockTrace) Extract(format interface{}, carrier interface{}) (trace.Trace, error) { |
||||
return &MockSpan{}, nil |
||||
} |
||||
|
||||
// MockSpan .
|
||||
type MockSpan struct { |
||||
*MockTrace |
||||
OperationName string |
||||
FinishErr error |
||||
Finished bool |
||||
Tags []trace.Tag |
||||
Logs []trace.LogField |
||||
} |
||||
|
||||
// Fork .
|
||||
func (m *MockSpan) Fork(serviceName string, operationName string) trace.Trace { |
||||
span := &MockSpan{OperationName: operationName, MockTrace: m.MockTrace} |
||||
m.Spans = append(m.Spans, span) |
||||
return span |
||||
} |
||||
|
||||
// Follow .
|
||||
func (m *MockSpan) Follow(serviceName string, operationName string) trace.Trace { |
||||
span := &MockSpan{OperationName: operationName, MockTrace: m.MockTrace} |
||||
m.Spans = append(m.Spans, span) |
||||
return span |
||||
} |
||||
|
||||
// Finish .
|
||||
func (m *MockSpan) Finish(perr *error) { |
||||
if perr != nil { |
||||
m.FinishErr = *perr |
||||
} |
||||
m.Finished = true |
||||
} |
||||
|
||||
// SetTag .
|
||||
func (m *MockSpan) SetTag(tags ...trace.Tag) trace.Trace { |
||||
m.Tags = append(m.Tags, tags...) |
||||
return m |
||||
} |
||||
|
||||
// SetLog .
|
||||
func (m *MockSpan) SetLog(logs ...trace.LogField) trace.Trace { |
||||
m.Logs = append(m.Logs, logs...) |
||||
return m |
||||
} |
||||
|
||||
// Visit .
|
||||
func (m *MockSpan) Visit(fn func(k, v string)) {} |
||||
|
||||
// SetTitle .
|
||||
func (m *MockSpan) SetTitle(title string) { |
||||
m.OperationName = title |
||||
} |
||||
|
||||
// TraceID .
|
||||
func (m *MockSpan) TraceID() string { |
||||
return "" |
||||
} |
@ -0,0 +1,24 @@ |
||||
// Package mocktrace this ut just make ci happay.
|
||||
package mocktrace |
||||
|
||||
import ( |
||||
"fmt" |
||||
"testing" |
||||
) |
||||
|
||||
func TestMockTrace(t *testing.T) { |
||||
mocktrace := &MockTrace{} |
||||
mocktrace.Inject(nil, nil, nil) |
||||
mocktrace.Extract(nil, nil) |
||||
|
||||
root := mocktrace.New("test") |
||||
root.Fork("", "") |
||||
root.Follow("", "") |
||||
root.Finish(nil) |
||||
err := fmt.Errorf("test") |
||||
root.Finish(&err) |
||||
root.SetTag() |
||||
root.SetLog() |
||||
root.Visit(func(k, v string) {}) |
||||
root.SetTitle("") |
||||
} |
Loading…
Reference in new issue