Merge remote-tracking branch 'origin/master'

pull/13/head
felixhao 6 years ago
commit dbdfe47e5b
  1. 10
      go.mod
  2. 25
      pkg/cache/memcache/README.md
  3. 187
      pkg/cache/memcache/client.go
  4. 685
      pkg/cache/memcache/conn.go
  5. 76
      pkg/cache/memcache/errors.go
  6. 136
      pkg/cache/memcache/memcache.go
  7. 59
      pkg/cache/memcache/mock.go
  8. 197
      pkg/cache/memcache/pool.go
  9. 109
      pkg/cache/memcache/trace.go
  10. 32
      pkg/cache/memcache/util.go
  11. 7
      pkg/cache/redis/README.md
  12. 57
      pkg/cache/redis/commandinfo.go
  13. 597
      pkg/cache/redis/conn.go
  14. 169
      pkg/cache/redis/doc.go
  15. 33
      pkg/cache/redis/errors.go
  16. 117
      pkg/cache/redis/log.go
  17. 36
      pkg/cache/redis/mock.go
  18. 218
      pkg/cache/redis/pool.go
  19. 152
      pkg/cache/redis/pubsub.go
  20. 51
      pkg/cache/redis/redis.go
  21. 409
      pkg/cache/redis/reply.go
  22. 559
      pkg/cache/redis/scan.go
  23. 86
      pkg/cache/redis/script.go
  24. 142
      pkg/cache/redis/trace.go
  25. 40
      pkg/database/hbase/README.md
  26. 23
      pkg/database/hbase/config.go
  27. 297
      pkg/database/hbase/hbase.go
  28. 48
      pkg/database/hbase/metrics.go
  29. 24
      pkg/database/hbase/slowlog.go
  30. 40
      pkg/database/hbase/trace.go
  31. 9
      pkg/database/sql/README.md
  32. 40
      pkg/database/sql/mysql.go
  33. 678
      pkg/database/sql/sql.go
  34. 14
      pkg/database/tidb/README.md
  35. 58
      pkg/database/tidb/discovery.go
  36. 82
      pkg/database/tidb/node_proc.go
  37. 739
      pkg/database/tidb/sql.go
  38. 38
      pkg/database/tidb/tidb.go
  39. 85
      pkg/net/http/blademaster/binding/binding.go
  40. 342
      pkg/net/http/blademaster/binding/binding_test.go
  41. 45
      pkg/net/http/blademaster/binding/default_validator.go
  42. 113
      pkg/net/http/blademaster/binding/example/test.pb.go
  43. 12
      pkg/net/http/blademaster/binding/example/test.proto
  44. 36
      pkg/net/http/blademaster/binding/example_test.go
  45. 55
      pkg/net/http/blademaster/binding/form.go
  46. 276
      pkg/net/http/blademaster/binding/form_mapping.go
  47. 22
      pkg/net/http/blademaster/binding/json.go
  48. 19
      pkg/net/http/blademaster/binding/query.go
  49. 44
      pkg/net/http/blademaster/binding/tags.go
  50. 209
      pkg/net/http/blademaster/binding/validate_test.go
  51. 22
      pkg/net/http/blademaster/binding/xml.go
  52. 306
      pkg/net/http/blademaster/context.go
  53. 249
      pkg/net/http/blademaster/cors.go
  54. 64
      pkg/net/http/blademaster/csrf.go
  55. 69
      pkg/net/http/blademaster/logger.go
  56. 7
      pkg/net/http/blademaster/metadata.go
  57. 46
      pkg/net/http/blademaster/perf.go
  58. 12
      pkg/net/http/blademaster/prometheus.go
  59. 32
      pkg/net/http/blademaster/recovery.go
  60. 30
      pkg/net/http/blademaster/render/data.go
  61. 58
      pkg/net/http/blademaster/render/json.go
  62. 38
      pkg/net/http/blademaster/render/protobuf.go
  63. 26
      pkg/net/http/blademaster/render/redirect.go
  64. 30
      pkg/net/http/blademaster/render/render.go
  65. 89
      pkg/net/http/blademaster/render/render.pb.go
  66. 14
      pkg/net/http/blademaster/render/render.proto
  67. 40
      pkg/net/http/blademaster/render/string.go
  68. 31
      pkg/net/http/blademaster/render/xml.go
  69. 166
      pkg/net/http/blademaster/routergroup.go
  70. 405
      pkg/net/http/blademaster/server.go
  71. 30
      pkg/net/http/blademaster/trace.go
  72. 40
      pkg/net/http/blademaster/utils.go
  73. 84
      pkg/net/trace/mocktrace/mocktrace.go
  74. 24
      pkg/net/trace/mocktrace/mocktrace_test.go
  75. 2
      pkg/net/trace/noop.go
  76. 2
      pkg/net/trace/sample.go
  77. 4
      pkg/net/trace/span.go
  78. 2
      pkg/net/trace/tracer.go
  79. 9
      pkg/net/trace/util.go
  80. 4
      tool/kratos/README.MD

@ -2,18 +2,28 @@ module github.com/bilibili/kratos
require ( require (
github.com/BurntSushi/toml v0.3.1 github.com/BurntSushi/toml v0.3.1
github.com/aristanetworks/goarista v0.0.0-20190409234242-46f4bc7b73ef // indirect
github.com/cznic/b v0.0.0-20181122101859-a26611c4d92d // indirect
github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 // indirect
github.com/cznic/strutil v0.0.0-20181122101858-275e90344537 // indirect
github.com/fatih/color v1.7.0 github.com/fatih/color v1.7.0
github.com/fsnotify/fsnotify v1.4.7 github.com/fsnotify/fsnotify v1.4.7
github.com/go-playground/locales v0.12.1 // indirect github.com/go-playground/locales v0.12.1 // indirect
github.com/go-playground/universal-translator v0.16.0 // indirect github.com/go-playground/universal-translator v0.16.0 // indirect
github.com/go-sql-driver/mysql v1.4.1
github.com/gogo/protobuf v1.2.0 github.com/gogo/protobuf v1.2.0
github.com/golang/protobuf v1.2.0 github.com/golang/protobuf v1.2.0
github.com/kr/pty v1.1.4 github.com/kr/pty v1.1.4
github.com/leodido/go-urn v1.1.0 // indirect github.com/leodido/go-urn v1.1.0 // indirect
github.com/pkg/errors v0.8.1 github.com/pkg/errors v0.8.1
github.com/prometheus/client_golang v0.9.2 github.com/prometheus/client_golang v0.9.2
github.com/remyoudompheng/bigfft v0.0.0-20190321074620-2f0d2b0e0001 // indirect
github.com/samuel/go-zookeeper v0.0.0-20180130194729-c4fab1ac1bec // indirect
github.com/sirupsen/logrus v1.4.1 // indirect
github.com/stretchr/testify v1.3.0 github.com/stretchr/testify v1.3.0
github.com/tsuna/gohbase v0.0.0-20190201102810-d3184c1526df
github.com/urfave/cli v1.20.0 github.com/urfave/cli v1.20.0
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 // indirect
google.golang.org/grpc v1.18.0 google.golang.org/grpc v1.18.0
gopkg.in/AlecAivazis/survey.v1 v1.8.2 gopkg.in/AlecAivazis/survey.v1 v1.8.2
gopkg.in/go-playground/assert.v1 v1.2.1 // indirect gopkg.in/go-playground/assert.v1 v1.2.1 // indirect

@ -0,0 +1,25 @@
# cache/memcache
##### 项目简介
1. 提供protobuf,gob,json序列化方式,gzip的memcache接口
#### 使用方式
```golang
// 初始化 注意这里只是示例 展示用法 不能每次都New 只需要初始化一次
mc := memcache.New(&memcache.Config{})
// 程序关闭的时候调用close方法
defer mc.Close()
// 增加 key
err = mc.Set(c, &memcache.Item{})
// 删除key
err := mc.Delete(c,key)
// 获得某个key的内容
err := mc.Get(c,key).Scan(&v)
// 获取多个key的内容
replies, err := mc.GetMulti(c, keys)
for _, key := range replies.Keys() {
if err = replies.Scan(key, &v); err != nil {
return
}
}
```

@ -0,0 +1,187 @@
package memcache
import (
"context"
)
// Memcache memcache client
type Memcache struct {
pool *Pool
}
// Reply is the result of Get
type Reply struct {
err error
item *Item
conn Conn
closed bool
}
// Replies is the result of GetMulti
type Replies struct {
err error
items map[string]*Item
usedItems map[string]struct{}
conn Conn
closed bool
}
// New get a memcache client
func New(c *Config) *Memcache {
return &Memcache{pool: NewPool(c)}
}
// Close close connection pool
func (mc *Memcache) Close() error {
return mc.pool.Close()
}
// Conn direct get a connection
func (mc *Memcache) Conn(c context.Context) Conn {
return mc.pool.Get(c)
}
// Set writes the given item, unconditionally.
func (mc *Memcache) Set(c context.Context, item *Item) (err error) {
conn := mc.pool.Get(c)
err = conn.Set(item)
conn.Close()
return
}
// Add writes the given item, if no value already exists for its key.
// ErrNotStored is returned if that condition is not met.
func (mc *Memcache) Add(c context.Context, item *Item) (err error) {
conn := mc.pool.Get(c)
err = conn.Add(item)
conn.Close()
return
}
// Replace writes the given item, but only if the server *does* already hold data for this key.
func (mc *Memcache) Replace(c context.Context, item *Item) (err error) {
conn := mc.pool.Get(c)
err = conn.Replace(item)
conn.Close()
return
}
// CompareAndSwap writes the given item that was previously returned by Get
func (mc *Memcache) CompareAndSwap(c context.Context, item *Item) (err error) {
conn := mc.pool.Get(c)
err = conn.CompareAndSwap(item)
conn.Close()
return
}
// Get sends a command to the server for gets data.
func (mc *Memcache) Get(c context.Context, key string) *Reply {
conn := mc.pool.Get(c)
item, err := conn.Get(key)
if err != nil {
conn.Close()
}
return &Reply{err: err, item: item, conn: conn}
}
// Item get raw Item
func (r *Reply) Item() *Item {
return r.item
}
// Scan converts value, read from the memcache
func (r *Reply) Scan(v interface{}) (err error) {
if r.err != nil {
return r.err
}
err = r.conn.Scan(r.item, v)
if !r.closed {
r.conn.Close()
r.closed = true
}
return
}
// GetMulti is a batch version of Get
func (mc *Memcache) GetMulti(c context.Context, keys []string) (*Replies, error) {
conn := mc.pool.Get(c)
items, err := conn.GetMulti(keys)
rs := &Replies{err: err, items: items, conn: conn, usedItems: make(map[string]struct{}, len(keys))}
if (err != nil) || (len(items) == 0) {
rs.Close()
}
return rs, err
}
// Close close rows.
func (rs *Replies) Close() (err error) {
if !rs.closed {
err = rs.conn.Close()
rs.closed = true
}
return
}
// Item get Item from rows
func (rs *Replies) Item(key string) *Item {
return rs.items[key]
}
// Scan converts value, read from key in rows
func (rs *Replies) Scan(key string, v interface{}) (err error) {
if rs.err != nil {
return rs.err
}
item, ok := rs.items[key]
if !ok {
rs.Close()
return ErrNotFound
}
rs.usedItems[key] = struct{}{}
err = rs.conn.Scan(item, v)
if (err != nil) || (len(rs.items) == len(rs.usedItems)) {
rs.Close()
}
return
}
// Keys keys of result
func (rs *Replies) Keys() (keys []string) {
keys = make([]string, 0, len(rs.items))
for key := range rs.items {
keys = append(keys, key)
}
return
}
// Touch updates the expiry for the given key.
func (mc *Memcache) Touch(c context.Context, key string, timeout int32) (err error) {
conn := mc.pool.Get(c)
err = conn.Touch(key, timeout)
conn.Close()
return
}
// Delete deletes the item with the provided key.
func (mc *Memcache) Delete(c context.Context, key string) (err error) {
conn := mc.pool.Get(c)
err = conn.Delete(key)
conn.Close()
return
}
// Increment atomically increments key by delta.
func (mc *Memcache) Increment(c context.Context, key string, delta uint64) (newValue uint64, err error) {
conn := mc.pool.Get(c)
newValue, err = conn.Increment(key, delta)
conn.Close()
return
}
// Decrement atomically decrements key by delta.
func (mc *Memcache) Decrement(c context.Context, key string, delta uint64) (newValue uint64, err error) {
conn := mc.pool.Get(c)
newValue, err = conn.Decrement(key, delta)
conn.Close()
return
}

@ -0,0 +1,685 @@
package memcache
import (
"bufio"
"bytes"
"compress/gzip"
"context"
"encoding/gob"
"encoding/json"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/gogo/protobuf/proto"
pkgerr "github.com/pkg/errors"
)
var (
crlf = []byte("\r\n")
spaceStr = string(" ")
replyOK = []byte("OK\r\n")
replyStored = []byte("STORED\r\n")
replyNotStored = []byte("NOT_STORED\r\n")
replyExists = []byte("EXISTS\r\n")
replyNotFound = []byte("NOT_FOUND\r\n")
replyDeleted = []byte("DELETED\r\n")
replyEnd = []byte("END\r\n")
replyTouched = []byte("TOUCHED\r\n")
replyValueStr = "VALUE"
replyClientErrorPrefix = []byte("CLIENT_ERROR ")
replyServerErrorPrefix = []byte("SERVER_ERROR ")
)
const (
_encodeBuf = 4096 // 4kb
// 1024*1024 - 1, set error???
_largeValue = 1000 * 1000 // 1MB
)
type reader struct {
io.Reader
}
func (r *reader) Reset(rd io.Reader) {
r.Reader = rd
}
// conn is the low-level implementation of Conn
type conn struct {
// Shared
mu sync.Mutex
err error
conn net.Conn
// Read & Write
readTimeout time.Duration
writeTimeout time.Duration
rw *bufio.ReadWriter
// Item Reader
ir bytes.Reader
// Compress
gr gzip.Reader
gw *gzip.Writer
cb bytes.Buffer
// Encoding
edb bytes.Buffer
// json
jr reader
jd *json.Decoder
je *json.Encoder
// protobuffer
ped *proto.Buffer
}
// DialOption specifies an option for dialing a Memcache server.
type DialOption struct {
f func(*dialOptions)
}
type dialOptions struct {
readTimeout time.Duration
writeTimeout time.Duration
dial func(network, addr string) (net.Conn, error)
}
// DialReadTimeout specifies the timeout for reading a single command reply.
func DialReadTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.readTimeout = d
}}
}
// DialWriteTimeout specifies the timeout for writing a single command.
func DialWriteTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.writeTimeout = d
}}
}
// DialConnectTimeout specifies the timeout for connecting to the Memcache server.
func DialConnectTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
dialer := net.Dialer{Timeout: d}
do.dial = dialer.Dial
}}
}
// DialNetDial specifies a custom dial function for creating TCP
// connections. If this option is left out, then net.Dial is
// used. DialNetDial overrides DialConnectTimeout.
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
return DialOption{func(do *dialOptions) {
do.dial = dial
}}
}
// Dial connects to the Memcache server at the given network and
// address using the specified options.
func Dial(network, address string, options ...DialOption) (Conn, error) {
do := dialOptions{
dial: net.Dial,
}
for _, option := range options {
option.f(&do)
}
netConn, err := do.dial(network, address)
if err != nil {
return nil, pkgerr.WithStack(err)
}
return NewConn(netConn, do.readTimeout, do.writeTimeout), nil
}
// NewConn returns a new memcache connection for the given net connection.
func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
if writeTimeout <= 0 || readTimeout <= 0 {
panic("must config memcache timeout")
}
c := &conn{
conn: netConn,
rw: bufio.NewReadWriter(bufio.NewReader(netConn),
bufio.NewWriter(netConn)),
readTimeout: readTimeout,
writeTimeout: writeTimeout,
}
c.jd = json.NewDecoder(&c.jr)
c.je = json.NewEncoder(&c.edb)
c.gw = gzip.NewWriter(&c.cb)
c.edb.Grow(_encodeBuf)
// NOTE reuse bytes.Buffer internal buf
// DON'T concurrency call Scan
c.ped = proto.NewBuffer(c.edb.Bytes())
return c
}
func (c *conn) Close() error {
c.mu.Lock()
err := c.err
if c.err == nil {
c.err = pkgerr.New("memcache: closed")
err = c.conn.Close()
}
c.mu.Unlock()
return err
}
func (c *conn) fatal(err error) error {
c.mu.Lock()
if c.err == nil {
c.err = pkgerr.WithStack(err)
// Close connection to force errors on subsequent calls and to unblock
// other reader or writer.
c.conn.Close()
}
c.mu.Unlock()
return c.err
}
func (c *conn) Err() error {
c.mu.Lock()
err := c.err
c.mu.Unlock()
return err
}
func (c *conn) Add(item *Item) error {
return c.populate("add", item)
}
func (c *conn) Set(item *Item) error {
return c.populate("set", item)
}
func (c *conn) Replace(item *Item) error {
return c.populate("replace", item)
}
func (c *conn) CompareAndSwap(item *Item) error {
return c.populate("cas", item)
}
func (c *conn) populate(cmd string, item *Item) (err error) {
if !legalKey(item.Key) {
return pkgerr.WithStack(ErrMalformedKey)
}
var res []byte
if res, err = c.encode(item); err != nil {
return
}
l := len(res)
count := l/(_largeValue) + 1
if count == 1 {
item.Value = res
return c.populateOne(cmd, item)
}
nItem := &Item{
Key: item.Key,
Value: []byte(strconv.Itoa(l)),
Expiration: item.Expiration,
Flags: item.Flags | flagLargeValue,
}
err = c.populateOne(cmd, nItem)
if err != nil {
return
}
k := item.Key
nItem.Flags = item.Flags
for i := 1; i <= count; i++ {
if i == count {
nItem.Value = res[_largeValue*(count-1):]
} else {
nItem.Value = res[_largeValue*(i-1) : _largeValue*i]
}
nItem.Key = fmt.Sprintf("%s%d", k, i)
if err = c.populateOne(cmd, nItem); err != nil {
return
}
}
return
}
func (c *conn) populateOne(cmd string, item *Item) (err error) {
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
// <command name> <key> <flags> <exptime> <bytes> [noreply]\r\n
if cmd == "cas" {
_, err = fmt.Fprintf(c.rw, "%s %s %d %d %d %d\r\n",
cmd, item.Key, item.Flags, item.Expiration, len(item.Value), item.cas)
} else {
_, err = fmt.Fprintf(c.rw, "%s %s %d %d %d\r\n",
cmd, item.Key, item.Flags, item.Expiration, len(item.Value))
}
if err != nil {
return c.fatal(err)
}
c.rw.Write(item.Value)
c.rw.Write(crlf)
if err = c.rw.Flush(); err != nil {
return c.fatal(err)
}
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
line, err := c.rw.ReadSlice('\n')
if err != nil {
return c.fatal(err)
}
switch {
case bytes.Equal(line, replyStored):
return nil
case bytes.Equal(line, replyNotStored):
return ErrNotStored
case bytes.Equal(line, replyExists):
return ErrCASConflict
case bytes.Equal(line, replyNotFound):
return ErrNotFound
}
return pkgerr.WithStack(protocolError(string(line)))
}
func (c *conn) Get(key string) (r *Item, err error) {
if !legalKey(key) {
return nil, pkgerr.WithStack(ErrMalformedKey)
}
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", key); err != nil {
return nil, c.fatal(err)
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(err)
}
if err = c.parseGetReply(func(it *Item) {
r = it
}); err != nil {
return
}
if r == nil {
err = ErrNotFound
return
}
if r.Flags&flagLargeValue != flagLargeValue {
return
}
if r, err = c.getLargeValue(r); err != nil {
return
}
return
}
func (c *conn) GetMulti(keys []string) (res map[string]*Item, err error) {
for _, key := range keys {
if !legalKey(key) {
return nil, pkgerr.WithStack(ErrMalformedKey)
}
}
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil {
return nil, c.fatal(err)
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(err)
}
res = make(map[string]*Item, len(keys))
if err = c.parseGetReply(func(it *Item) {
res[it.Key] = it
}); err != nil {
return
}
for k, v := range res {
if v.Flags&flagLargeValue != flagLargeValue {
continue
}
r, err := c.getLargeValue(v)
if err != nil {
return res, err
}
res[k] = r
}
return
}
func (c *conn) getMulti(keys []string) (res map[string]*Item, err error) {
for _, key := range keys {
if !legalKey(key) {
return nil, pkgerr.WithStack(ErrMalformedKey)
}
}
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil {
return nil, c.fatal(err)
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(err)
}
res = make(map[string]*Item, len(keys))
err = c.parseGetReply(func(it *Item) {
res[it.Key] = it
})
return
}
func (c *conn) getLargeValue(it *Item) (r *Item, err error) {
l, err := strconv.Atoi(string(it.Value))
if err != nil {
return
}
count := l/_largeValue + 1
keys := make([]string, 0, count)
for i := 1; i <= count; i++ {
keys = append(keys, fmt.Sprintf("%s%d", it.Key, i))
}
items, err := c.getMulti(keys)
if err != nil {
return
}
if len(items) < count {
err = ErrNotFound
return
}
v := make([]byte, 0, l)
for _, k := range keys {
if items[k] == nil || items[k].Value == nil {
err = ErrNotFound
return
}
v = append(v, items[k].Value...)
}
it.Value = v
it.Flags = it.Flags ^ flagLargeValue
r = it
return
}
func (c *conn) parseGetReply(f func(*Item)) error {
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
for {
line, err := c.rw.ReadSlice('\n')
if err != nil {
return c.fatal(err)
}
if bytes.Equal(line, replyEnd) {
return nil
}
if bytes.HasPrefix(line, replyServerErrorPrefix) {
errMsg := line[len(replyServerErrorPrefix):]
return c.fatal(protocolError(errMsg))
}
it := new(Item)
size, err := scanGetReply(line, it)
if err != nil {
return c.fatal(err)
}
it.Value = make([]byte, size+2)
if _, err = io.ReadFull(c.rw, it.Value); err != nil {
return c.fatal(err)
}
if !bytes.HasSuffix(it.Value, crlf) {
return c.fatal(protocolError("corrupt get reply, no except CRLF"))
}
it.Value = it.Value[:size]
f(it)
}
}
func scanGetReply(line []byte, item *Item) (size int, err error) {
if !bytes.HasSuffix(line, crlf) {
return 0, protocolError("corrupt get reply, no except CRLF")
}
// VALUE <key> <flags> <bytes> [<cas unique>]
chunks := strings.Split(string(line[:len(line)-2]), spaceStr)
if len(chunks) < 4 {
return 0, protocolError("corrupt get reply")
}
if chunks[0] != replyValueStr {
return 0, protocolError("corrupt get reply, no except VALUE")
}
item.Key = chunks[1]
flags64, err := strconv.ParseUint(chunks[2], 10, 32)
if err != nil {
return 0, err
}
item.Flags = uint32(flags64)
if size, err = strconv.Atoi(chunks[3]); err != nil {
return
}
if len(chunks) > 4 {
item.cas, err = strconv.ParseUint(chunks[4], 10, 64)
}
return
}
func (c *conn) Touch(key string, expire int32) (err error) {
if !legalKey(key) {
return pkgerr.WithStack(ErrMalformedKey)
}
line, err := c.writeReadLine("touch %s %d\r\n", key, expire)
if err != nil {
return err
}
switch {
case bytes.Equal(line, replyTouched):
return nil
case bytes.Equal(line, replyNotFound):
return ErrNotFound
default:
return pkgerr.WithStack(protocolError(string(line)))
}
}
func (c *conn) Increment(key string, delta uint64) (uint64, error) {
return c.incrDecr("incr", key, delta)
}
func (c *conn) Decrement(key string, delta uint64) (newValue uint64, err error) {
return c.incrDecr("decr", key, delta)
}
func (c *conn) incrDecr(cmd, key string, delta uint64) (uint64, error) {
if !legalKey(key) {
return 0, pkgerr.WithStack(ErrMalformedKey)
}
line, err := c.writeReadLine("%s %s %d\r\n", cmd, key, delta)
if err != nil {
return 0, err
}
switch {
case bytes.Equal(line, replyNotFound):
return 0, ErrNotFound
case bytes.HasPrefix(line, replyClientErrorPrefix):
errMsg := line[len(replyClientErrorPrefix):]
return 0, pkgerr.WithStack(protocolError(errMsg))
}
val, err := strconv.ParseUint(string(line[:len(line)-2]), 10, 64)
if err != nil {
return 0, err
}
return val, nil
}
func (c *conn) Delete(key string) (err error) {
if !legalKey(key) {
return pkgerr.WithStack(ErrMalformedKey)
}
line, err := c.writeReadLine("delete %s\r\n", key)
if err != nil {
return err
}
switch {
case bytes.Equal(line, replyOK):
return nil
case bytes.Equal(line, replyDeleted):
return nil
case bytes.Equal(line, replyNotStored):
return ErrNotStored
case bytes.Equal(line, replyExists):
return ErrCASConflict
case bytes.Equal(line, replyNotFound):
return ErrNotFound
}
return pkgerr.WithStack(protocolError(string(line)))
}
func (c *conn) writeReadLine(format string, args ...interface{}) ([]byte, error) {
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
_, err := fmt.Fprintf(c.rw, format, args...)
if err != nil {
return nil, c.fatal(pkgerr.WithStack(err))
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(pkgerr.WithStack(err))
}
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
line, err := c.rw.ReadSlice('\n')
if err != nil {
return line, c.fatal(pkgerr.WithStack(err))
}
return line, nil
}
func (c *conn) Scan(item *Item, v interface{}) (err error) {
c.ir.Reset(item.Value)
if item.Flags&FlagGzip == FlagGzip {
if err = c.gr.Reset(&c.ir); err != nil {
return
}
if err = c.decode(&c.gr, item, v); err != nil {
err = pkgerr.WithStack(err)
return
}
err = c.gr.Close()
} else {
err = c.decode(&c.ir, item, v)
}
err = pkgerr.WithStack(err)
return
}
func (c *conn) WithContext(ctx context.Context) Conn {
// FIXME: implement WithContext
return c
}
func (c *conn) encode(item *Item) (data []byte, err error) {
if (item.Flags | _flagEncoding) == _flagEncoding {
if item.Value == nil {
return nil, ErrItem
}
} else if item.Object == nil {
return nil, ErrItem
}
// encoding
switch {
case item.Flags&FlagGOB == FlagGOB:
c.edb.Reset()
if err = gob.NewEncoder(&c.edb).Encode(item.Object); err != nil {
return
}
data = c.edb.Bytes()
case item.Flags&FlagProtobuf == FlagProtobuf:
c.edb.Reset()
c.ped.SetBuf(c.edb.Bytes())
pb, ok := item.Object.(proto.Message)
if !ok {
err = ErrItemObject
return
}
if err = c.ped.Marshal(pb); err != nil {
return
}
data = c.ped.Bytes()
case item.Flags&FlagJSON == FlagJSON:
c.edb.Reset()
if err = c.je.Encode(item.Object); err != nil {
return
}
data = c.edb.Bytes()
default:
data = item.Value
}
// compress
if item.Flags&FlagGzip == FlagGzip {
c.cb.Reset()
c.gw.Reset(&c.cb)
if _, err = c.gw.Write(data); err != nil {
return
}
if err = c.gw.Close(); err != nil {
return
}
data = c.cb.Bytes()
}
if len(data) > 8000000 {
err = ErrValueSize
}
return
}
func (c *conn) decode(rd io.Reader, item *Item, v interface{}) (err error) {
var data []byte
switch {
case item.Flags&FlagGOB == FlagGOB:
err = gob.NewDecoder(rd).Decode(v)
case item.Flags&FlagJSON == FlagJSON:
c.jr.Reset(rd)
err = c.jd.Decode(v)
default:
data = item.Value
if item.Flags&FlagGzip == FlagGzip {
c.edb.Reset()
if _, err = io.Copy(&c.edb, rd); err != nil {
return
}
data = c.edb.Bytes()
}
if item.Flags&FlagProtobuf == FlagProtobuf {
m, ok := v.(proto.Message)
if !ok {
err = ErrItemObject
return
}
c.ped.SetBuf(data)
err = c.ped.Unmarshal(m)
} else {
switch v.(type) {
case *[]byte:
d := v.(*[]byte)
*d = data
case *string:
d := v.(*string)
*d = string(data)
case interface{}:
err = json.Unmarshal(data, v)
}
}
}
return
}
func legalKey(key string) bool {
if len(key) > 250 || len(key) == 0 {
return false
}
for i := 0; i < len(key); i++ {
if key[i] <= ' ' || key[i] == 0x7f {
return false
}
}
return true
}

@ -0,0 +1,76 @@
package memcache
import (
"errors"
"fmt"
"strings"
pkgerr "github.com/pkg/errors"
)
var (
// ErrNotFound not found
ErrNotFound = errors.New("memcache: key not found")
// ErrExists exists
ErrExists = errors.New("memcache: key exists")
// ErrNotStored not stored
ErrNotStored = errors.New("memcache: key not stored")
// ErrCASConflict means that a CompareAndSwap call failed due to the
// cached value being modified between the Get and the CompareAndSwap.
// If the cached value was simply evicted rather than replaced,
// ErrNotStored will be returned instead.
ErrCASConflict = errors.New("memcache: compare-and-swap conflict")
// ErrPoolExhausted is returned from a pool connection method (Store, Get,
// Delete, IncrDecr, Err) when the maximum number of database connections
// in the pool has been reached.
ErrPoolExhausted = errors.New("memcache: connection pool exhausted")
// ErrPoolClosed pool closed
ErrPoolClosed = errors.New("memcache: connection pool closed")
// ErrConnClosed conn closed
ErrConnClosed = errors.New("memcache: connection closed")
// ErrMalformedKey is returned when an invalid key is used.
// Keys must be at maximum 250 bytes long and not
// contain whitespace or control characters.
ErrMalformedKey = errors.New("memcache: malformed key is too long or contains invalid characters")
// ErrValueSize item value size must less than 1mb
ErrValueSize = errors.New("memcache: item value size must not greater than 1mb")
// ErrStat stat error for monitor
ErrStat = errors.New("memcache unexpected errors")
// ErrItem item nil.
ErrItem = errors.New("memcache: item object nil")
// ErrItemObject object type Assertion failed
ErrItemObject = errors.New("memcache: item object protobuf type assertion failed")
)
type protocolError string
func (pe protocolError) Error() string {
return fmt.Sprintf("memcache: %s (possible server error or unsupported concurrent read by application)", string(pe))
}
func formatErr(err error) string {
e := pkgerr.Cause(err)
switch e {
case ErrNotFound, ErrExists, ErrNotStored, nil:
return ""
default:
es := e.Error()
switch {
case strings.HasPrefix(es, "read"):
return "read timeout"
case strings.HasPrefix(es, "dial"):
return "dial timeout"
case strings.HasPrefix(es, "write"):
return "write timeout"
case strings.Contains(es, "EOF"):
return "eof"
case strings.Contains(es, "reset"):
return "reset"
case strings.Contains(es, "broken"):
return "broken pipe"
default:
return "unexpected err"
}
}
}

@ -0,0 +1,136 @@
package memcache
import (
"context"
)
// Error represents an error returned in a command reply.
type Error string
func (err Error) Error() string { return string(err) }
const (
// Flag, 15(encoding) bit+ 17(compress) bit
// FlagRAW default flag.
FlagRAW = uint32(0)
// FlagGOB gob encoding.
FlagGOB = uint32(1) << 0
// FlagJSON json encoding.
FlagJSON = uint32(1) << 1
// FlagProtobuf protobuf
FlagProtobuf = uint32(1) << 2
_flagEncoding = uint32(0xFFFF8000)
// FlagGzip gzip compress.
FlagGzip = uint32(1) << 15
// left mv 31??? not work!!!
flagLargeValue = uint32(1) << 30
)
// Item is an reply to be got or stored in a memcached server.
type Item struct {
// Key is the Item's key (250 bytes maximum).
Key string
// Value is the Item's value.
Value []byte
// Object is the Item's object for use codec.
Object interface{}
// Flags are server-opaque flags whose semantics are entirely
// up to the app.
Flags uint32
// Expiration is the cache expiration time, in seconds: either a relative
// time from now (up to 1 month), or an absolute Unix epoch time.
// Zero means the Item has no expiration time.
Expiration int32
// Compare and swap ID.
cas uint64
}
// Conn represents a connection to a Memcache server.
// Command Reference: https://github.com/memcached/memcached/wiki/Commands
type Conn interface {
// Close closes the connection.
Close() error
// Err returns a non-nil value if the connection is broken. The returned
// value is either the first non-nil value returned from the underlying
// network connection or a protocol parsing error. Applications should
// close broken connections.
Err() error
// Add writes the given item, if no value already exists for its key.
// ErrNotStored is returned if that condition is not met.
Add(item *Item) error
// Set writes the given item, unconditionally.
Set(item *Item) error
// Replace writes the given item, but only if the server *does* already
// hold data for this key.
Replace(item *Item) error
// Get sends a command to the server for gets data.
Get(key string) (*Item, error)
// GetMulti is a batch version of Get. The returned map from keys to items
// may have fewer elements than the input slice, due to memcache cache
// misses. Each key must be at most 250 bytes in length.
// If no error is returned, the returned map will also be non-nil.
GetMulti(keys []string) (map[string]*Item, error)
// Delete deletes the item with the provided key.
// The error ErrCacheMiss is returned if the item didn't already exist in
// the cache.
Delete(key string) error
// Increment atomically increments key by delta. The return value is the
// new value after being incremented or an error. If the value didn't exist
// in memcached the error is ErrCacheMiss. The value in memcached must be
// an decimal number, or an error will be returned.
// On 64-bit overflow, the new value wraps around.
Increment(key string, delta uint64) (newValue uint64, err error)
// Decrement atomically decrements key by delta. The return value is the
// new value after being decremented or an error. If the value didn't exist
// in memcached the error is ErrCacheMiss. The value in memcached must be
// an decimal number, or an error will be returned. On underflow, the new
// value is capped at zero and does not wrap around.
Decrement(key string, delta uint64) (newValue uint64, err error)
// CompareAndSwap writes the given item that was previously returned by
// Get, if the value was neither modified or evicted between the Get and
// the CompareAndSwap calls. The item's Key should not change between calls
// but all other item fields may differ. ErrCASConflict is returned if the
// value was modified in between the calls.
// ErrNotStored is returned if the value was evicted in between the calls.
CompareAndSwap(item *Item) error
// Touch updates the expiry for the given key. The seconds parameter is
// either a Unix timestamp or, if seconds is less than 1 month, the number
// of seconds into the future at which time the item will expire.
//ErrCacheMiss is returned if the key is not in the cache. The key must be
// at most 250 bytes in length.
Touch(key string, seconds int32) (err error)
// Scan converts value read from the memcache into the following
// common Go types and special types:
//
// *string
// *[]byte
// *interface{}
//
Scan(item *Item, v interface{}) (err error)
// WithContext return a Conn with its context changed to ctx
// the context controls the entire lifetime of Conn before you change it
// NOTE: this method is not thread-safe
WithContext(ctx context.Context) Conn
}

@ -0,0 +1,59 @@
package memcache
import (
"context"
)
// MockErr for unit test.
type MockErr struct {
Error error
}
var _ Conn = MockErr{}
// MockWith return a mock conn.
func MockWith(err error) MockErr {
return MockErr{Error: err}
}
// Err .
func (m MockErr) Err() error { return m.Error }
// Close .
func (m MockErr) Close() error { return m.Error }
// Add .
func (m MockErr) Add(item *Item) error { return m.Error }
// Set .
func (m MockErr) Set(item *Item) error { return m.Error }
// Replace .
func (m MockErr) Replace(item *Item) error { return m.Error }
// CompareAndSwap .
func (m MockErr) CompareAndSwap(item *Item) error { return m.Error }
// Get .
func (m MockErr) Get(key string) (*Item, error) { return nil, m.Error }
// GetMulti .
func (m MockErr) GetMulti(keys []string) (map[string]*Item, error) { return nil, m.Error }
// Touch .
func (m MockErr) Touch(key string, timeout int32) error { return m.Error }
// Delete .
func (m MockErr) Delete(key string) error { return m.Error }
// Increment .
func (m MockErr) Increment(key string, delta uint64) (uint64, error) { return 0, m.Error }
// Decrement .
func (m MockErr) Decrement(key string, delta uint64) (uint64, error) { return 0, m.Error }
// Scan .
func (m MockErr) Scan(item *Item, v interface{}) error { return m.Error }
// WithContext .
func (m MockErr) WithContext(ctx context.Context) Conn { return m }

@ -0,0 +1,197 @@
package memcache
import (
"context"
"io"
"time"
"github.com/bilibili/kratos/pkg/container/pool"
"github.com/bilibili/kratos/pkg/stat"
xtime "github.com/bilibili/kratos/pkg/time"
)
var stats = stat.Cache
// Config memcache config.
type Config struct {
*pool.Config
Name string // memcache name, for trace
Proto string
Addr string
DialTimeout xtime.Duration
ReadTimeout xtime.Duration
WriteTimeout xtime.Duration
}
// Pool memcache connection pool struct.
type Pool struct {
p pool.Pool
c *Config
}
// NewPool new a memcache conn pool.
func NewPool(c *Config) (p *Pool) {
if c.DialTimeout <= 0 || c.ReadTimeout <= 0 || c.WriteTimeout <= 0 {
panic("must config memcache timeout")
}
p1 := pool.NewList(c.Config)
cnop := DialConnectTimeout(time.Duration(c.DialTimeout))
rdop := DialReadTimeout(time.Duration(c.ReadTimeout))
wrop := DialWriteTimeout(time.Duration(c.WriteTimeout))
p1.New = func(ctx context.Context) (io.Closer, error) {
conn, err := Dial(c.Proto, c.Addr, cnop, rdop, wrop)
return &traceConn{Conn: conn, address: c.Addr}, err
}
p = &Pool{p: p1, c: c}
return
}
// Get gets a connection. The application must close the returned connection.
// This method always returns a valid connection so that applications can defer
// error handling to the first use of the connection. If there is an error
// getting an underlying connection, then the connection Err, Do, Send, Flush
// and Receive methods return that error.
func (p *Pool) Get(ctx context.Context) Conn {
c, err := p.p.Get(ctx)
if err != nil {
return errorConnection{err}
}
c1, _ := c.(Conn)
return &pooledConnection{p: p, c: c1.WithContext(ctx), ctx: ctx}
}
// Close release the resources used by the pool.
func (p *Pool) Close() error {
return p.p.Close()
}
type pooledConnection struct {
p *Pool
c Conn
ctx context.Context
}
func pstat(key string, t time.Time, err error) {
stats.Timing(key, int64(time.Since(t)/time.Millisecond))
if err != nil {
if msg := formatErr(err); msg != "" {
stats.Incr("memcache", msg)
}
}
}
func (pc *pooledConnection) Close() error {
c := pc.c
if _, ok := c.(errorConnection); ok {
return nil
}
pc.c = errorConnection{ErrConnClosed}
pc.p.p.Put(context.Background(), c, c.Err() != nil)
return nil
}
func (pc *pooledConnection) Err() error {
return pc.c.Err()
}
func (pc *pooledConnection) Set(item *Item) (err error) {
now := time.Now()
err = pc.c.Set(item)
pstat("memcache:set", now, err)
return
}
func (pc *pooledConnection) Add(item *Item) (err error) {
now := time.Now()
err = pc.c.Add(item)
pstat("memcache:add", now, err)
return
}
func (pc *pooledConnection) Replace(item *Item) (err error) {
now := time.Now()
err = pc.c.Replace(item)
pstat("memcache:replace", now, err)
return
}
func (pc *pooledConnection) CompareAndSwap(item *Item) (err error) {
now := time.Now()
err = pc.c.CompareAndSwap(item)
pstat("memcache:cas", now, err)
return
}
func (pc *pooledConnection) Get(key string) (r *Item, err error) {
now := time.Now()
r, err = pc.c.Get(key)
pstat("memcache:get", now, err)
return
}
func (pc *pooledConnection) GetMulti(keys []string) (res map[string]*Item, err error) {
// if keys is empty slice returns empty map direct
if len(keys) == 0 {
return make(map[string]*Item), nil
}
now := time.Now()
res, err = pc.c.GetMulti(keys)
pstat("memcache:gets", now, err)
return
}
func (pc *pooledConnection) Touch(key string, timeout int32) (err error) {
now := time.Now()
err = pc.c.Touch(key, timeout)
pstat("memcache:touch", now, err)
return
}
func (pc *pooledConnection) Scan(item *Item, v interface{}) error {
return pc.c.Scan(item, v)
}
func (pc *pooledConnection) WithContext(ctx context.Context) Conn {
// TODO: set context
pc.ctx = ctx
return pc
}
func (pc *pooledConnection) Delete(key string) (err error) {
now := time.Now()
err = pc.c.Delete(key)
pstat("memcache:delete", now, err)
return
}
func (pc *pooledConnection) Increment(key string, delta uint64) (newValue uint64, err error) {
now := time.Now()
newValue, err = pc.c.Increment(key, delta)
pstat("memcache:increment", now, err)
return
}
func (pc *pooledConnection) Decrement(key string, delta uint64) (newValue uint64, err error) {
now := time.Now()
newValue, err = pc.c.Decrement(key, delta)
pstat("memcache:decrement", now, err)
return
}
type errorConnection struct{ err error }
func (ec errorConnection) Err() error { return ec.err }
func (ec errorConnection) Close() error { return ec.err }
func (ec errorConnection) Add(item *Item) error { return ec.err }
func (ec errorConnection) Set(item *Item) error { return ec.err }
func (ec errorConnection) Replace(item *Item) error { return ec.err }
func (ec errorConnection) CompareAndSwap(item *Item) error { return ec.err }
func (ec errorConnection) Get(key string) (*Item, error) { return nil, ec.err }
func (ec errorConnection) GetMulti(keys []string) (map[string]*Item, error) { return nil, ec.err }
func (ec errorConnection) Touch(key string, timeout int32) error { return ec.err }
func (ec errorConnection) Delete(key string) error { return ec.err }
func (ec errorConnection) Increment(key string, delta uint64) (uint64, error) { return 0, ec.err }
func (ec errorConnection) Decrement(key string, delta uint64) (uint64, error) { return 0, ec.err }
func (ec errorConnection) Scan(item *Item, v interface{}) error { return ec.err }
func (ec errorConnection) WithContext(ctx context.Context) Conn { return ec }

@ -0,0 +1,109 @@
package memcache
import (
"context"
"strconv"
"strings"
"time"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/net/trace"
)
const (
_traceFamily = "memcache"
_traceSpanKind = "client"
_traceComponentName = "library/cache/memcache"
_tracePeerService = "memcache"
_slowLogDuration = time.Millisecond * 250
)
type traceConn struct {
Conn
ctx context.Context
address string
}
func (t *traceConn) setTrace(action, statement string) func(error) error {
now := time.Now()
parent, ok := trace.FromContext(t.ctx)
if !ok {
return func(err error) error { return err }
}
span := parent.Fork(_traceFamily, "Memcache:"+action)
span.SetTag(
trace.String(trace.TagSpanKind, _traceSpanKind),
trace.String(trace.TagComponent, _traceComponentName),
trace.String(trace.TagPeerService, _tracePeerService),
trace.String(trace.TagPeerAddress, t.address),
trace.String(trace.TagDBStatement, action+" "+statement),
)
return func(err error) error {
span.Finish(&err)
t := time.Since(now)
if t > _slowLogDuration {
log.Warn("%s slow log action: %s key: %s time: %v", _traceFamily, action, statement, t)
}
return err
}
}
func (t *traceConn) WithContext(ctx context.Context) Conn {
t.ctx = ctx
t.Conn = t.Conn.WithContext(ctx)
return t
}
func (t *traceConn) Add(item *Item) error {
finishFn := t.setTrace("Add", item.Key)
return finishFn(t.Conn.Add(item))
}
func (t *traceConn) Set(item *Item) error {
finishFn := t.setTrace("Set", item.Key)
return finishFn(t.Conn.Set(item))
}
func (t *traceConn) Replace(item *Item) error {
finishFn := t.setTrace("Replace", item.Key)
return finishFn(t.Conn.Replace(item))
}
func (t *traceConn) Get(key string) (*Item, error) {
finishFn := t.setTrace("Get", key)
item, err := t.Conn.Get(key)
return item, finishFn(err)
}
func (t *traceConn) GetMulti(keys []string) (map[string]*Item, error) {
finishFn := t.setTrace("GetMulti", strings.Join(keys, " "))
items, err := t.Conn.GetMulti(keys)
return items, finishFn(err)
}
func (t *traceConn) Delete(key string) error {
finishFn := t.setTrace("Delete", key)
return finishFn(t.Conn.Delete(key))
}
func (t *traceConn) Increment(key string, delta uint64) (newValue uint64, err error) {
finishFn := t.setTrace("Increment", key+" "+strconv.FormatUint(delta, 10))
newValue, err = t.Conn.Increment(key, delta)
return newValue, finishFn(err)
}
func (t *traceConn) Decrement(key string, delta uint64) (newValue uint64, err error) {
finishFn := t.setTrace("Decrement", key+" "+strconv.FormatUint(delta, 10))
newValue, err = t.Conn.Decrement(key, delta)
return newValue, finishFn(err)
}
func (t *traceConn) CompareAndSwap(item *Item) error {
finishFn := t.setTrace("CompareAndSwap", item.Key)
return finishFn(t.Conn.CompareAndSwap(item))
}
func (t *traceConn) Touch(key string, seconds int32) (err error) {
finishFn := t.setTrace("Touch", key+" "+strconv.Itoa(int(seconds)))
return finishFn(t.Conn.Touch(key, seconds))
}

@ -0,0 +1,32 @@
package memcache
import (
"github.com/gogo/protobuf/proto"
)
// RawItem item with FlagRAW flag.
//
// Expiration is the cache expiration time, in seconds: either a relative
// time from now (up to 1 month), or an absolute Unix epoch time.
// Zero means the Item has no expiration time.
func RawItem(key string, data []byte, flags uint32, expiration int32) *Item {
return &Item{Key: key, Flags: flags | FlagRAW, Value: data, Expiration: expiration}
}
// JSONItem item with FlagJSON flag.
//
// Expiration is the cache expiration time, in seconds: either a relative
// time from now (up to 1 month), or an absolute Unix epoch time.
// Zero means the Item has no expiration time.
func JSONItem(key string, v interface{}, flags uint32, expiration int32) *Item {
return &Item{Key: key, Flags: flags | FlagJSON, Object: v, Expiration: expiration}
}
// ProtobufItem item with FlagProtobuf flag.
//
// Expiration is the cache expiration time, in seconds: either a relative
// time from now (up to 1 month), or an absolute Unix epoch time.
// Zero means the Item has no expiration time.
func ProtobufItem(key string, message proto.Message, flags uint32, expiration int32) *Item {
return &Item{Key: key, Flags: flags | FlagProtobuf, Object: message, Expiration: expiration}
}

@ -0,0 +1,7 @@
# cache/redis
##### 项目简介
1. 提供redis接口
#### 使用方式
请参考doc.go

@ -0,0 +1,57 @@
// Copyright 2014 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"strings"
)
// redis state
const (
WatchState = 1 << iota
MultiState
SubscribeState
MonitorState
)
// CommandInfo command info.
type CommandInfo struct {
Set, Clear int
}
var commandInfos = map[string]CommandInfo{
"WATCH": {Set: WatchState},
"UNWATCH": {Clear: WatchState},
"MULTI": {Set: MultiState},
"EXEC": {Clear: WatchState | MultiState},
"DISCARD": {Clear: WatchState | MultiState},
"PSUBSCRIBE": {Set: SubscribeState},
"SUBSCRIBE": {Set: SubscribeState},
"MONITOR": {Set: MonitorState},
}
func init() {
for n, ci := range commandInfos {
commandInfos[strings.ToLower(n)] = ci
}
}
// LookupCommandInfo get command info.
func LookupCommandInfo(commandName string) CommandInfo {
if ci, ok := commandInfos[commandName]; ok {
return ci
}
return commandInfos[strings.ToUpper(commandName)]
}

@ -0,0 +1,597 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
"net/url"
"regexp"
"strconv"
"sync"
"time"
"github.com/bilibili/kratos/pkg/stat"
"github.com/pkg/errors"
)
var stats = stat.Cache
// conn is the low-level implementation of Conn
type conn struct {
// Shared
mu sync.Mutex
pending int
err error
conn net.Conn
// Read
readTimeout time.Duration
br *bufio.Reader
// Write
writeTimeout time.Duration
bw *bufio.Writer
// Scratch space for formatting argument length.
// '*' or '$', length, "\r\n"
lenScratch [32]byte
// Scratch space for formatting integers and floats.
numScratch [40]byte
// stat func,default prom
stat func(string, *error) func()
}
func statfunc(cmd string, err *error) func() {
now := time.Now()
return func() {
stats.Timing(fmt.Sprintf("redis:%s", cmd), int64(time.Since(now)/time.Millisecond))
if err != nil {
if msg := formatErr(*err); msg != "" {
stats.Incr("redis", msg)
}
}
}
}
// DialTimeout acts like Dial but takes timeouts for establishing the
// connection to the server, writing a command and reading a reply.
//
// Deprecated: Use Dial with options instead.
func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) {
return Dial(network, address,
DialConnectTimeout(connectTimeout),
DialReadTimeout(readTimeout),
DialWriteTimeout(writeTimeout))
}
// DialOption specifies an option for dialing a Redis server.
type DialOption struct {
f func(*dialOptions)
}
type dialOptions struct {
readTimeout time.Duration
writeTimeout time.Duration
dial func(network, addr string) (net.Conn, error)
db int
password string
stat func(string, *error) func()
}
// DialStats specifies stat func for stats.default statfunc.
func DialStats(fn func(string, *error) func()) DialOption {
return DialOption{func(do *dialOptions) {
do.stat = fn
}}
}
// DialReadTimeout specifies the timeout for reading a single command reply.
func DialReadTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.readTimeout = d
}}
}
// DialWriteTimeout specifies the timeout for writing a single command.
func DialWriteTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.writeTimeout = d
}}
}
// DialConnectTimeout specifies the timeout for connecting to the Redis server.
func DialConnectTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
dialer := net.Dialer{Timeout: d}
do.dial = dialer.Dial
}}
}
// DialNetDial specifies a custom dial function for creating TCP
// connections. If this option is left out, then net.Dial is
// used. DialNetDial overrides DialConnectTimeout.
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
return DialOption{func(do *dialOptions) {
do.dial = dial
}}
}
// DialDatabase specifies the database to select when dialing a connection.
func DialDatabase(db int) DialOption {
return DialOption{func(do *dialOptions) {
do.db = db
}}
}
// DialPassword specifies the password to use when connecting to
// the Redis server.
func DialPassword(password string) DialOption {
return DialOption{func(do *dialOptions) {
do.password = password
}}
}
// Dial connects to the Redis server at the given network and
// address using the specified options.
func Dial(network, address string, options ...DialOption) (Conn, error) {
do := dialOptions{
dial: net.Dial,
}
for _, option := range options {
option.f(&do)
}
netConn, err := do.dial(network, address)
if err != nil {
return nil, errors.WithStack(err)
}
c := &conn{
conn: netConn,
bw: bufio.NewWriter(netConn),
br: bufio.NewReader(netConn),
readTimeout: do.readTimeout,
writeTimeout: do.writeTimeout,
stat: statfunc,
}
if do.password != "" {
if _, err := c.Do("AUTH", do.password); err != nil {
netConn.Close()
return nil, errors.WithStack(err)
}
}
if do.db != 0 {
if _, err := c.Do("SELECT", do.db); err != nil {
netConn.Close()
return nil, errors.WithStack(err)
}
}
if do.stat != nil {
c.stat = do.stat
}
return c, nil
}
var pathDBRegexp = regexp.MustCompile(`/(\d+)\z`)
// DialURL connects to a Redis server at the given URL using the Redis
// URI scheme. URLs should follow the draft IANA specification for the
// scheme (https://www.iana.org/assignments/uri-schemes/prov/redis).
func DialURL(rawurl string, options ...DialOption) (Conn, error) {
u, err := url.Parse(rawurl)
if err != nil {
return nil, errors.WithStack(err)
}
if u.Scheme != "redis" {
return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme)
}
// As per the IANA draft spec, the host defaults to localhost and
// the port defaults to 6379.
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
// assume port is missing
host = u.Host
port = "6379"
}
if host == "" {
host = "localhost"
}
address := net.JoinHostPort(host, port)
if u.User != nil {
password, isSet := u.User.Password()
if isSet {
options = append(options, DialPassword(password))
}
}
match := pathDBRegexp.FindStringSubmatch(u.Path)
if len(match) == 2 {
db, err := strconv.Atoi(match[1])
if err != nil {
return nil, errors.Errorf("invalid database: %s", u.Path[1:])
}
if db != 0 {
options = append(options, DialDatabase(db))
}
} else if u.Path != "" {
return nil, errors.Errorf("invalid database: %s", u.Path[1:])
}
return Dial("tcp", address, options...)
}
// NewConn new a redis conn.
func NewConn(c *Config) (cn Conn, err error) {
cnop := DialConnectTimeout(time.Duration(c.DialTimeout))
rdop := DialReadTimeout(time.Duration(c.ReadTimeout))
wrop := DialWriteTimeout(time.Duration(c.WriteTimeout))
auop := DialPassword(c.Auth)
// new conn
cn, err = Dial(c.Proto, c.Addr, cnop, rdop, wrop, auop)
return
}
func (c *conn) Close() error {
c.mu.Lock()
err := c.err
if c.err == nil {
c.err = errors.New("redigo: closed")
err = c.conn.Close()
}
c.mu.Unlock()
return err
}
func (c *conn) fatal(err error) error {
c.mu.Lock()
if c.err == nil {
c.err = err
// Close connection to force errors on subsequent calls and to unblock
// other reader or writer.
c.conn.Close()
}
c.mu.Unlock()
return errors.WithStack(c.err)
}
func (c *conn) Err() error {
c.mu.Lock()
err := c.err
c.mu.Unlock()
return err
}
func (c *conn) writeLen(prefix byte, n int) error {
c.lenScratch[len(c.lenScratch)-1] = '\n'
c.lenScratch[len(c.lenScratch)-2] = '\r'
i := len(c.lenScratch) - 3
for {
c.lenScratch[i] = byte('0' + n%10)
i--
n = n / 10
if n == 0 {
break
}
}
c.lenScratch[i] = prefix
_, err := c.bw.Write(c.lenScratch[i:])
return errors.WithStack(err)
}
func (c *conn) writeString(s string) error {
c.writeLen('$', len(s))
c.bw.WriteString(s)
_, err := c.bw.WriteString("\r\n")
return errors.WithStack(err)
}
func (c *conn) writeBytes(p []byte) error {
c.writeLen('$', len(p))
c.bw.Write(p)
_, err := c.bw.WriteString("\r\n")
return errors.WithStack(err)
}
func (c *conn) writeInt64(n int64) error {
return errors.WithStack(c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10)))
}
func (c *conn) writeFloat64(n float64) error {
return errors.WithStack(c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64)))
}
func (c *conn) writeCommand(cmd string, args []interface{}) (err error) {
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
c.writeLen('*', 1+len(args))
err = c.writeString(cmd)
for _, arg := range args {
if err != nil {
break
}
switch arg := arg.(type) {
case string:
err = c.writeString(arg)
case []byte:
err = c.writeBytes(arg)
case int:
err = c.writeInt64(int64(arg))
case int64:
err = c.writeInt64(arg)
case float64:
err = c.writeFloat64(arg)
case bool:
if arg {
err = c.writeString("1")
} else {
err = c.writeString("0")
}
case nil:
err = c.writeString("")
default:
var buf bytes.Buffer
fmt.Fprint(&buf, arg)
err = errors.WithStack(c.writeBytes(buf.Bytes()))
}
}
return err
}
type protocolError string
func (pe protocolError) Error() string {
return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe))
}
func (c *conn) readLine() ([]byte, error) {
p, err := c.br.ReadSlice('\n')
if err == bufio.ErrBufferFull {
return nil, errors.WithStack(protocolError("long response line"))
}
if err != nil {
return nil, err
}
i := len(p) - 2
if i < 0 || p[i] != '\r' {
return nil, errors.WithStack(protocolError("bad response line terminator"))
}
return p[:i], nil
}
// parseLen parses bulk string and array lengths.
func parseLen(p []byte) (int, error) {
if len(p) == 0 {
return -1, errors.WithStack(protocolError("malformed length"))
}
if p[0] == '-' && len(p) == 2 && p[1] == '1' {
// handle $-1 and $-1 null replies.
return -1, nil
}
var n int
for _, b := range p {
n *= 10
if b < '0' || b > '9' {
return -1, errors.WithStack(protocolError("illegal bytes in length"))
}
n += int(b - '0')
}
return n, nil
}
// parseInt parses an integer reply.
func parseInt(p []byte) (interface{}, error) {
if len(p) == 0 {
return 0, errors.WithStack(protocolError("malformed integer"))
}
var negate bool
if p[0] == '-' {
negate = true
p = p[1:]
if len(p) == 0 {
return 0, errors.WithStack(protocolError("malformed integer"))
}
}
var n int64
for _, b := range p {
n *= 10
if b < '0' || b > '9' {
return 0, errors.WithStack(protocolError("illegal bytes in length"))
}
n += int64(b - '0')
}
if negate {
n = -n
}
return n, nil
}
var (
okReply interface{} = "OK"
pongReply interface{} = "PONG"
)
func (c *conn) readReply() (interface{}, error) {
line, err := c.readLine()
if err != nil {
return nil, err
}
if len(line) == 0 {
return nil, errors.WithStack(protocolError("short response line"))
}
switch line[0] {
case '+':
switch {
case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
// Avoid allocation for frequent "+OK" response.
return okReply, nil
case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
// Avoid allocation in PING command benchmarks :)
return pongReply, nil
default:
return string(line[1:]), nil
}
case '-':
return Error(string(line[1:])), nil
case ':':
return parseInt(line[1:])
case '$':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return nil, err
}
p := make([]byte, n)
_, err = io.ReadFull(c.br, p)
if err != nil {
return nil, errors.WithStack(err)
}
if line1, err := c.readLine(); err != nil {
return nil, err
} else if len(line1) != 0 {
return nil, errors.WithStack(protocolError("bad bulk string format"))
}
return p, nil
case '*':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return nil, err
}
r := make([]interface{}, n)
for i := range r {
r[i], err = c.readReply()
if err != nil {
return nil, err
}
}
return r, nil
}
return nil, errors.WithStack(protocolError("unexpected response line"))
}
func (c *conn) Send(cmd string, args ...interface{}) (err error) {
c.mu.Lock()
c.pending++
c.mu.Unlock()
if err = c.writeCommand(cmd, args); err != nil {
c.fatal(err)
}
return err
}
func (c *conn) Flush() (err error) {
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if err = c.bw.Flush(); err != nil {
c.fatal(err)
}
return err
}
func (c *conn) Receive() (reply interface{}, err error) {
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
if reply, err = c.readReply(); err != nil {
return nil, c.fatal(err)
}
// When using pub/sub, the number of receives can be greater than the
// number of sends. To enable normal use of the connection after
// unsubscribing from all channels, we do not decrement pending to a
// negative value.
//
// The pending field is decremented after the reply is read to handle the
// case where Receive is called before Send.
c.mu.Lock()
if c.pending > 0 {
c.pending--
}
c.mu.Unlock()
if err, ok := reply.(Error); ok {
return nil, err
}
return
}
func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
c.mu.Lock()
pending := c.pending
c.pending = 0
c.mu.Unlock()
if cmd == "" && pending == 0 {
return nil, nil
}
var err error
defer c.stat(cmd, &err)()
if cmd != "" {
err = c.writeCommand(cmd, args)
}
if err == nil {
err = errors.WithStack(c.bw.Flush())
}
if err != nil {
return nil, c.fatal(err)
}
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
if cmd == "" {
reply := make([]interface{}, pending)
for i := range reply {
var r interface{}
r, err = c.readReply()
if err != nil {
break
}
reply[i] = r
}
if err != nil {
return nil, c.fatal(err)
}
return reply, nil
}
var reply interface{}
for i := 0; i <= pending; i++ {
var e error
if reply, e = c.readReply(); e != nil {
return nil, c.fatal(e)
}
if e, ok := reply.(Error); ok && err == nil {
err = e
}
}
return reply, err
}
// WithContext FIXME: implement WithContext
func (c *conn) WithContext(ctx context.Context) Conn { return c }

@ -0,0 +1,169 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// Package redis is a client for the Redis database.
//
// The Redigo FAQ (https://github.com/garyburd/redigo/wiki/FAQ) contains more
// documentation about this package.
//
// Connections
//
// The Conn interface is the primary interface for working with Redis.
// Applications create connections by calling the Dial, DialWithTimeout or
// NewConn functions. In the future, functions will be added for creating
// sharded and other types of connections.
//
// The application must call the connection Close method when the application
// is done with the connection.
//
// Executing Commands
//
// The Conn interface has a generic method for executing Redis commands:
//
// Do(commandName string, args ...interface{}) (reply interface{}, err error)
//
// The Redis command reference (http://redis.io/commands) lists the available
// commands. An example of using the Redis APPEND command is:
//
// n, err := conn.Do("APPEND", "key", "value")
//
// The Do method converts command arguments to binary strings for transmission
// to the server as follows:
//
// Go Type Conversion
// []byte Sent as is
// string Sent as is
// int, int64 strconv.FormatInt(v)
// float64 strconv.FormatFloat(v, 'g', -1, 64)
// bool true -> "1", false -> "0"
// nil ""
// all other types fmt.Print(v)
//
// Redis command reply types are represented using the following Go types:
//
// Redis type Go type
// error redis.Error
// integer int64
// simple string string
// bulk string []byte or nil if value not present.
// array []interface{} or nil if value not present.
//
// Use type assertions or the reply helper functions to convert from
// interface{} to the specific Go type for the command result.
//
// Pipelining
//
// Connections support pipelining using the Send, Flush and Receive methods.
//
// Send(commandName string, args ...interface{}) error
// Flush() error
// Receive() (reply interface{}, err error)
//
// Send writes the command to the connection's output buffer. Flush flushes the
// connection's output buffer to the server. Receive reads a single reply from
// the server. The following example shows a simple pipeline.
//
// c.Send("SET", "foo", "bar")
// c.Send("GET", "foo")
// c.Flush()
// c.Receive() // reply from SET
// v, err = c.Receive() // reply from GET
//
// The Do method combines the functionality of the Send, Flush and Receive
// methods. The Do method starts by writing the command and flushing the output
// buffer. Next, the Do method receives all pending replies including the reply
// for the command just sent by Do. If any of the received replies is an error,
// then Do returns the error. If there are no errors, then Do returns the last
// reply. If the command argument to the Do method is "", then the Do method
// will flush the output buffer and receive pending replies without sending a
// command.
//
// Use the Send and Do methods to implement pipelined transactions.
//
// c.Send("MULTI")
// c.Send("INCR", "foo")
// c.Send("INCR", "bar")
// r, err := c.Do("EXEC")
// fmt.Println(r) // prints [1, 1]
//
// Concurrency
//
// Connections do not support concurrent calls to the write methods (Send,
// Flush) or concurrent calls to the read method (Receive). Connections do
// allow a concurrent reader and writer.
//
// Because the Do method combines the functionality of Send, Flush and Receive,
// the Do method cannot be called concurrently with the other methods.
//
// For full concurrent access to Redis, use the thread-safe Pool to get and
// release connections from within a goroutine.
//
// Publish and Subscribe
//
// Use the Send, Flush and Receive methods to implement Pub/Sub subscribers.
//
// c.Send("SUBSCRIBE", "example")
// c.Flush()
// for {
// reply, err := c.Receive()
// if err != nil {
// return err
// }
// // process pushed message
// }
//
// The PubSubConn type wraps a Conn with convenience methods for implementing
// subscribers. The Subscribe, PSubscribe, Unsubscribe and PUnsubscribe methods
// send and flush a subscription management command. The receive method
// converts a pushed message to convenient types for use in a type switch.
//
// psc := redis.PubSubConn{c}
// psc.Subscribe("example")
// for {
// switch v := psc.Receive().(type) {
// case redis.Message:
// fmt.Printf("%s: message: %s\n", v.Channel, v.Data)
// case redis.Subscription:
// fmt.Printf("%s: %s %d\n", v.Channel, v.Kind, v.Count)
// case error:
// return v
// }
// }
//
// Reply Helpers
//
// The Bool, Int, Bytes, String, Strings and Values functions convert a reply
// to a value of a specific type. To allow convenient wrapping of calls to the
// connection Do and Receive methods, the functions take a second argument of
// type error. If the error is non-nil, then the helper function returns the
// error. If the error is nil, the function converts the reply to the specified
// type:
//
// exists, err := redis.Bool(c.Do("EXISTS", "foo"))
// if err != nil {
// // handle error return from c.Do or type conversion error.
// }
//
// The Scan function converts elements of a array reply to Go types:
//
// var value1 int
// var value2 string
// reply, err := redis.Values(c.Do("MGET", "key1", "key2"))
// if err != nil {
// // handle error
// }
// if _, err := redis.Scan(reply, &value1, &value2); err != nil {
// // handle error
// }
package redis

@ -0,0 +1,33 @@
package redis
import (
"strings"
pkgerr "github.com/pkg/errors"
)
func formatErr(err error) string {
e := pkgerr.Cause(err)
switch e {
case ErrNil, nil:
return ""
default:
es := e.Error()
switch {
case strings.HasPrefix(es, "read"):
return "read timeout"
case strings.HasPrefix(es, "dial"):
return "dial timeout"
case strings.HasPrefix(es, "write"):
return "write timeout"
case strings.Contains(es, "EOF"):
return "eof"
case strings.Contains(es, "reset"):
return "reset"
case strings.Contains(es, "broken"):
return "broken pipe"
default:
return "unexpected err"
}
}
}

@ -0,0 +1,117 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bytes"
"fmt"
"log"
)
// NewLoggingConn returns a logging wrapper around a connection.
func NewLoggingConn(conn Conn, logger *log.Logger, prefix string) Conn {
if prefix != "" {
prefix = prefix + "."
}
return &loggingConn{conn, logger, prefix}
}
type loggingConn struct {
Conn
logger *log.Logger
prefix string
}
func (c *loggingConn) Close() error {
err := c.Conn.Close()
var buf bytes.Buffer
fmt.Fprintf(&buf, "%sClose() -> (%v)", c.prefix, err)
c.logger.Output(2, buf.String())
return err
}
func (c *loggingConn) printValue(buf *bytes.Buffer, v interface{}) {
const chop = 32
switch v := v.(type) {
case []byte:
if len(v) > chop {
fmt.Fprintf(buf, "%q...", v[:chop])
} else {
fmt.Fprintf(buf, "%q", v)
}
case string:
if len(v) > chop {
fmt.Fprintf(buf, "%q...", v[:chop])
} else {
fmt.Fprintf(buf, "%q", v)
}
case []interface{}:
if len(v) == 0 {
buf.WriteString("[]")
} else {
sep := "["
fin := "]"
if len(v) > chop {
v = v[:chop]
fin = "...]"
}
for _, vv := range v {
buf.WriteString(sep)
c.printValue(buf, vv)
sep = ", "
}
buf.WriteString(fin)
}
default:
fmt.Fprint(buf, v)
}
}
func (c *loggingConn) print(method, commandName string, args []interface{}, reply interface{}, err error) {
var buf bytes.Buffer
fmt.Fprintf(&buf, "%s%s(", c.prefix, method)
if method != "Receive" {
buf.WriteString(commandName)
for _, arg := range args {
buf.WriteString(", ")
c.printValue(&buf, arg)
}
}
buf.WriteString(") -> (")
if method != "Send" {
c.printValue(&buf, reply)
buf.WriteString(", ")
}
fmt.Fprintf(&buf, "%v)", err)
c.logger.Output(3, buf.String())
}
func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{}, error) {
reply, err := c.Conn.Do(commandName, args...)
c.print("Do", commandName, args, reply, err)
return reply, err
}
func (c *loggingConn) Send(commandName string, args ...interface{}) error {
err := c.Conn.Send(commandName, args...)
c.print("Send", commandName, args, nil, err)
return err
}
func (c *loggingConn) Receive() (interface{}, error) {
reply, err := c.Conn.Receive()
c.print("Receive", "", nil, reply, err)
return reply, err
}

@ -0,0 +1,36 @@
package redis
import (
"context"
)
// MockErr for unit test.
type MockErr struct {
Error error
}
// MockWith return a mock conn.
func MockWith(err error) MockErr {
return MockErr{Error: err}
}
// Err .
func (m MockErr) Err() error { return m.Error }
// Close .
func (m MockErr) Close() error { return m.Error }
// Do .
func (m MockErr) Do(commandName string, args ...interface{}) (interface{}, error) { return nil, m.Error }
// Send .
func (m MockErr) Send(commandName string, args ...interface{}) error { return m.Error }
// Flush .
func (m MockErr) Flush() error { return m.Error }
// Receive .
func (m MockErr) Receive() (interface{}, error) { return nil, m.Error }
// WithContext .
func (m MockErr) WithContext(context.Context) Conn { return m }

@ -0,0 +1,218 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha1"
"errors"
"io"
"strconv"
"sync"
"time"
"github.com/bilibili/kratos/pkg/container/pool"
"github.com/bilibili/kratos/pkg/net/trace"
xtime "github.com/bilibili/kratos/pkg/time"
)
var beginTime, _ = time.Parse("2006-01-02 15:04:05", "2006-01-02 15:04:05")
var (
errConnClosed = errors.New("redigo: connection closed")
)
// Pool .
type Pool struct {
*pool.Slice
// config
c *Config
}
// Config client settings.
type Config struct {
*pool.Config
Name string // redis name, for trace
Proto string
Addr string
Auth string
DialTimeout xtime.Duration
ReadTimeout xtime.Duration
WriteTimeout xtime.Duration
}
// NewPool creates a new pool.
func NewPool(c *Config, options ...DialOption) (p *Pool) {
if c.DialTimeout <= 0 || c.ReadTimeout <= 0 || c.WriteTimeout <= 0 {
panic("must config redis timeout")
}
p1 := pool.NewSlice(c.Config)
cnop := DialConnectTimeout(time.Duration(c.DialTimeout))
options = append(options, cnop)
rdop := DialReadTimeout(time.Duration(c.ReadTimeout))
options = append(options, rdop)
wrop := DialWriteTimeout(time.Duration(c.WriteTimeout))
options = append(options, wrop)
auop := DialPassword(c.Auth)
options = append(options, auop)
// new pool
p1.New = func(ctx context.Context) (io.Closer, error) {
conn, err := Dial(c.Proto, c.Addr, options...)
if err != nil {
return nil, err
}
return &traceConn{Conn: conn, connTags: []trace.Tag{trace.TagString(trace.TagPeerAddress, c.Addr)}}, nil
}
p = &Pool{Slice: p1, c: c}
return
}
// Get gets a connection. The application must close the returned connection.
// This method always returns a valid connection so that applications can defer
// error handling to the first use of the connection. If there is an error
// getting an underlying connection, then the connection Err, Do, Send, Flush
// and Receive methods return that error.
func (p *Pool) Get(ctx context.Context) Conn {
c, err := p.Slice.Get(ctx)
if err != nil {
return errorConnection{err}
}
c1, _ := c.(Conn)
return &pooledConnection{p: p, c: c1.WithContext(ctx), ctx: ctx, now: beginTime}
}
// Close releases the resources used by the pool.
func (p *Pool) Close() error {
return p.Slice.Close()
}
type pooledConnection struct {
p *Pool
c Conn
state int
now time.Time
cmds []string
ctx context.Context
}
var (
sentinel []byte
sentinelOnce sync.Once
)
func initSentinel() {
p := make([]byte, 64)
if _, err := rand.Read(p); err == nil {
sentinel = p
} else {
h := sha1.New()
io.WriteString(h, "Oops, rand failed. Use time instead.")
io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10))
sentinel = h.Sum(nil)
}
}
func (pc *pooledConnection) Close() error {
c := pc.c
if _, ok := c.(errorConnection); ok {
return nil
}
pc.c = errorConnection{errConnClosed}
if pc.state&MultiState != 0 {
c.Send("DISCARD")
pc.state &^= (MultiState | WatchState)
} else if pc.state&WatchState != 0 {
c.Send("UNWATCH")
pc.state &^= WatchState
}
if pc.state&SubscribeState != 0 {
c.Send("UNSUBSCRIBE")
c.Send("PUNSUBSCRIBE")
// To detect the end of the message stream, ask the server to echo
// a sentinel value and read until we see that value.
sentinelOnce.Do(initSentinel)
c.Send("ECHO", sentinel)
c.Flush()
for {
p, err := c.Receive()
if err != nil {
break
}
if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) {
pc.state &^= SubscribeState
break
}
}
}
_, err := c.Do("")
pc.p.Slice.Put(context.Background(), c, pc.state != 0 || c.Err() != nil)
return err
}
func (pc *pooledConnection) Err() error {
return pc.c.Err()
}
func (pc *pooledConnection) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
ci := LookupCommandInfo(commandName)
pc.state = (pc.state | ci.Set) &^ ci.Clear
reply, err = pc.c.Do(commandName, args...)
return
}
func (pc *pooledConnection) Send(commandName string, args ...interface{}) (err error) {
ci := LookupCommandInfo(commandName)
pc.state = (pc.state | ci.Set) &^ ci.Clear
if pc.now.Equal(beginTime) {
// mark first send time
pc.now = time.Now()
}
pc.cmds = append(pc.cmds, commandName)
return pc.c.Send(commandName, args...)
}
func (pc *pooledConnection) Flush() error {
return pc.c.Flush()
}
func (pc *pooledConnection) Receive() (reply interface{}, err error) {
reply, err = pc.c.Receive()
if len(pc.cmds) > 0 {
pc.cmds = pc.cmds[1:]
}
return
}
func (pc *pooledConnection) WithContext(ctx context.Context) Conn {
pc.ctx = ctx
return pc
}
type errorConnection struct{ err error }
func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) {
return nil, ec.err
}
func (ec errorConnection) Send(string, ...interface{}) error { return ec.err }
func (ec errorConnection) Err() error { return ec.err }
func (ec errorConnection) Close() error { return ec.err }
func (ec errorConnection) Flush() error { return ec.err }
func (ec errorConnection) Receive() (interface{}, error) { return nil, ec.err }
func (ec errorConnection) WithContext(context.Context) Conn { return ec }

@ -0,0 +1,152 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"errors"
pkgerr "github.com/pkg/errors"
)
var (
errPubSub = errors.New("redigo: unknown pubsub notification")
)
// Subscription represents a subscribe or unsubscribe notification.
type Subscription struct {
// Kind is "subscribe", "unsubscribe", "psubscribe" or "punsubscribe"
Kind string
// The channel that was changed.
Channel string
// The current number of subscriptions for connection.
Count int
}
// Message represents a message notification.
type Message struct {
// The originating channel.
Channel string
// The message data.
Data []byte
}
// PMessage represents a pmessage notification.
type PMessage struct {
// The matched pattern.
Pattern string
// The originating channel.
Channel string
// The message data.
Data []byte
}
// Pong represents a pubsub pong notification.
type Pong struct {
Data string
}
// PubSubConn wraps a Conn with convenience methods for subscribers.
type PubSubConn struct {
Conn Conn
}
// Close closes the connection.
func (c PubSubConn) Close() error {
return c.Conn.Close()
}
// Subscribe subscribes the connection to the specified channels.
func (c PubSubConn) Subscribe(channel ...interface{}) error {
c.Conn.Send("SUBSCRIBE", channel...)
return c.Conn.Flush()
}
// PSubscribe subscribes the connection to the given patterns.
func (c PubSubConn) PSubscribe(channel ...interface{}) error {
c.Conn.Send("PSUBSCRIBE", channel...)
return c.Conn.Flush()
}
// Unsubscribe unsubscribes the connection from the given channels, or from all
// of them if none is given.
func (c PubSubConn) Unsubscribe(channel ...interface{}) error {
c.Conn.Send("UNSUBSCRIBE", channel...)
return c.Conn.Flush()
}
// PUnsubscribe unsubscribes the connection from the given patterns, or from all
// of them if none is given.
func (c PubSubConn) PUnsubscribe(channel ...interface{}) error {
c.Conn.Send("PUNSUBSCRIBE", channel...)
return c.Conn.Flush()
}
// Ping sends a PING to the server with the specified data.
func (c PubSubConn) Ping(data string) error {
c.Conn.Send("PING", data)
return c.Conn.Flush()
}
// Receive returns a pushed message as a Subscription, Message, PMessage, Pong
// or error. The return value is intended to be used directly in a type switch
// as illustrated in the PubSubConn example.
func (c PubSubConn) Receive() interface{} {
reply, err := Values(c.Conn.Receive())
if err != nil {
return err
}
var kind string
reply, err = Scan(reply, &kind)
if err != nil {
return err
}
switch kind {
case "message":
var m Message
if _, err := Scan(reply, &m.Channel, &m.Data); err != nil {
return err
}
return m
case "pmessage":
var pm PMessage
if _, err := Scan(reply, &pm.Pattern, &pm.Channel, &pm.Data); err != nil {
return err
}
return pm
case "subscribe", "psubscribe", "unsubscribe", "punsubscribe":
s := Subscription{Kind: kind}
if _, err := Scan(reply, &s.Channel, &s.Count); err != nil {
return err
}
return s
case "pong":
var p Pong
if _, err := Scan(reply, &p.Data); err != nil {
return err
}
return p
}
return pkgerr.WithStack(errPubSub)
}

@ -0,0 +1,51 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"context"
)
// Error represents an error returned in a command reply.
type Error string
func (err Error) Error() string { return string(err) }
// Conn represents a connection to a Redis server.
type Conn interface {
// Close closes the connection.
Close() error
// Err returns a non-nil value if the connection is broken. The returned
// value is either the first non-nil value returned from the underlying
// network connection or a protocol parsing error. Applications should
// close broken connections.
Err() error
// Do sends a command to the server and returns the received reply.
Do(commandName string, args ...interface{}) (reply interface{}, err error)
// Send writes the command to the client's output buffer.
Send(commandName string, args ...interface{}) error
// Flush flushes the output buffer to the Redis server.
Flush() error
// Receive receives a single reply from the Redis server
Receive() (reply interface{}, err error)
// WithContext
WithContext(ctx context.Context) Conn
}

@ -0,0 +1,409 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"errors"
"strconv"
pkgerr "github.com/pkg/errors"
)
// ErrNil indicates that a reply value is nil.
var ErrNil = errors.New("redigo: nil returned")
// Int is a helper that converts a command reply to an integer. If err is not
// equal to nil, then Int returns 0, err. Otherwise, Int converts the
// reply to an int as follows:
//
// Reply type Result
// integer int(reply), nil
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Int(reply interface{}, err error) (int, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case int64:
x := int(reply)
if int64(x) != reply {
return 0, pkgerr.WithStack(strconv.ErrRange)
}
return x, nil
case []byte:
n, err := strconv.ParseInt(string(reply), 10, 0)
return int(n), pkgerr.WithStack(err)
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, pkgerr.Errorf("redigo: unexpected type for Int, got type %T", reply)
}
// Int64 is a helper that converts a command reply to 64 bit integer. If err is
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
// reply to an int64 as follows:
//
// Reply type Result
// integer reply, nil
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Int64(reply interface{}, err error) (int64, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case int64:
return reply, nil
case []byte:
n, err := strconv.ParseInt(string(reply), 10, 64)
return n, pkgerr.WithStack(err)
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, pkgerr.Errorf("redigo: unexpected type for Int64, got type %T", reply)
}
var errNegativeInt = errors.New("redigo: unexpected value for Uint64")
// Uint64 is a helper that converts a command reply to 64 bit integer. If err is
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
// reply to an int64 as follows:
//
// Reply type Result
// integer reply, nil
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Uint64(reply interface{}, err error) (uint64, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case int64:
if reply < 0 {
return 0, pkgerr.WithStack(errNegativeInt)
}
return uint64(reply), nil
case []byte:
n, err := strconv.ParseUint(string(reply), 10, 64)
return n, err
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, pkgerr.Errorf("redigo: unexpected type for Uint64, got type %T", reply)
}
// Float64 is a helper that converts a command reply to 64 bit float. If err is
// not equal to nil, then Float64 returns 0, err. Otherwise, Float64 converts
// the reply to an int as follows:
//
// Reply type Result
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Float64(reply interface{}, err error) (float64, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case []byte:
n, err := strconv.ParseFloat(string(reply), 64)
return n, pkgerr.WithStack(err)
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, pkgerr.Errorf("redigo: unexpected type for Float64, got type %T", reply)
}
// String is a helper that converts a command reply to a string. If err is not
// equal to nil, then String returns "", err. Otherwise String converts the
// reply to a string as follows:
//
// Reply type Result
// bulk string string(reply), nil
// simple string reply, nil
// nil "", ErrNil
// other "", error
func String(reply interface{}, err error) (string, error) {
if err != nil {
return "", err
}
switch reply := reply.(type) {
case []byte:
return string(reply), nil
case string:
return reply, nil
case nil:
return "", ErrNil
case Error:
return "", reply
}
return "", pkgerr.Errorf("redigo: unexpected type for String, got type %T", reply)
}
// Bytes is a helper that converts a command reply to a slice of bytes. If err
// is not equal to nil, then Bytes returns nil, err. Otherwise Bytes converts
// the reply to a slice of bytes as follows:
//
// Reply type Result
// bulk string reply, nil
// simple string []byte(reply), nil
// nil nil, ErrNil
// other nil, error
func Bytes(reply interface{}, err error) ([]byte, error) {
if err != nil {
return nil, err
}
switch reply := reply.(type) {
case []byte:
return reply, nil
case string:
return []byte(reply), nil
case nil:
return nil, ErrNil
case Error:
return nil, reply
}
return nil, pkgerr.Errorf("redigo: unexpected type for Bytes, got type %T", reply)
}
// Bool is a helper that converts a command reply to a boolean. If err is not
// equal to nil, then Bool returns false, err. Otherwise Bool converts the
// reply to boolean as follows:
//
// Reply type Result
// integer value != 0, nil
// bulk string strconv.ParseBool(reply)
// nil false, ErrNil
// other false, error
func Bool(reply interface{}, err error) (bool, error) {
if err != nil {
return false, err
}
switch reply := reply.(type) {
case int64:
return reply != 0, nil
case []byte:
b, e := strconv.ParseBool(string(reply))
return b, pkgerr.WithStack(e)
case nil:
return false, ErrNil
case Error:
return false, reply
}
return false, pkgerr.Errorf("redigo: unexpected type for Bool, got type %T", reply)
}
// MultiBulk is a helper that converts an array command reply to a []interface{}.
//
// Deprecated: Use Values instead.
func MultiBulk(reply interface{}, err error) ([]interface{}, error) { return Values(reply, err) }
// Values is a helper that converts an array command reply to a []interface{}.
// If err is not equal to nil, then Values returns nil, err. Otherwise, Values
// converts the reply as follows:
//
// Reply type Result
// array reply, nil
// nil nil, ErrNil
// other nil, error
func Values(reply interface{}, err error) ([]interface{}, error) {
if err != nil {
return nil, err
}
switch reply := reply.(type) {
case []interface{}:
return reply, nil
case nil:
return nil, ErrNil
case Error:
return nil, reply
}
return nil, pkgerr.Errorf("redigo: unexpected type for Values, got type %T", reply)
}
// Strings is a helper that converts an array command reply to a []string. If
// err is not equal to nil, then Strings returns nil, err. Nil array items are
// converted to "" in the output slice. Strings returns an error if an array
// item is not a bulk string or nil.
func Strings(reply interface{}, err error) ([]string, error) {
if err != nil {
return nil, err
}
switch reply := reply.(type) {
case []interface{}:
result := make([]string, len(reply))
for i := range reply {
if reply[i] == nil {
continue
}
p, ok := reply[i].([]byte)
if !ok {
return nil, pkgerr.Errorf("redigo: unexpected element type for Strings, got type %T", reply[i])
}
result[i] = string(p)
}
return result, nil
case nil:
return nil, ErrNil
case Error:
return nil, reply
}
return nil, pkgerr.Errorf("redigo: unexpected type for Strings, got type %T", reply)
}
// ByteSlices is a helper that converts an array command reply to a [][]byte.
// If err is not equal to nil, then ByteSlices returns nil, err. Nil array
// items are stay nil. ByteSlices returns an error if an array item is not a
// bulk string or nil.
func ByteSlices(reply interface{}, err error) ([][]byte, error) {
if err != nil {
return nil, err
}
switch reply := reply.(type) {
case []interface{}:
result := make([][]byte, len(reply))
for i := range reply {
if reply[i] == nil {
continue
}
p, ok := reply[i].([]byte)
if !ok {
return nil, pkgerr.Errorf("redigo: unexpected element type for ByteSlices, got type %T", reply[i])
}
result[i] = p
}
return result, nil
case nil:
return nil, ErrNil
case Error:
return nil, reply
}
return nil, pkgerr.Errorf("redigo: unexpected type for ByteSlices, got type %T", reply)
}
// Ints is a helper that converts an array command reply to a []int. If
// err is not equal to nil, then Ints returns nil, err.
func Ints(reply interface{}, err error) ([]int, error) {
var ints []int
values, err := Values(reply, err)
if err != nil {
return ints, err
}
if err := ScanSlice(values, &ints); err != nil {
return ints, err
}
return ints, nil
}
// Int64s is a helper that converts an array command reply to a []int64. If
// err is not equal to nil, then Int64s returns nil, err.
func Int64s(reply interface{}, err error) ([]int64, error) {
var int64s []int64
values, err := Values(reply, err)
if err != nil {
return int64s, err
}
if err := ScanSlice(values, &int64s); err != nil {
return int64s, err
}
return int64s, nil
}
// StringMap is a helper that converts an array of strings (alternating key, value)
// into a map[string]string. The HGETALL and CONFIG GET commands return replies in this format.
// Requires an even number of values in result.
func StringMap(result interface{}, err error) (map[string]string, error) {
values, err := Values(result, err)
if err != nil {
return nil, err
}
if len(values)%2 != 0 {
return nil, pkgerr.New("redigo: StringMap expects even number of values result")
}
m := make(map[string]string, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, okKey := values[i].([]byte)
value, okValue := values[i+1].([]byte)
if !okKey || !okValue {
return nil, pkgerr.New("redigo: ScanMap key not a bulk string value")
}
m[string(key)] = string(value)
}
return m, nil
}
// IntMap is a helper that converts an array of strings (alternating key, value)
// into a map[string]int. The HGETALL commands return replies in this format.
// Requires an even number of values in result.
func IntMap(result interface{}, err error) (map[string]int, error) {
values, err := Values(result, err)
if err != nil {
return nil, err
}
if len(values)%2 != 0 {
return nil, pkgerr.New("redigo: IntMap expects even number of values result")
}
m := make(map[string]int, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, ok := values[i].([]byte)
if !ok {
return nil, pkgerr.New("redigo: ScanMap key not a bulk string value")
}
value, err := Int(values[i+1], nil)
if err != nil {
return nil, err
}
m[string(key)] = value
}
return m, nil
}
// Int64Map is a helper that converts an array of strings (alternating key, value)
// into a map[string]int64. The HGETALL commands return replies in this format.
// Requires an even number of values in result.
func Int64Map(result interface{}, err error) (map[string]int64, error) {
values, err := Values(result, err)
if err != nil {
return nil, err
}
if len(values)%2 != 0 {
return nil, pkgerr.New("redigo: Int64Map expects even number of values result")
}
m := make(map[string]int64, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, ok := values[i].([]byte)
if !ok {
return nil, pkgerr.New("redigo: ScanMap key not a bulk string value")
}
value, err := Int64(values[i+1], nil)
if err != nil {
return nil, err
}
m[string(key)] = value
}
return m, nil
}

@ -0,0 +1,559 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
pkgerr "github.com/pkg/errors"
)
func ensureLen(d reflect.Value, n int) {
if n > d.Cap() {
d.Set(reflect.MakeSlice(d.Type(), n, n))
} else {
d.SetLen(n)
}
}
func cannotConvert(d reflect.Value, s interface{}) error {
var sname string
switch s.(type) {
case string:
sname = "Redis simple string"
case Error:
sname = "Redis error"
case int64:
sname = "Redis integer"
case []byte:
sname = "Redis bulk string"
case []interface{}:
sname = "Redis array"
default:
sname = reflect.TypeOf(s).String()
}
return pkgerr.Errorf("cannot convert from %s to %s", sname, d.Type())
}
func convertAssignBulkString(d reflect.Value, s []byte) (err error) {
switch d.Type().Kind() {
case reflect.Float32, reflect.Float64:
var x float64
x, err = strconv.ParseFloat(string(s), d.Type().Bits())
d.SetFloat(x)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var x int64
x, err = strconv.ParseInt(string(s), 10, d.Type().Bits())
d.SetInt(x)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
var x uint64
x, err = strconv.ParseUint(string(s), 10, d.Type().Bits())
d.SetUint(x)
case reflect.Bool:
var x bool
x, err = strconv.ParseBool(string(s))
d.SetBool(x)
case reflect.String:
d.SetString(string(s))
case reflect.Slice:
if d.Type().Elem().Kind() != reflect.Uint8 {
err = cannotConvert(d, s)
} else {
d.SetBytes(s)
}
default:
err = cannotConvert(d, s)
}
err = pkgerr.WithStack(err)
return
}
func convertAssignInt(d reflect.Value, s int64) (err error) {
switch d.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
d.SetInt(s)
if d.Int() != s {
err = strconv.ErrRange
d.SetInt(0)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if s < 0 {
err = strconv.ErrRange
} else {
x := uint64(s)
d.SetUint(x)
if d.Uint() != x {
err = strconv.ErrRange
d.SetUint(0)
}
}
case reflect.Bool:
d.SetBool(s != 0)
default:
err = cannotConvert(d, s)
}
err = pkgerr.WithStack(err)
return
}
func convertAssignValue(d reflect.Value, s interface{}) (err error) {
switch s := s.(type) {
case []byte:
err = convertAssignBulkString(d, s)
case int64:
err = convertAssignInt(d, s)
default:
err = cannotConvert(d, s)
}
return err
}
func convertAssignArray(d reflect.Value, s []interface{}) error {
if d.Type().Kind() != reflect.Slice {
return cannotConvert(d, s)
}
ensureLen(d, len(s))
for i := 0; i < len(s); i++ {
if err := convertAssignValue(d.Index(i), s[i]); err != nil {
return err
}
}
return nil
}
func convertAssign(d interface{}, s interface{}) (err error) {
// Handle the most common destination types using type switches and
// fall back to reflection for all other types.
switch s := s.(type) {
case nil:
// ingore
case []byte:
switch d := d.(type) {
case *string:
*d = string(s)
case *int:
*d, err = strconv.Atoi(string(s))
case *bool:
*d, err = strconv.ParseBool(string(s))
case *[]byte:
*d = s
case *interface{}:
*d = s
case nil:
// skip value
default:
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
err = cannotConvert(d, s)
} else {
err = convertAssignBulkString(d.Elem(), s)
}
}
case int64:
switch d := d.(type) {
case *int:
x := int(s)
if int64(x) != s {
err = strconv.ErrRange
x = 0
}
*d = x
case *bool:
*d = s != 0
case *interface{}:
*d = s
case nil:
// skip value
default:
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
err = cannotConvert(d, s)
} else {
err = convertAssignInt(d.Elem(), s)
}
}
case string:
switch d := d.(type) {
case *string:
*d = string(s)
default:
err = cannotConvert(reflect.ValueOf(d), s)
}
case []interface{}:
switch d := d.(type) {
case *[]interface{}:
*d = s
case *interface{}:
*d = s
case nil:
// skip value
default:
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
err = cannotConvert(d, s)
} else {
err = convertAssignArray(d.Elem(), s)
}
}
case Error:
err = s
default:
err = cannotConvert(reflect.ValueOf(d), s)
}
err = pkgerr.WithStack(err)
return
}
// Scan copies from src to the values pointed at by dest.
//
// The values pointed at by dest must be an integer, float, boolean, string,
// []byte, interface{} or slices of these types. Scan uses the standard strconv
// package to convert bulk strings to numeric and boolean types.
//
// If a dest value is nil, then the corresponding src value is skipped.
//
// If a src element is nil, then the corresponding dest value is not modified.
//
// To enable easy use of Scan in a loop, Scan returns the slice of src
// following the copied values.
func Scan(src []interface{}, dest ...interface{}) ([]interface{}, error) {
if len(src) < len(dest) {
return nil, pkgerr.New("redigo.Scan: array short")
}
var err error
for i, d := range dest {
err = convertAssign(d, src[i])
if err != nil {
err = fmt.Errorf("redigo.Scan: cannot assign to dest %d: %v", i, err)
break
}
}
return src[len(dest):], err
}
type fieldSpec struct {
name string
index []int
omitEmpty bool
}
type structSpec struct {
m map[string]*fieldSpec
l []*fieldSpec
}
func (ss *structSpec) fieldSpec(name []byte) *fieldSpec {
return ss.m[string(name)]
}
func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) {
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
switch {
case f.PkgPath != "" && !f.Anonymous:
// Ignore unexported fields.
case f.Anonymous:
// TODO: Handle pointers. Requires change to decoder and
// protection against infinite recursion.
if f.Type.Kind() == reflect.Struct {
compileStructSpec(f.Type, depth, append(index, i), ss)
}
default:
fs := &fieldSpec{name: f.Name}
tag := f.Tag.Get("redis")
p := strings.Split(tag, ",")
if len(p) > 0 {
if p[0] == "-" {
continue
}
if len(p[0]) > 0 {
fs.name = p[0]
}
for _, s := range p[1:] {
switch s {
case "omitempty":
fs.omitEmpty = true
default:
panic(fmt.Errorf("redigo: unknown field tag %s for type %s", s, t.Name()))
}
}
}
d, found := depth[fs.name]
if !found {
d = 1 << 30
}
switch {
case len(index) == d:
// At same depth, remove from result.
delete(ss.m, fs.name)
j := 0
for i1 := 0; i1 < len(ss.l); i1++ {
if fs.name != ss.l[i1].name {
ss.l[j] = ss.l[i1]
j++
}
}
ss.l = ss.l[:j]
case len(index) < d:
fs.index = make([]int, len(index)+1)
copy(fs.index, index)
fs.index[len(index)] = i
depth[fs.name] = len(index)
ss.m[fs.name] = fs
ss.l = append(ss.l, fs)
}
}
}
}
var (
structSpecMutex sync.RWMutex
structSpecCache = make(map[reflect.Type]*structSpec)
)
func structSpecForType(t reflect.Type) *structSpec {
structSpecMutex.RLock()
ss, found := structSpecCache[t]
structSpecMutex.RUnlock()
if found {
return ss
}
structSpecMutex.Lock()
defer structSpecMutex.Unlock()
ss, found = structSpecCache[t]
if found {
return ss
}
ss = &structSpec{m: make(map[string]*fieldSpec)}
compileStructSpec(t, make(map[string]int), nil, ss)
structSpecCache[t] = ss
return ss
}
var errScanStructValue = errors.New("redigo.ScanStruct: value must be non-nil pointer to a struct")
// ScanStruct scans alternating names and values from src to a struct. The
// HGETALL and CONFIG GET commands return replies in this format.
//
// ScanStruct uses exported field names to match values in the response. Use
// 'redis' field tag to override the name:
//
// Field int `redis:"myName"`
//
// Fields with the tag redis:"-" are ignored.
//
// Integer, float, boolean, string and []byte fields are supported. Scan uses the
// standard strconv package to convert bulk string values to numeric and
// boolean types.
//
// If a src element is nil, then the corresponding field is not modified.
func ScanStruct(src []interface{}, dest interface{}) error {
d := reflect.ValueOf(dest)
if d.Kind() != reflect.Ptr || d.IsNil() {
return pkgerr.WithStack(errScanStructValue)
}
d = d.Elem()
if d.Kind() != reflect.Struct {
return pkgerr.WithStack(errScanStructValue)
}
ss := structSpecForType(d.Type())
if len(src)%2 != 0 {
return pkgerr.New("redigo.ScanStruct: number of values not a multiple of 2")
}
for i := 0; i < len(src); i += 2 {
s := src[i+1]
if s == nil {
continue
}
name, ok := src[i].([]byte)
if !ok {
return pkgerr.Errorf("redigo.ScanStruct: key %d not a bulk string value", i)
}
fs := ss.fieldSpec(name)
if fs == nil {
continue
}
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil {
return pkgerr.Errorf("redigo.ScanStruct: cannot assign field %s: %v", fs.name, err)
}
}
return nil
}
var (
errScanSliceValue = errors.New("redigo.ScanSlice: dest must be non-nil pointer to a struct")
)
// ScanSlice scans src to the slice pointed to by dest. The elements the dest
// slice must be integer, float, boolean, string, struct or pointer to struct
// values.
//
// Struct fields must be integer, float, boolean or string values. All struct
// fields are used unless a subset is specified using fieldNames.
func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error {
d := reflect.ValueOf(dest)
if d.Kind() != reflect.Ptr || d.IsNil() {
return pkgerr.WithStack(errScanSliceValue)
}
d = d.Elem()
if d.Kind() != reflect.Slice {
return pkgerr.WithStack(errScanSliceValue)
}
isPtr := false
t := d.Type().Elem()
if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct {
isPtr = true
t = t.Elem()
}
if t.Kind() != reflect.Struct {
ensureLen(d, len(src))
for i, s := range src {
if s == nil {
continue
}
if err := convertAssignValue(d.Index(i), s); err != nil {
return pkgerr.Errorf("redigo.ScanSlice: cannot assign element %d: %v", i, err)
}
}
return nil
}
ss := structSpecForType(t)
fss := ss.l
if len(fieldNames) > 0 {
fss = make([]*fieldSpec, len(fieldNames))
for i, name := range fieldNames {
fss[i] = ss.m[name]
if fss[i] == nil {
return pkgerr.Errorf("redigo.ScanSlice: ScanSlice bad field name %s", name)
}
}
}
if len(fss) == 0 {
return pkgerr.New("redigo.ScanSlice: no struct fields")
}
n := len(src) / len(fss)
if n*len(fss) != len(src) {
return pkgerr.New("redigo.ScanSlice: length not a multiple of struct field count")
}
ensureLen(d, n)
for i := 0; i < n; i++ {
d1 := d.Index(i)
if isPtr {
if d1.IsNil() {
d1.Set(reflect.New(t))
}
d1 = d1.Elem()
}
for j, fs := range fss {
s := src[i*len(fss)+j]
if s == nil {
continue
}
if err := convertAssignValue(d1.FieldByIndex(fs.index), s); err != nil {
return pkgerr.Errorf("redigo.ScanSlice: cannot assign element %d to field %s: %v", i*len(fss)+j, fs.name, err)
}
}
}
return nil
}
// Args is a helper for constructing command arguments from structured values.
type Args []interface{}
// Add returns the result of appending value to args.
func (args Args) Add(value ...interface{}) Args {
return append(args, value...)
}
// AddFlat returns the result of appending the flattened value of v to args.
//
// Maps are flattened by appending the alternating keys and map values to args.
//
// Slices are flattened by appending the slice elements to args.
//
// Structs are flattened by appending the alternating names and values of
// exported fields to args. If v is a nil struct pointer, then nothing is
// appended. The 'redis' field tag overrides struct field names. See ScanStruct
// for more information on the use of the 'redis' field tag.
//
// Other types are appended to args as is.
func (args Args) AddFlat(v interface{}) Args {
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Struct:
args = flattenStruct(args, rv)
case reflect.Slice:
for i := 0; i < rv.Len(); i++ {
args = append(args, rv.Index(i).Interface())
}
case reflect.Map:
for _, k := range rv.MapKeys() {
args = append(args, k.Interface(), rv.MapIndex(k).Interface())
}
case reflect.Ptr:
if rv.Type().Elem().Kind() == reflect.Struct {
if !rv.IsNil() {
args = flattenStruct(args, rv.Elem())
}
} else {
args = append(args, v)
}
default:
args = append(args, v)
}
return args
}
func flattenStruct(args Args, v reflect.Value) Args {
ss := structSpecForType(v.Type())
for _, fs := range ss.l {
fv := v.FieldByIndex(fs.index)
if fs.omitEmpty {
var empty = false
switch fv.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
empty = fv.Len() == 0
case reflect.Bool:
empty = !fv.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
empty = fv.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
empty = fv.Uint() == 0
case reflect.Float32, reflect.Float64:
empty = fv.Float() == 0
case reflect.Interface, reflect.Ptr:
empty = fv.IsNil()
}
if empty {
continue
}
}
args = append(args, fs.name, fv.Interface())
}
return args
}

@ -0,0 +1,86 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"crypto/sha1"
"encoding/hex"
"io"
"strings"
)
// Script encapsulates the source, hash and key count for a Lua script. See
// http://redis.io/commands/eval for information on scripts in Redis.
type Script struct {
keyCount int
src string
hash string
}
// NewScript returns a new script object. If keyCount is greater than or equal
// to zero, then the count is automatically inserted in the EVAL command
// argument list. If keyCount is less than zero, then the application supplies
// the count as the first value in the keysAndArgs argument to the Do, Send and
// SendHash methods.
func NewScript(keyCount int, src string) *Script {
h := sha1.New()
io.WriteString(h, src)
return &Script{keyCount, src, hex.EncodeToString(h.Sum(nil))}
}
func (s *Script) args(spec string, keysAndArgs []interface{}) []interface{} {
var args []interface{}
if s.keyCount < 0 {
args = make([]interface{}, 1+len(keysAndArgs))
args[0] = spec
copy(args[1:], keysAndArgs)
} else {
args = make([]interface{}, 2+len(keysAndArgs))
args[0] = spec
args[1] = s.keyCount
copy(args[2:], keysAndArgs)
}
return args
}
// Do evaluates the script. Under the covers, Do optimistically evaluates the
// script using the EVALSHA command. If the command fails because the script is
// not loaded, then Do evaluates the script using the EVAL command (thus
// causing the script to load).
func (s *Script) Do(c Conn, keysAndArgs ...interface{}) (interface{}, error) {
v, err := c.Do("EVALSHA", s.args(s.hash, keysAndArgs)...)
if e, ok := err.(Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") {
v, err = c.Do("EVAL", s.args(s.src, keysAndArgs)...)
}
return v, err
}
// SendHash evaluates the script without waiting for the reply. The script is
// evaluated with the EVALSHA command. The application must ensure that the
// script is loaded by a previous call to Send, Do or Load methods.
func (s *Script) SendHash(c Conn, keysAndArgs ...interface{}) error {
return c.Send("EVALSHA", s.args(s.hash, keysAndArgs)...)
}
// Send evaluates the script without waiting for the reply.
func (s *Script) Send(c Conn, keysAndArgs ...interface{}) error {
return c.Send("EVAL", s.args(s.src, keysAndArgs)...)
}
// Load loads the script without evaluating it.
func (s *Script) Load(c Conn) error {
_, err := c.Do("SCRIPT", "LOAD", s.src)
return err
}

@ -0,0 +1,142 @@
package redis
import (
"context"
"fmt"
"time"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/net/trace"
)
const (
_traceComponentName = "pkg/cache/redis"
_tracePeerService = "redis"
_traceSpanKind = "client"
_slowLogDuration = time.Millisecond * 250
)
var _internalTags = []trace.Tag{
trace.TagString(trace.TagSpanKind, _traceSpanKind),
trace.TagString(trace.TagComponent, _traceComponentName),
trace.TagString(trace.TagPeerService, _tracePeerService),
}
type traceConn struct {
// tr for pipeline, if tr != nil meaning on pipeline
tr trace.Trace
ctx context.Context
// connTag include e.g. ip,port
connTags []trace.Tag
// origin redis conn
Conn
pending int
}
func (t *traceConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
statement := getStatement(commandName, args...)
defer slowLog(statement, time.Now())
root, ok := trace.FromContext(t.ctx)
// NOTE: ignored empty commandName
// current sdk will Do empty command after pipeline finished
if !ok || commandName == "" {
return t.Conn.Do(commandName, args...)
}
tr := root.Fork("", "Redis:"+commandName)
tr.SetTag(_internalTags...)
tr.SetTag(t.connTags...)
tr.SetTag(trace.TagString(trace.TagDBStatement, statement))
reply, err = t.Conn.Do(commandName, args...)
tr.Finish(&err)
return
}
func (t *traceConn) Send(commandName string, args ...interface{}) error {
statement := getStatement(commandName, args...)
defer slowLog(statement, time.Now())
t.pending++
root, ok := trace.FromContext(t.ctx)
if !ok {
return t.Conn.Send(commandName, args...)
}
if t.tr == nil {
t.tr = root.Fork("", "Redis:Pipeline")
t.tr.SetTag(_internalTags...)
t.tr.SetTag(t.connTags...)
}
t.tr.SetLog(
trace.Log(trace.LogEvent, "Send"),
trace.Log("db.statement", statement),
)
err := t.Conn.Send(commandName, args...)
if err != nil {
t.tr.SetTag(trace.TagBool(trace.TagError, true))
t.tr.SetLog(
trace.Log(trace.LogEvent, "Send Fail"),
trace.Log(trace.LogMessage, err.Error()),
)
}
return err
}
func (t *traceConn) Flush() error {
defer slowLog("Flush", time.Now())
if t.tr == nil {
return t.Conn.Flush()
}
t.tr.SetLog(trace.Log(trace.LogEvent, "Flush"))
err := t.Conn.Flush()
if err != nil {
t.tr.SetTag(trace.TagBool(trace.TagError, true))
t.tr.SetLog(
trace.Log(trace.LogEvent, "Flush Fail"),
trace.Log(trace.LogMessage, err.Error()),
)
}
return err
}
func (t *traceConn) Receive() (reply interface{}, err error) {
defer slowLog("Receive", time.Now())
if t.tr == nil {
return t.Conn.Receive()
}
t.tr.SetLog(trace.Log(trace.LogEvent, "Receive"))
reply, err = t.Conn.Receive()
if err != nil {
t.tr.SetTag(trace.TagBool(trace.TagError, true))
t.tr.SetLog(
trace.Log(trace.LogEvent, "Receive Fail"),
trace.Log(trace.LogMessage, err.Error()),
)
}
if t.pending > 0 {
t.pending--
}
if t.pending == 0 {
t.tr.Finish(nil)
t.tr = nil
}
return reply, err
}
func (t *traceConn) WithContext(ctx context.Context) Conn {
t.ctx = ctx
return t
}
func slowLog(statement string, now time.Time) {
du := time.Since(now)
if du > _slowLogDuration {
log.Warn("%s slow log statement: %s time: %v", _tracePeerService, statement, du)
}
}
func getStatement(commandName string, args ...interface{}) (res string) {
res = commandName
if len(args) > 0 {
res = fmt.Sprintf("%s %v", commandName, args[0])
}
return
}

@ -0,0 +1,40 @@
### database/hbase
### 项目简介
Hbase Client,进行封装加入了链路追踪和统计。
### usage
```go
package main
import (
"context"
"fmt"
"github.com/bilibili/kratos/pkg/database/hbase"
)
func main() {
config := &hbase.Config{Zookeeper: &hbase.ZKConfig{Addrs: []string{"localhost"}}}
client := hbase.NewClient(config)
values := map[string]map[string][]byte{"name": {"firstname": []byte("hello"), "lastname": []byte("world")}}
ctx := context.Background()
_, err := client.PutStr(ctx, "user", "user1", values)
if err != nil {
panic(err)
}
result, err := client.GetStr(ctx, "user", "user1")
if err != nil {
panic(err)
}
fmt.Printf("%v", result)
}
```
##### 依赖包
1.[gohbase](https://github.com/tsuna/gohbase)

@ -0,0 +1,23 @@
package hbase
import (
xtime "github.com/bilibili/kratos/pkg/time"
)
// ZKConfig Server&Client settings.
type ZKConfig struct {
Root string
Addrs []string
Timeout xtime.Duration
}
// Config hbase config
type Config struct {
Zookeeper *ZKConfig
RPCQueueSize int
FlushInterval xtime.Duration
EffectiveUser string
RegionLookupTimeout xtime.Duration
RegionReadTimeout xtime.Duration
TestRowKey string
}

@ -0,0 +1,297 @@
package hbase
import (
"context"
"io"
"strings"
"time"
"github.com/tsuna/gohbase"
"github.com/tsuna/gohbase/hrpc"
"github.com/bilibili/kratos/pkg/log"
)
// HookFunc hook function call before every method and hook return function will call after finish.
type HookFunc func(ctx context.Context, call hrpc.Call, customName string) func(err error)
// Client hbase client.
type Client struct {
hc gohbase.Client
addr string
config *Config
hooks []HookFunc
}
// AddHook add hook function.
func (c *Client) AddHook(hookFn HookFunc) {
c.hooks = append(c.hooks, hookFn)
}
func (c *Client) invokeHook(ctx context.Context, call hrpc.Call, customName string) func(error) {
finishHooks := make([]func(error), 0, len(c.hooks))
for _, fn := range c.hooks {
finishHooks = append(finishHooks, fn(ctx, call, customName))
}
return func(err error) {
for _, fn := range finishHooks {
fn(err)
}
}
}
// NewClient new a hbase client.
func NewClient(config *Config, options ...gohbase.Option) *Client {
rawcli := NewRawClient(config, options...)
rawcli.AddHook(NewSlowLogHook(250 * time.Millisecond))
rawcli.AddHook(MetricsHook(nil))
rawcli.AddHook(TraceHook("database/hbase", strings.Join(config.Zookeeper.Addrs, ",")))
return rawcli
}
// NewRawClient new a hbase client without prometheus metrics and dapper trace hook.
func NewRawClient(config *Config, options ...gohbase.Option) *Client {
zkquorum := strings.Join(config.Zookeeper.Addrs, ",")
if config.Zookeeper.Root != "" {
options = append(options, gohbase.ZookeeperRoot(config.Zookeeper.Root))
}
if config.Zookeeper.Timeout != 0 {
options = append(options, gohbase.ZookeeperTimeout(time.Duration(config.Zookeeper.Timeout)))
}
if config.RPCQueueSize != 0 {
log.Warn("RPCQueueSize configuration be ignored")
}
// force RpcQueueSize = 1, don't change it !!! it has reason (゜-゜)つロ
options = append(options, gohbase.RpcQueueSize(1))
if config.FlushInterval != 0 {
options = append(options, gohbase.FlushInterval(time.Duration(config.FlushInterval)))
}
if config.EffectiveUser != "" {
options = append(options, gohbase.EffectiveUser(config.EffectiveUser))
}
if config.RegionLookupTimeout != 0 {
options = append(options, gohbase.RegionLookupTimeout(time.Duration(config.RegionLookupTimeout)))
}
if config.RegionReadTimeout != 0 {
options = append(options, gohbase.RegionReadTimeout(time.Duration(config.RegionReadTimeout)))
}
hc := gohbase.NewClient(zkquorum, options...)
return &Client{
hc: hc,
addr: zkquorum,
config: config,
}
}
// ScanAll do scan command and return all result
// NOTE: if err != nil the results is safe for range operate even not result found
func (c *Client) ScanAll(ctx context.Context, table []byte, options ...func(hrpc.Call) error) (results []*hrpc.Result, err error) {
cursor, err := c.Scan(ctx, table, options...)
if err != nil {
return nil, err
}
for {
result, err := cursor.Next()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
results = append(results, result)
}
return results, nil
}
type scanTrace struct {
hrpc.Scanner
finishHook func(error)
}
func (s *scanTrace) Next() (*hrpc.Result, error) {
result, err := s.Scanner.Next()
if err != nil {
s.finishHook(err)
}
return result, err
}
func (s *scanTrace) Close() error {
err := s.Scanner.Close()
s.finishHook(err)
return err
}
// Scan do a scan command.
func (c *Client) Scan(ctx context.Context, table []byte, options ...func(hrpc.Call) error) (scanner hrpc.Scanner, err error) {
var scan *hrpc.Scan
scan, err = hrpc.NewScan(ctx, table, options...)
if err != nil {
return nil, err
}
st := &scanTrace{}
st.finishHook = c.invokeHook(ctx, scan, "Scan")
st.Scanner = c.hc.Scan(scan)
return st, nil
}
// ScanStr scan string
func (c *Client) ScanStr(ctx context.Context, table string, options ...func(hrpc.Call) error) (hrpc.Scanner, error) {
return c.Scan(ctx, []byte(table), options...)
}
// ScanStrAll scan string
// NOTE: if err != nil the results is safe for range operate even not result found
func (c *Client) ScanStrAll(ctx context.Context, table string, options ...func(hrpc.Call) error) ([]*hrpc.Result, error) {
return c.ScanAll(ctx, []byte(table), options...)
}
// ScanRange get a scanner for the given table and key range.
// The range is half-open, i.e. [startRow; stopRow[ -- stopRow is not
// included in the range.
func (c *Client) ScanRange(ctx context.Context, table, startRow, stopRow []byte, options ...func(hrpc.Call) error) (scanner hrpc.Scanner, err error) {
var scan *hrpc.Scan
scan, err = hrpc.NewScanRange(ctx, table, startRow, stopRow, options...)
if err != nil {
return nil, err
}
st := &scanTrace{}
st.finishHook = c.invokeHook(ctx, scan, "ScanRange")
st.Scanner = c.hc.Scan(scan)
return st, nil
}
// ScanRangeStr get a scanner for the given table and key range.
// The range is half-open, i.e. [startRow; stopRow[ -- stopRow is not
// included in the range.
func (c *Client) ScanRangeStr(ctx context.Context, table, startRow, stopRow string, options ...func(hrpc.Call) error) (hrpc.Scanner, error) {
return c.ScanRange(ctx, []byte(table), []byte(startRow), []byte(stopRow), options...)
}
// Get get result for the given table and row key.
// NOTE: if err != nil then result != nil, if result not exists result.Cells length is 0
func (c *Client) Get(ctx context.Context, table, key []byte, options ...func(hrpc.Call) error) (result *hrpc.Result, err error) {
var get *hrpc.Get
get, err = hrpc.NewGet(ctx, table, key, options...)
if err != nil {
return nil, err
}
finishHook := c.invokeHook(ctx, get, "GET")
result, err = c.hc.Get(get)
finishHook(err)
return
}
// GetStr do a get command.
// NOTE: if err != nil then result != nil, if result not exists result.Cells length is 0
func (c *Client) GetStr(ctx context.Context, table, key string, options ...func(hrpc.Call) error) (result *hrpc.Result, err error) {
return c.Get(ctx, []byte(table), []byte(key), options...)
}
// PutStr insert the given family-column-values in the given row key of the given table.
func (c *Client) PutStr(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) {
put, err := hrpc.NewPutStr(ctx, table, key, values, options...)
if err != nil {
return nil, err
}
finishHook := c.invokeHook(ctx, put, "PUT")
result, err := c.hc.Put(put)
finishHook(err)
return result, err
}
// Delete is used to perform Delete operations on a single row.
// To delete entire row, values should be nil.
//
// To delete specific families, qualifiers map should be nil:
// map[string]map[string][]byte{
// "cf1": nil,
// "cf2": nil,
// }
//
// To delete specific qualifiers:
// map[string]map[string][]byte{
// "cf": map[string][]byte{
// "q1": nil,
// "q2": nil,
// },
// }
//
// To delete all versions before and at a timestamp, pass hrpc.Timestamp() option.
// By default all versions will be removed.
//
// To delete only a specific version at a timestamp, pass hrpc.DeleteOneVersion() option
// along with a timestamp. For delete specific qualifiers request, if timestamp is not
// passed, only the latest version will be removed. For delete specific families request,
// the timestamp should be passed or it will have no effect as it's an expensive
// operation to perform.
func (c *Client) Delete(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) {
del, err := hrpc.NewDelStr(ctx, table, key, values, options...)
if err != nil {
return nil, err
}
finishHook := c.invokeHook(ctx, del, "Delete")
result, err := c.hc.Delete(del)
finishHook(err)
return result, err
}
// Append do a append command.
func (c *Client) Append(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) {
appd, err := hrpc.NewAppStr(ctx, table, key, values, options...)
if err != nil {
return nil, err
}
finishHook := c.invokeHook(ctx, appd, "Append")
result, err := c.hc.Append(appd)
finishHook(err)
return result, err
}
// Increment the given values in HBase under the given table and key.
func (c *Client) Increment(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (int64, error) {
increment, err := hrpc.NewIncStr(ctx, table, key, values, options...)
if err != nil {
return 0, err
}
finishHook := c.invokeHook(ctx, increment, "Increment")
result, err := c.hc.Increment(increment)
finishHook(err)
return result, err
}
// IncrementSingle increment the given value by amount in HBase under the given table, key, family and qualifier.
func (c *Client) IncrementSingle(ctx context.Context, table string, key string, family string, qualifier string, amount int64, options ...func(hrpc.Call) error) (int64, error) {
increment, err := hrpc.NewIncStrSingle(ctx, table, key, family, qualifier, amount, options...)
if err != nil {
return 0, err
}
finishHook := c.invokeHook(ctx, increment, "IncrementSingle")
result, err := c.hc.Increment(increment)
finishHook(err)
return result, err
}
// Ping ping.
func (c *Client) Ping(ctx context.Context) (err error) {
testRowKey := "test"
if c.config.TestRowKey != "" {
testRowKey = c.config.TestRowKey
}
values := map[string]map[string][]byte{"test": map[string][]byte{"test": []byte("test")}}
_, err = c.PutStr(ctx, "test", testRowKey, values)
return
}
// Close close client.
func (c *Client) Close() error {
c.hc.Close()
return nil
}

@ -0,0 +1,48 @@
package hbase
import (
"context"
"io"
"time"
"github.com/tsuna/gohbase"
"github.com/tsuna/gohbase/hrpc"
"github.com/bilibili/kratos/pkg/stat"
)
func codeFromErr(err error) string {
code := "unknown_error"
switch err {
case gohbase.ErrClientClosed:
code = "client_closed"
case gohbase.ErrConnotFindRegion:
code = "connot_find_region"
case gohbase.TableNotFound:
code = "table_not_found"
case gohbase.ErrRegionUnavailable:
code = "region_unavailable"
}
return code
}
// MetricsHook if stats is nil use stat.DB as default.
func MetricsHook(stats stat.Stat) HookFunc {
if stats == nil {
stats = stat.DB
}
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) {
now := time.Now()
if customName == "" {
customName = call.Name()
}
method := "hbase:" + customName
return func(err error) {
durationMs := int64(time.Since(now) / time.Millisecond)
stats.Timing(method, durationMs)
if err != nil && err != io.EOF {
stats.Incr(method, codeFromErr(err))
}
}
}
}

@ -0,0 +1,24 @@
package hbase
import (
"context"
"time"
"github.com/tsuna/gohbase/hrpc"
"github.com/bilibili/kratos/pkg/log"
)
// NewSlowLogHook log slow operation.
func NewSlowLogHook(threshold time.Duration) HookFunc {
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) {
start := time.Now()
return func(error) {
duration := time.Since(start)
if duration < threshold {
return
}
log.Warn("hbase slow log: %s %s %s time: %s", customName, call.Table(), call.Key(), duration)
}
}
}

@ -0,0 +1,40 @@
package hbase
import (
"context"
"io"
"github.com/tsuna/gohbase/hrpc"
"github.com/bilibili/kratos/pkg/net/trace"
)
// TraceHook create new hbase trace hook.
func TraceHook(component, instance string) HookFunc {
var internalTags []trace.Tag
internalTags = append(internalTags, trace.TagString(trace.TagComponent, component))
internalTags = append(internalTags, trace.TagString(trace.TagDBInstance, instance))
internalTags = append(internalTags, trace.TagString(trace.TagPeerService, "hbase"))
internalTags = append(internalTags, trace.TagString(trace.TagSpanKind, "client"))
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) {
noop := func(error) {}
root, ok := trace.FromContext(ctx)
if !ok {
return noop
}
if customName == "" {
customName = call.Name()
}
span := root.Fork("", "Hbase:"+customName)
span.SetTag(internalTags...)
statement := string(call.Table()) + " " + string(call.Key())
span.SetTag(trace.TagString(trace.TagDBStatement, statement))
return func(err error) {
if err == io.EOF {
// reset error for trace.
err = nil
}
span.Finish(&err)
}
}
}

@ -0,0 +1,9 @@
#### database/sql
##### 项目简介
MySQL数据库驱动,进行封装加入了链路追踪和统计。
如果需要SQL级别的超时管理 可以在业务代码里面使用context.WithDeadline实现 推荐超时配置放到application.toml里面 方便热加载
##### 依赖包
1. [Go-MySQL-Driver](https://github.com/go-sql-driver/mysql)

@ -0,0 +1,40 @@
package sql
import (
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
"github.com/bilibili/kratos/pkg/stat"
"github.com/bilibili/kratos/pkg/time"
// database driver
_ "github.com/go-sql-driver/mysql"
)
var stats = stat.DB
// Config mysql config.
type Config struct {
Addr string // for trace
DSN string // write data source name.
ReadDSN []string // read data source name.
Active int // pool
Idle int // pool
IdleTimeout time.Duration // connect max life time.
QueryTimeout time.Duration // query sql timeout
ExecTimeout time.Duration // execute sql timeout
TranTimeout time.Duration // transaction sql timeout
Breaker *breaker.Config // breaker
}
// NewMySQL new db and retry connection when has error.
func NewMySQL(c *Config) (db *DB) {
if c.QueryTimeout == 0 || c.ExecTimeout == 0 || c.TranTimeout == 0 {
panic("mysql must be set query/execute/transction timeout")
}
db, err := Open(c)
if err != nil {
log.Error("open mysql error(%v)", err)
panic(err)
}
return
}

@ -0,0 +1,678 @@
package sql
import (
"context"
"database/sql"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/bilibili/kratos/pkg/ecode"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
"github.com/bilibili/kratos/pkg/net/trace"
"github.com/pkg/errors"
)
const (
_family = "sql_client"
_slowLogDuration = time.Millisecond * 250
)
var (
// ErrStmtNil prepared stmt error
ErrStmtNil = errors.New("sql: prepare failed and stmt nil")
// ErrNoMaster is returned by Master when call master multiple times.
ErrNoMaster = errors.New("sql: no master instance")
// ErrNoRows is returned by Scan when QueryRow doesn't return a row.
// In such a case, QueryRow returns a placeholder *Row value that defers
// this error until a Scan.
ErrNoRows = sql.ErrNoRows
// ErrTxDone transaction done.
ErrTxDone = sql.ErrTxDone
)
// DB database.
type DB struct {
write *conn
read []*conn
idx int64
master *DB
}
// conn database connection
type conn struct {
*sql.DB
breaker breaker.Breaker
conf *Config
}
// Tx transaction.
type Tx struct {
db *conn
tx *sql.Tx
t trace.Trace
c context.Context
cancel func()
}
// Row row.
type Row struct {
err error
*sql.Row
db *conn
query string
args []interface{}
t trace.Trace
cancel func()
}
// Scan copies the columns from the matched row into the values pointed at by dest.
func (r *Row) Scan(dest ...interface{}) (err error) {
defer slowLog(fmt.Sprintf("Scan query(%s) args(%+v)", r.query, r.args), time.Now())
if r.t != nil {
defer r.t.Finish(&err)
}
if r.err != nil {
err = r.err
} else if r.Row == nil {
err = ErrStmtNil
}
if err != nil {
return
}
err = r.Row.Scan(dest...)
if r.cancel != nil {
r.cancel()
}
r.db.onBreaker(&err)
if err != ErrNoRows {
err = errors.Wrapf(err, "query %s args %+v", r.query, r.args)
}
return
}
// Rows rows.
type Rows struct {
*sql.Rows
cancel func()
}
// Close closes the Rows, preventing further enumeration. If Next is called
// and returns false and there are no further result sets,
// the Rows are closed automatically and it will suffice to check the
// result of Err. Close is idempotent and does not affect the result of Err.
func (rs *Rows) Close() (err error) {
err = errors.WithStack(rs.Rows.Close())
if rs.cancel != nil {
rs.cancel()
}
return
}
// Stmt prepared stmt.
type Stmt struct {
db *conn
tx bool
query string
stmt atomic.Value
t trace.Trace
}
// Open opens a database specified by its database driver name and a
// driver-specific data source name, usually consisting of at least a database
// name and connection information.
func Open(c *Config) (*DB, error) {
db := new(DB)
d, err := connect(c, c.DSN)
if err != nil {
return nil, err
}
brkGroup := breaker.NewGroup(c.Breaker)
brk := brkGroup.Get(c.Addr)
w := &conn{DB: d, breaker: brk, conf: c}
rs := make([]*conn, 0, len(c.ReadDSN))
for _, rd := range c.ReadDSN {
d, err := connect(c, rd)
if err != nil {
return nil, err
}
brk := brkGroup.Get(parseDSNAddr(rd))
r := &conn{DB: d, breaker: brk, conf: c}
rs = append(rs, r)
}
db.write = w
db.read = rs
db.master = &DB{write: db.write}
return db, nil
}
func connect(c *Config, dataSourceName string) (*sql.DB, error) {
d, err := sql.Open("mysql", dataSourceName)
if err != nil {
err = errors.WithStack(err)
return nil, err
}
d.SetMaxOpenConns(c.Active)
d.SetMaxIdleConns(c.Idle)
d.SetConnMaxLifetime(time.Duration(c.IdleTimeout))
return d, nil
}
// Begin starts a transaction. The isolation level is dependent on the driver.
func (db *DB) Begin(c context.Context) (tx *Tx, err error) {
return db.write.begin(c)
}
// Exec executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (db *DB) Exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) {
return db.write.exec(c, query, args...)
}
// Prepare creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the returned
// statement. The caller must call the statement's Close method when the
// statement is no longer needed.
func (db *DB) Prepare(query string) (*Stmt, error) {
return db.write.prepare(query)
}
// Prepared creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the returned
// statement. The caller must call the statement's Close method when the
// statement is no longer needed.
func (db *DB) Prepared(query string) (stmt *Stmt) {
return db.write.prepared(query)
}
// Query executes a query that returns rows, typically a SELECT. The args are
// for any placeholder parameters in the query.
func (db *DB) Query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) {
idx := db.readIndex()
for i := range db.read {
if rows, err = db.read[(idx+i)%len(db.read)].query(c, query, args...); !ecode.ServiceUnavailable.Equal(err) {
return
}
}
return db.write.query(c, query, args...)
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until Row's
// Scan method is called.
func (db *DB) QueryRow(c context.Context, query string, args ...interface{}) *Row {
idx := db.readIndex()
for i := range db.read {
if row := db.read[(idx+i)%len(db.read)].queryRow(c, query, args...); !ecode.ServiceUnavailable.Equal(row.err) {
return row
}
}
return db.write.queryRow(c, query, args...)
}
func (db *DB) readIndex() int {
if len(db.read) == 0 {
return 0
}
v := atomic.AddInt64(&db.idx, 1)
return int(v) % len(db.read)
}
// Close closes the write and read database, releasing any open resources.
func (db *DB) Close() (err error) {
if e := db.write.Close(); e != nil {
err = errors.WithStack(e)
}
for _, rd := range db.read {
if e := rd.Close(); e != nil {
err = errors.WithStack(e)
}
}
return
}
// Ping verifies a connection to the database is still alive, establishing a
// connection if necessary.
func (db *DB) Ping(c context.Context) (err error) {
if err = db.write.ping(c); err != nil {
return
}
for _, rd := range db.read {
if err = rd.ping(c); err != nil {
return
}
}
return
}
// Master return *DB instance direct use master conn
// use this *DB instance only when you have some reason need to get result without any delay.
func (db *DB) Master() *DB {
if db.master == nil {
panic(ErrNoMaster)
}
return db.master
}
func (db *conn) onBreaker(err *error) {
if err != nil && *err != nil && *err != sql.ErrNoRows && *err != sql.ErrTxDone {
db.breaker.MarkFailed()
} else {
db.breaker.MarkSuccess()
}
}
func (db *conn) begin(c context.Context) (tx *Tx, err error) {
now := time.Now()
defer slowLog("Begin", now)
t, ok := trace.FromContext(c)
if ok {
t = t.Fork(_family, "begin")
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, ""))
defer func() {
if err != nil {
t.Finish(&err)
}
}()
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("mysql:begin", "breaker")
return
}
_, c, cancel := db.conf.TranTimeout.Shrink(c)
rtx, err := db.BeginTx(c, nil)
stats.Timing("mysql:begin", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.WithStack(err)
cancel()
return
}
tx = &Tx{tx: rtx, t: t, db: db, c: c, cancel: cancel}
return
}
func (db *conn) exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", query, args), now)
if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "exec")
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query))
defer t.Finish(&err)
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("mysql:exec", "breaker")
return
}
_, c, cancel := db.conf.ExecTimeout.Shrink(c)
res, err = db.ExecContext(c, query, args...)
cancel()
db.onBreaker(&err)
stats.Timing("mysql:exec", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "exec:%s, args:%+v", query, args)
}
return
}
func (db *conn) ping(c context.Context) (err error) {
now := time.Now()
defer slowLog("Ping", now)
if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "ping")
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, ""))
defer t.Finish(&err)
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("mysql:ping", "breaker")
return
}
_, c, cancel := db.conf.ExecTimeout.Shrink(c)
err = db.PingContext(c)
cancel()
db.onBreaker(&err)
stats.Timing("mysql:ping", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.WithStack(err)
}
return
}
func (db *conn) prepare(query string) (*Stmt, error) {
defer slowLog(fmt.Sprintf("Prepare query(%s)", query), time.Now())
stmt, err := db.Prepare(query)
if err != nil {
err = errors.Wrapf(err, "prepare %s", query)
return nil, err
}
st := &Stmt{query: query, db: db}
st.stmt.Store(stmt)
return st, nil
}
func (db *conn) prepared(query string) (stmt *Stmt) {
defer slowLog(fmt.Sprintf("Prepared query(%s)", query), time.Now())
stmt = &Stmt{query: query, db: db}
s, err := db.Prepare(query)
if err == nil {
stmt.stmt.Store(s)
return
}
go func() {
for {
s, err := db.Prepare(query)
if err != nil {
time.Sleep(time.Second)
continue
}
stmt.stmt.Store(s)
return
}
}()
return
}
func (db *conn) query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", query, args), now)
if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "query")
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query))
defer t.Finish(&err)
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("mysql:query", "breaker")
return
}
_, c, cancel := db.conf.QueryTimeout.Shrink(c)
rs, err := db.DB.QueryContext(c, query, args...)
db.onBreaker(&err)
stats.Timing("mysql:query", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "query:%s, args:%+v", query, args)
cancel()
return
}
rows = &Rows{Rows: rs, cancel: cancel}
return
}
func (db *conn) queryRow(c context.Context, query string, args ...interface{}) *Row {
now := time.Now()
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", query, args), now)
t, ok := trace.FromContext(c)
if ok {
t = t.Fork(_family, "queryrow")
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query))
}
if err := db.breaker.Allow(); err != nil {
stats.Incr("mysql:queryrow", "breaker")
return &Row{db: db, t: t, err: err}
}
_, c, cancel := db.conf.QueryTimeout.Shrink(c)
r := db.DB.QueryRowContext(c, query, args...)
stats.Timing("mysql:queryrow", int64(time.Since(now)/time.Millisecond))
return &Row{db: db, Row: r, query: query, args: args, t: t, cancel: cancel}
}
// Close closes the statement.
func (s *Stmt) Close() (err error) {
if s == nil {
err = ErrStmtNil
return
}
stmt, ok := s.stmt.Load().(*sql.Stmt)
if ok {
err = errors.WithStack(stmt.Close())
}
return
}
// Exec executes a prepared statement with the given arguments and returns a
// Result summarizing the effect of the statement.
func (s *Stmt) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) {
if s == nil {
err = ErrStmtNil
return
}
now := time.Now()
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", s.query, args), now)
if s.tx {
if s.t != nil {
s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
}
} else if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "exec")
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query))
defer t.Finish(&err)
}
if err = s.db.breaker.Allow(); err != nil {
stats.Incr("mysql:stmt:exec", "breaker")
return
}
stmt, ok := s.stmt.Load().(*sql.Stmt)
if !ok {
err = ErrStmtNil
return
}
_, c, cancel := s.db.conf.ExecTimeout.Shrink(c)
res, err = stmt.ExecContext(c, args...)
cancel()
s.db.onBreaker(&err)
stats.Timing("mysql:stmt:exec", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "exec:%s, args:%+v", s.query, args)
}
return
}
// Query executes a prepared query statement with the given arguments and
// returns the query results as a *Rows.
func (s *Stmt) Query(c context.Context, args ...interface{}) (rows *Rows, err error) {
if s == nil {
err = ErrStmtNil
return
}
now := time.Now()
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", s.query, args), now)
if s.tx {
if s.t != nil {
s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
}
} else if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "query")
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query))
defer t.Finish(&err)
}
if err = s.db.breaker.Allow(); err != nil {
stats.Incr("mysql:stmt:query", "breaker")
return
}
stmt, ok := s.stmt.Load().(*sql.Stmt)
if !ok {
err = ErrStmtNil
return
}
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c)
rs, err := stmt.QueryContext(c, args...)
s.db.onBreaker(&err)
stats.Timing("mysql:stmt:query", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "query:%s, args:%+v", s.query, args)
cancel()
return
}
rows = &Rows{Rows: rs, cancel: cancel}
return
}
// QueryRow executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
func (s *Stmt) QueryRow(c context.Context, args ...interface{}) (row *Row) {
now := time.Now()
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", s.query, args), now)
row = &Row{db: s.db, query: s.query, args: args}
if s == nil {
row.err = ErrStmtNil
return
}
if s.tx {
if s.t != nil {
s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
}
} else if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "queryrow")
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query))
row.t = t
}
if row.err = s.db.breaker.Allow(); row.err != nil {
stats.Incr("mysql:stmt:queryrow", "breaker")
return
}
stmt, ok := s.stmt.Load().(*sql.Stmt)
if !ok {
return
}
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c)
row.Row = stmt.QueryRowContext(c, args...)
row.cancel = cancel
stats.Timing("mysql:stmt:queryrow", int64(time.Since(now)/time.Millisecond))
return
}
// Commit commits the transaction.
func (tx *Tx) Commit() (err error) {
err = tx.tx.Commit()
tx.cancel()
tx.db.onBreaker(&err)
if tx.t != nil {
tx.t.Finish(&err)
}
if err != nil {
err = errors.WithStack(err)
}
return
}
// Rollback aborts the transaction.
func (tx *Tx) Rollback() (err error) {
err = tx.tx.Rollback()
tx.cancel()
tx.db.onBreaker(&err)
if tx.t != nil {
tx.t.Finish(&err)
}
if err != nil {
err = errors.WithStack(err)
}
return
}
// Exec executes a query that doesn't return rows. For example: an INSERT and
// UPDATE.
func (tx *Tx) Exec(query string, args ...interface{}) (res sql.Result, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", query, args), now)
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("exec %s", query)))
}
res, err = tx.tx.ExecContext(tx.c, query, args...)
stats.Timing("mysql:tx:exec", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "exec:%s, args:%+v", query, args)
}
return
}
// Query executes a query that returns rows, typically a SELECT.
func (tx *Tx) Query(query string, args ...interface{}) (rows *Rows, err error) {
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("query %s", query)))
}
now := time.Now()
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", query, args), now)
defer func() {
stats.Timing("mysql:tx:query", int64(time.Since(now)/time.Millisecond))
}()
rs, err := tx.tx.QueryContext(tx.c, query, args...)
if err == nil {
rows = &Rows{Rows: rs}
} else {
err = errors.Wrapf(err, "query:%s, args:%+v", query, args)
}
return
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until Row's
// Scan method is called.
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("queryrow %s", query)))
}
now := time.Now()
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", query, args), now)
defer func() {
stats.Timing("mysql:tx:queryrow", int64(time.Since(now)/time.Millisecond))
}()
r := tx.tx.QueryRowContext(tx.c, query, args...)
return &Row{Row: r, db: tx.db, query: query, args: args}
}
// Stmt returns a transaction-specific prepared statement from an existing statement.
func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
as, ok := stmt.stmt.Load().(*sql.Stmt)
if !ok {
return nil
}
ts := tx.tx.StmtContext(tx.c, as)
st := &Stmt{query: stmt.query, tx: true, t: tx.t, db: tx.db}
st.stmt.Store(ts)
return st
}
// Prepare creates a prepared statement for use within a transaction.
// The returned statement operates within the transaction and can no longer be
// used once the transaction has been committed or rolled back.
// To use an existing prepared statement on this transaction, see Tx.Stmt.
func (tx *Tx) Prepare(query string) (*Stmt, error) {
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("prepare %s", query)))
}
defer slowLog(fmt.Sprintf("Prepare query(%s)", query), time.Now())
stmt, err := tx.tx.Prepare(query)
if err != nil {
err = errors.Wrapf(err, "prepare %s", query)
return nil, err
}
st := &Stmt{query: query, tx: true, t: tx.t, db: tx.db}
st.stmt.Store(stmt)
return st, nil
}
// parseDSNAddr parse dsn name and return addr.
func parseDSNAddr(dsn string) (addr string) {
if dsn == "" {
return
}
part0 := strings.Split(dsn, "@")
if len(part0) > 1 {
part1 := strings.Split(part0[1], "?")
if len(part1) > 0 {
addr = part1[0]
}
}
return
}
func slowLog(statement string, now time.Time) {
du := time.Since(now)
if du > _slowLogDuration {
log.Warn("%s slow log statement: %s time: %v", _family, statement, du)
}
}

@ -0,0 +1,14 @@
#### database/tidb
##### 项目简介
TiDB数据库驱动 对mysql驱动进行封装
##### 功能
1. 支持discovery服务发现 多节点直连
2. 支持通过lvs单一地址连接
3. 支持prepare绑定多个节点
4. 支持动态增减节点负载均衡
5. 日志区分运行节点
##### 依赖包
1.[Go-MySQL-Driver](https://github.com/go-sql-driver/mysql)

@ -0,0 +1,58 @@
package tidb
import (
"context"
"fmt"
"strings"
"time"
"github.com/bilibili/kratos/pkg/conf/env"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/naming"
"github.com/bilibili/kratos/pkg/naming/discovery"
)
var _schema = "tidb://"
func (db *DB) nodeList() (nodes []string) {
var (
insInfo *naming.InstancesInfo
insMap map[string][]*naming.Instance
ins []*naming.Instance
ok bool
)
if insInfo, ok = db.dis.Fetch(context.Background()); !ok {
return
}
insMap = insInfo.Instances
if ins, ok = insMap[env.Zone]; !ok || len(ins) == 0 {
return
}
for _, in := range ins {
for _, addr := range in.Addrs {
if strings.HasPrefix(addr, _schema) {
addr = strings.Replace(addr, _schema, "", -1)
nodes = append(nodes, addr)
}
}
}
log.Info("tidb get %s instances(%v)", db.appid, nodes)
return
}
func (db *DB) disc() (nodes []string) {
db.dis = discovery.Build(db.appid)
e := db.dis.Watch()
select {
case <-e:
nodes = db.nodeList()
case <-time.After(10 * time.Second):
panic("tidb init discovery err")
}
if len(nodes) == 0 {
panic(fmt.Sprintf("tidb %s no instance", db.appid))
}
go db.nodeproc(e)
log.Info("init tidb discvoery info successfully")
return
}

@ -0,0 +1,82 @@
package tidb
import (
"time"
"github.com/bilibili/kratos/pkg/log"
)
func (db *DB) nodeproc(e <-chan struct{}) {
if db.dis == nil {
return
}
for {
<-e
nodes := db.nodeList()
if len(nodes) == 0 {
continue
}
cm := make(map[string]*conn)
var conns []*conn
for _, conn := range db.conns {
cm[conn.addr] = conn
}
for _, node := range nodes {
if cm[node] != nil {
conns = append(conns, cm[node])
continue
}
c, err := db.connectDSN(genDSN(db.conf.DSN, node))
if err == nil {
conns = append(conns, c)
} else {
log.Error("tidb: connect addr: %s err: %+v", node, err)
}
}
if len(conns) == 0 {
log.Error("tidb: no nodes ignore event")
continue
}
oldConns := db.conns
db.mutex.Lock()
db.conns = conns
db.mutex.Unlock()
log.Info("tidb: new nodes: %v", nodes)
var removedConn []*conn
for _, conn := range oldConns {
var exist bool
for _, c := range conns {
if c.addr == conn.addr {
exist = true
break
}
}
if !exist {
removedConn = append(removedConn, conn)
}
}
go db.closeConns(removedConn)
}
}
func (db *DB) closeConns(conns []*conn) {
if len(conns) == 0 {
return
}
du := db.conf.QueryTimeout
if db.conf.ExecTimeout > du {
du = db.conf.ExecTimeout
}
if db.conf.TranTimeout > du {
du = db.conf.TranTimeout
}
time.Sleep(time.Duration(du))
for _, conn := range conns {
err := conn.Close()
if err != nil {
log.Error("tidb: close removed conn: %s err: %v", conn.addr, err)
} else {
log.Info("tidb: close removed conn: %s", conn.addr)
}
}
}

@ -0,0 +1,739 @@
package tidb
import (
"context"
"database/sql"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/naming"
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
"github.com/bilibili/kratos/pkg/net/trace"
"github.com/go-sql-driver/mysql"
"github.com/pkg/errors"
)
const (
_family = "tidb_client"
_slowLogDuration = time.Millisecond * 250
)
var (
// ErrStmtNil prepared stmt error
ErrStmtNil = errors.New("sql: prepare failed and stmt nil")
// ErrNoRows is returned by Scan when QueryRow doesn't return a row.
// In such a case, QueryRow returns a placeholder *Row value that defers
// this error until a Scan.
ErrNoRows = sql.ErrNoRows
// ErrTxDone transaction done.
ErrTxDone = sql.ErrTxDone
)
// DB database.
type DB struct {
conf *Config
conns []*conn
idx int64
dis naming.Resolver
appid string
mutex sync.RWMutex
breakerGroup *breaker.Group
}
// conn database connection
type conn struct {
*sql.DB
breaker breaker.Breaker
conf *Config
addr string
}
// Tx transaction.
type Tx struct {
db *conn
tx *sql.Tx
t trace.Trace
c context.Context
cancel func()
}
// Row row.
type Row struct {
err error
*sql.Row
db *conn
query string
args []interface{}
t trace.Trace
cancel func()
}
// Scan copies the columns from the matched row into the values pointed at by dest.
func (r *Row) Scan(dest ...interface{}) (err error) {
defer slowLog(fmt.Sprintf("Scan addr: %s query(%s) args(%+v)", r.db.addr, r.query, r.args), time.Now())
if r.t != nil {
defer r.t.Finish(&err)
}
if r.err != nil {
err = r.err
} else if r.Row == nil {
err = ErrStmtNil
}
if err != nil {
return
}
err = r.Row.Scan(dest...)
if r.cancel != nil {
r.cancel()
}
r.db.onBreaker(&err)
if err != ErrNoRows {
err = errors.Wrapf(err, "addr: %s, query %s args %+v", r.db.addr, r.query, r.args)
}
return
}
// Rows rows.
type Rows struct {
*sql.Rows
cancel func()
}
// Close closes the Rows, preventing further enumeration. If Next is called
// and returns false and there are no further result sets,
// the Rows are closed automatically and it will suffice to check the
// result of Err. Close is idempotent and does not affect the result of Err.
func (rs *Rows) Close() (err error) {
err = errors.WithStack(rs.Rows.Close())
if rs.cancel != nil {
rs.cancel()
}
return
}
// Stmt prepared stmt.
type Stmt struct {
db *conn
tx bool
query string
stmt atomic.Value
t trace.Trace
}
// Stmts random prepared stmt.
type Stmts struct {
query string
sts map[string]*Stmt
mu sync.RWMutex
db *DB
}
// Open opens a database specified by its database driver name and a
// driver-specific data source name, usually consisting of at least a database
// name and connection information.
func Open(c *Config) (db *DB, err error) {
db = &DB{conf: c, breakerGroup: breaker.NewGroup(c.Breaker)}
cfg, err := mysql.ParseDSN(c.DSN)
if err != nil {
return
}
var dsns []string
if cfg.Net == "discovery" {
db.appid = cfg.Addr
for _, addr := range db.disc() {
dsns = append(dsns, genDSN(c.DSN, addr))
}
} else {
dsns = append(dsns, c.DSN)
}
cs := make([]*conn, 0, len(dsns))
for _, dsn := range dsns {
r, err := db.connectDSN(dsn)
if err != nil {
return db, err
}
cs = append(cs, r)
}
db.conns = cs
return
}
func (db *DB) connectDSN(dsn string) (c *conn, err error) {
d, err := connect(db.conf, dsn)
if err != nil {
return
}
addr := parseDSNAddr(dsn)
brk := db.breakerGroup.Get(addr)
c = &conn{DB: d, breaker: brk, conf: db.conf, addr: addr}
return
}
func connect(c *Config, dataSourceName string) (*sql.DB, error) {
d, err := sql.Open("mysql", dataSourceName)
if err != nil {
err = errors.WithStack(err)
return nil, err
}
d.SetMaxOpenConns(c.Active)
d.SetMaxIdleConns(c.Idle)
d.SetConnMaxLifetime(time.Duration(c.IdleTimeout))
return d, nil
}
func (db *DB) conn() (c *conn) {
db.mutex.RLock()
c = db.conns[db.index()]
db.mutex.RUnlock()
return
}
// Begin starts a transaction. The isolation level is dependent on the driver.
func (db *DB) Begin(c context.Context) (tx *Tx, err error) {
return db.conn().begin(c)
}
// Exec executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (db *DB) Exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) {
return db.conn().exec(c, query, args...)
}
// Prepare creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the returned
// statement. The caller must call the statement's Close method when the
// statement is no longer needed.
func (db *DB) Prepare(query string) (*Stmt, error) {
return db.conn().prepare(query)
}
// Prepared creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the returned
// statement. The caller must call the statement's Close method when the
// statement is no longer needed.
func (db *DB) Prepared(query string) (s *Stmts) {
s = &Stmts{query: query, sts: make(map[string]*Stmt), db: db}
for _, c := range db.conns {
st := c.prepared(query)
s.mu.Lock()
s.sts[c.addr] = st
s.mu.Unlock()
}
return
}
// Query executes a query that returns rows, typically a SELECT. The args are
// for any placeholder parameters in the query.
func (db *DB) Query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) {
return db.conn().query(c, query, args...)
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until Row's
// Scan method is called.
func (db *DB) QueryRow(c context.Context, query string, args ...interface{}) *Row {
return db.conn().queryRow(c, query, args...)
}
func (db *DB) index() int {
if len(db.conns) == 1 {
return 0
}
v := atomic.AddInt64(&db.idx, 1)
return int(v) % len(db.conns)
}
// Close closes the databases, releasing any open resources.
func (db *DB) Close() (err error) {
db.mutex.RLock()
defer db.mutex.RUnlock()
for _, d := range db.conns {
if e := d.Close(); e != nil {
err = errors.WithStack(e)
}
}
return
}
// Ping verifies a connection to the database is still alive, establishing a
// connection if necessary.
func (db *DB) Ping(c context.Context) (err error) {
if err = db.conn().ping(c); err != nil {
return
}
return
}
func (db *conn) onBreaker(err *error) {
if err != nil && *err != nil && *err != sql.ErrNoRows && *err != sql.ErrTxDone {
db.breaker.MarkFailed()
} else {
db.breaker.MarkSuccess()
}
}
func (db *conn) begin(c context.Context) (tx *Tx, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Begin addr: %s", db.addr), now)
t, ok := trace.FromContext(c)
if ok {
t = t.Fork(_family, "begin")
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, ""))
defer func() {
if err != nil {
t.Finish(&err)
}
}()
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("tidb:begin", "breaker")
return
}
_, c, cancel := db.conf.TranTimeout.Shrink(c)
rtx, err := db.BeginTx(c, nil)
stats.Timing("tidb:begin", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.WithStack(err)
cancel()
return
}
tx = &Tx{tx: rtx, t: t, db: db, c: c, cancel: cancel}
return
}
func (db *conn) exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", db.addr, query, args), now)
if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "exec")
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query))
defer t.Finish(&err)
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("tidb:exec", "breaker")
return
}
_, c, cancel := db.conf.ExecTimeout.Shrink(c)
res, err = db.ExecContext(c, query, args...)
cancel()
db.onBreaker(&err)
stats.Timing("tidb:exec", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", db.addr, query, args)
}
return
}
func (db *conn) ping(c context.Context) (err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Ping addr: %s", db.addr), now)
if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "ping")
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, ""))
defer t.Finish(&err)
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("tidb:ping", "breaker")
return
}
_, c, cancel := db.conf.ExecTimeout.Shrink(c)
err = db.PingContext(c)
cancel()
db.onBreaker(&err)
stats.Timing("tidb:ping", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.WithStack(err)
}
return
}
func (db *conn) prepare(query string) (*Stmt, error) {
defer slowLog(fmt.Sprintf("Prepare addr: %s query(%s)", db.addr, query), time.Now())
stmt, err := db.Prepare(query)
if err != nil {
err = errors.Wrapf(err, "addr: %s prepare %s", db.addr, query)
return nil, err
}
st := &Stmt{query: query, db: db}
st.stmt.Store(stmt)
return st, nil
}
func (db *conn) prepared(query string) (stmt *Stmt) {
defer slowLog(fmt.Sprintf("Prepared addr: %s query(%s)", db.addr, query), time.Now())
stmt = &Stmt{query: query, db: db}
s, err := db.Prepare(query)
if err == nil {
stmt.stmt.Store(s)
return
}
return
}
func (db *conn) query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", db.addr, query, args), now)
if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "query")
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query))
defer t.Finish(&err)
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("tidb:query", "breaker")
return
}
_, c, cancel := db.conf.QueryTimeout.Shrink(c)
rs, err := db.DB.QueryContext(c, query, args...)
db.onBreaker(&err)
stats.Timing("tidb:query", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", db.addr, query, args)
cancel()
return
}
rows = &Rows{Rows: rs, cancel: cancel}
return
}
func (db *conn) queryRow(c context.Context, query string, args ...interface{}) *Row {
now := time.Now()
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", db.addr, query, args), now)
t, ok := trace.FromContext(c)
if ok {
t = t.Fork(_family, "queryrow")
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query))
}
if err := db.breaker.Allow(); err != nil {
stats.Incr("tidb:queryrow", "breaker")
return &Row{db: db, t: t, err: err}
}
_, c, cancel := db.conf.QueryTimeout.Shrink(c)
r := db.DB.QueryRowContext(c, query, args...)
stats.Timing("tidb:queryrow", int64(time.Since(now)/time.Millisecond))
return &Row{db: db, Row: r, query: query, args: args, t: t, cancel: cancel}
}
// Close closes the statement.
func (s *Stmt) Close() (err error) {
stmt, ok := s.stmt.Load().(*sql.Stmt)
if ok {
err = errors.WithStack(stmt.Close())
}
return
}
func (s *Stmt) prepare() (st *sql.Stmt) {
var ok bool
if st, ok = s.stmt.Load().(*sql.Stmt); ok {
return
}
var err error
if st, err = s.db.Prepare(s.query); err == nil {
s.stmt.Store(st)
}
return
}
// Exec executes a prepared statement with the given arguments and returns a
// Result summarizing the effect of the statement.
func (s *Stmt) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now)
if s.tx {
if s.t != nil {
s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
}
} else if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "exec")
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query))
defer t.Finish(&err)
}
if err = s.db.breaker.Allow(); err != nil {
stats.Incr("tidb:stmt:exec", "breaker")
return
}
stmt := s.prepare()
if stmt == nil {
err = ErrStmtNil
return
}
_, c, cancel := s.db.conf.ExecTimeout.Shrink(c)
res, err = stmt.ExecContext(c, args...)
cancel()
s.db.onBreaker(&err)
stats.Timing("tidb:stmt:exec", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", s.db.addr, s.query, args)
}
return
}
// Query executes a prepared query statement with the given arguments and
// returns the query results as a *Rows.
func (s *Stmt) Query(c context.Context, args ...interface{}) (rows *Rows, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now)
if s.tx {
if s.t != nil {
s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
}
} else if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "query")
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query))
defer t.Finish(&err)
}
if err = s.db.breaker.Allow(); err != nil {
stats.Incr("tidb:stmt:query", "breaker")
return
}
stmt := s.prepare()
if stmt == nil {
err = ErrStmtNil
return
}
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c)
rs, err := stmt.QueryContext(c, args...)
s.db.onBreaker(&err)
stats.Timing("tidb:stmt:query", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", s.db.addr, s.query, args)
cancel()
return
}
rows = &Rows{Rows: rs, cancel: cancel}
return
}
// QueryRow executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
func (s *Stmt) QueryRow(c context.Context, args ...interface{}) (row *Row) {
now := time.Now()
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now)
row = &Row{db: s.db, query: s.query, args: args}
if s.tx {
if s.t != nil {
s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
}
} else if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "queryrow")
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query))
row.t = t
}
if row.err = s.db.breaker.Allow(); row.err != nil {
stats.Incr("tidb:stmt:queryrow", "breaker")
return
}
stmt := s.prepare()
if stmt == nil {
return
}
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c)
row.Row = stmt.QueryRowContext(c, args...)
row.cancel = cancel
stats.Timing("tidb:stmt:queryrow", int64(time.Since(now)/time.Millisecond))
return
}
func (s *Stmts) prepare(conn *conn) (st *Stmt) {
if conn == nil {
conn = s.db.conn()
}
s.mu.RLock()
st = s.sts[conn.addr]
s.mu.RUnlock()
if st == nil {
st = conn.prepared(s.query)
s.mu.Lock()
s.sts[conn.addr] = st
s.mu.Unlock()
}
return
}
// Exec executes a prepared statement with the given arguments and returns a
// Result summarizing the effect of the statement.
func (s *Stmts) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) {
return s.prepare(nil).Exec(c, args...)
}
// Query executes a prepared query statement with the given arguments and
// returns the query results as a *Rows.
func (s *Stmts) Query(c context.Context, args ...interface{}) (rows *Rows, err error) {
return s.prepare(nil).Query(c, args...)
}
// QueryRow executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
func (s *Stmts) QueryRow(c context.Context, args ...interface{}) (row *Row) {
return s.prepare(nil).QueryRow(c, args...)
}
// Close closes the statement.
func (s *Stmts) Close() (err error) {
for _, st := range s.sts {
if err = errors.WithStack(st.Close()); err != nil {
return
}
}
return
}
// Commit commits the transaction.
func (tx *Tx) Commit() (err error) {
err = tx.tx.Commit()
tx.cancel()
tx.db.onBreaker(&err)
if tx.t != nil {
tx.t.Finish(&err)
}
if err != nil {
err = errors.WithStack(err)
}
return
}
// Rollback aborts the transaction.
func (tx *Tx) Rollback() (err error) {
err = tx.tx.Rollback()
tx.cancel()
tx.db.onBreaker(&err)
if tx.t != nil {
tx.t.Finish(&err)
}
if err != nil {
err = errors.WithStack(err)
}
return
}
// Exec executes a query that doesn't return rows. For example: an INSERT and
// UPDATE.
func (tx *Tx) Exec(query string, args ...interface{}) (res sql.Result, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now)
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("exec %s", query)))
}
res, err = tx.tx.ExecContext(tx.c, query, args...)
stats.Timing("tidb:tx:exec", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", tx.db.addr, query, args)
}
return
}
// Query executes a query that returns rows, typically a SELECT.
func (tx *Tx) Query(query string, args ...interface{}) (rows *Rows, err error) {
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("query %s", query)))
}
now := time.Now()
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now)
defer func() {
stats.Timing("tidb:tx:query", int64(time.Since(now)/time.Millisecond))
}()
rs, err := tx.tx.QueryContext(tx.c, query, args...)
if err == nil {
rows = &Rows{Rows: rs}
} else {
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", tx.db.addr, query, args)
}
return
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until Row's
// Scan method is called.
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("queryrow %s", query)))
}
now := time.Now()
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now)
defer func() {
stats.Timing("tidb:tx:queryrow", int64(time.Since(now)/time.Millisecond))
}()
r := tx.tx.QueryRowContext(tx.c, query, args...)
return &Row{Row: r, db: tx.db, query: query, args: args}
}
// Stmt returns a transaction-specific prepared statement from an existing statement.
func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
if stmt == nil {
return nil
}
as, ok := stmt.stmt.Load().(*sql.Stmt)
if !ok {
return nil
}
ts := tx.tx.StmtContext(tx.c, as)
st := &Stmt{query: stmt.query, tx: true, t: tx.t, db: tx.db}
st.stmt.Store(ts)
return st
}
// Stmts returns a transaction-specific prepared statement from an existing statement.
func (tx *Tx) Stmts(stmt *Stmts) *Stmt {
return tx.Stmt(stmt.prepare(tx.db))
}
// Prepare creates a prepared statement for use within a transaction.
// The returned statement operates within the transaction and can no longer be
// used once the transaction has been committed or rolled back.
// To use an existing prepared statement on this transaction, see Tx.Stmt.
func (tx *Tx) Prepare(query string) (*Stmt, error) {
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("prepare %s", query)))
}
defer slowLog(fmt.Sprintf("Prepare addr: %s query(%s)", tx.db.addr, query), time.Now())
stmt, err := tx.tx.Prepare(query)
if err != nil {
err = errors.Wrapf(err, "addr: %s prepare %s", tx.db.addr, query)
return nil, err
}
st := &Stmt{query: query, tx: true, t: tx.t, db: tx.db}
st.stmt.Store(stmt)
return st, nil
}
// parseDSNAddr parse dsn name and return addr.
func parseDSNAddr(dsn string) (addr string) {
if dsn == "" {
return
}
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return
}
addr = cfg.Addr
return
}
func genDSN(dsn, addr string) (res string) {
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return
}
cfg.Addr = addr
cfg.Net = "tcp"
res = cfg.FormatDSN()
return
}
func slowLog(statement string, now time.Time) {
du := time.Since(now)
if du > _slowLogDuration {
log.Warn("%s slow log statement: %s time: %v", _family, statement, du)
}
}

@ -0,0 +1,38 @@
package tidb
import (
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
"github.com/bilibili/kratos/pkg/stat"
"github.com/bilibili/kratos/pkg/time"
// database driver
_ "github.com/go-sql-driver/mysql"
)
var stats = stat.DB
// Config mysql config.
type Config struct {
DSN string // dsn
Active int // pool
Idle int // pool
IdleTimeout time.Duration // connect max life time.
QueryTimeout time.Duration // query sql timeout
ExecTimeout time.Duration // execute sql timeout
TranTimeout time.Duration // transaction sql timeout
Breaker *breaker.Config // breaker
}
// NewTiDB new db and retry connection when has error.
func NewTiDB(c *Config) (db *DB) {
if c.QueryTimeout == 0 || c.ExecTimeout == 0 || c.TranTimeout == 0 {
panic("tidb must be set query/execute/transction timeout")
}
db, err := Open(c)
if err != nil {
log.Error("open tidb error(%v)", err)
panic(err)
}
return
}

@ -0,0 +1,85 @@
package binding
import (
"net/http"
"strings"
"gopkg.in/go-playground/validator.v9"
)
// MIME
const (
MIMEJSON = "application/json"
MIMEHTML = "text/html"
MIMEXML = "application/xml"
MIMEXML2 = "text/xml"
MIMEPlain = "text/plain"
MIMEPOSTForm = "application/x-www-form-urlencoded"
MIMEMultipartPOSTForm = "multipart/form-data"
)
// Binding http binding request interface.
type Binding interface {
Name() string
Bind(*http.Request, interface{}) error
}
// StructValidator http validator interface.
type StructValidator interface {
// ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right.
// If the received type is not a struct, any validation should be skipped and nil must be returned.
// If the received type is a struct or pointer to a struct, the validation should be performed.
// If the struct is not valid or the validation itself fails, a descriptive error should be returned.
// Otherwise nil must be returned.
ValidateStruct(interface{}) error
// RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key
// NOTE: if the key already exists, the previous validation function will be replaced.
// NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
RegisterValidation(string, validator.Func) error
}
// Validator default validator.
var Validator StructValidator = &defaultValidator{}
// Binding
var (
JSON = jsonBinding{}
XML = xmlBinding{}
Form = formBinding{}
Query = queryBinding{}
FormPost = formPostBinding{}
FormMultipart = formMultipartBinding{}
)
// Default get by binding type by method and contexttype.
func Default(method, contentType string) Binding {
if method == "GET" {
return Form
}
contentType = stripContentTypeParam(contentType)
switch contentType {
case MIMEJSON:
return JSON
case MIMEXML, MIMEXML2:
return XML
default: //case MIMEPOSTForm, MIMEMultipartPOSTForm:
return Form
}
}
func validate(obj interface{}) error {
if Validator == nil {
return nil
}
return Validator.ValidateStruct(obj)
}
func stripContentTypeParam(contentType string) string {
i := strings.Index(contentType, ";")
if i != -1 {
contentType = contentType[:i]
}
return contentType
}

@ -0,0 +1,342 @@
package binding
import (
"bytes"
"mime/multipart"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
type FooStruct struct {
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" validate:"required"`
}
type FooBarStruct struct {
FooStruct
Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" validate:"required"`
Slice []string `form:"slice" validate:"max=10"`
}
type ComplexDefaultStruct struct {
Int int `form:"int" default:"999"`
String string `form:"string" default:"default-string"`
Bool bool `form:"bool" default:"false"`
Int64Slice []int64 `form:"int64_slice,split" default:"1,2,3,4"`
Int8Slice []int8 `form:"int8_slice,split" default:"1,2,3,4"`
}
type Int8SliceStruct struct {
State []int8 `form:"state,split"`
}
type Int64SliceStruct struct {
State []int64 `form:"state,split"`
}
type StringSliceStruct struct {
State []string `form:"state,split"`
}
func TestBindingDefault(t *testing.T) {
assert.Equal(t, Default("GET", ""), Form)
assert.Equal(t, Default("GET", MIMEJSON), Form)
assert.Equal(t, Default("GET", MIMEJSON+"; charset=utf-8"), Form)
assert.Equal(t, Default("POST", MIMEJSON), JSON)
assert.Equal(t, Default("PUT", MIMEJSON), JSON)
assert.Equal(t, Default("POST", MIMEJSON+"; charset=utf-8"), JSON)
assert.Equal(t, Default("PUT", MIMEJSON+"; charset=utf-8"), JSON)
assert.Equal(t, Default("POST", MIMEXML), XML)
assert.Equal(t, Default("PUT", MIMEXML2), XML)
assert.Equal(t, Default("POST", MIMEPOSTForm), Form)
assert.Equal(t, Default("PUT", MIMEPOSTForm), Form)
assert.Equal(t, Default("POST", MIMEPOSTForm+"; charset=utf-8"), Form)
assert.Equal(t, Default("PUT", MIMEPOSTForm+"; charset=utf-8"), Form)
assert.Equal(t, Default("POST", MIMEMultipartPOSTForm), Form)
assert.Equal(t, Default("PUT", MIMEMultipartPOSTForm), Form)
}
func TestStripContentType(t *testing.T) {
c1 := "application/vnd.mozilla.xul+xml"
c2 := "application/vnd.mozilla.xul+xml; charset=utf-8"
assert.Equal(t, stripContentTypeParam(c1), c1)
assert.Equal(t, stripContentTypeParam(c2), "application/vnd.mozilla.xul+xml")
}
func TestBindInt8Form(t *testing.T) {
params := "state=1,2,3"
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q := new(Int8SliceStruct)
Form.Bind(req, q)
assert.EqualValues(t, []int8{1, 2, 3}, q.State)
params = "state=1,2,3,256"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(Int8SliceStruct)
assert.Error(t, Form.Bind(req, q))
params = "state="
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(Int8SliceStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Len(t, q.State, 0)
params = "state=1,,2"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(Int8SliceStruct)
assert.NoError(t, Form.Bind(req, q))
assert.EqualValues(t, []int8{1, 2}, q.State)
}
func TestBindInt64Form(t *testing.T) {
params := "state=1,2,3"
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q := new(Int64SliceStruct)
Form.Bind(req, q)
assert.EqualValues(t, []int64{1, 2, 3}, q.State)
params = "state="
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(Int64SliceStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Len(t, q.State, 0)
}
func TestBindStringForm(t *testing.T) {
params := "state=1,2,3"
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q := new(StringSliceStruct)
Form.Bind(req, q)
assert.EqualValues(t, []string{"1", "2", "3"}, q.State)
params = "state="
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(StringSliceStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Len(t, q.State, 0)
params = "state=p,,p"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(StringSliceStruct)
Form.Bind(req, q)
assert.EqualValues(t, []string{"p", "p"}, q.State)
}
func TestBindingJSON(t *testing.T) {
testBodyBinding(t,
JSON, "json",
"/", "/",
`{"foo": "bar"}`, `{"bar": "foo"}`)
}
func TestBindingForm(t *testing.T) {
testFormBinding(t, "POST",
"/", "/",
"foo=bar&bar=foo&slice=a&slice=b", "bar2=foo")
}
func TestBindingForm2(t *testing.T) {
testFormBinding(t, "GET",
"/?foo=bar&bar=foo", "/?bar2=foo",
"", "")
}
func TestBindingQuery(t *testing.T) {
testQueryBinding(t, "POST",
"/?foo=bar&bar=foo", "/",
"foo=unused", "bar2=foo")
}
func TestBindingQuery2(t *testing.T) {
testQueryBinding(t, "GET",
"/?foo=bar&bar=foo", "/?bar2=foo",
"foo=unused", "")
}
func TestBindingXML(t *testing.T) {
testBodyBinding(t,
XML, "xml",
"/", "/",
"<map><foo>bar</foo></map>", "<map><bar>foo</bar></map>")
}
func createFormPostRequest() *http.Request {
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo"))
req.Header.Set("Content-Type", MIMEPOSTForm)
return req
}
func createFormMultipartRequest() *http.Request {
boundary := "--testboundary"
body := new(bytes.Buffer)
mw := multipart.NewWriter(body)
defer mw.Close()
mw.SetBoundary(boundary)
mw.WriteField("foo", "bar")
mw.WriteField("bar", "foo")
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body)
req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary)
return req
}
func TestBindingFormPost(t *testing.T) {
req := createFormPostRequest()
var obj FooBarStruct
FormPost.Bind(req, &obj)
assert.Equal(t, obj.Foo, "bar")
assert.Equal(t, obj.Bar, "foo")
}
func TestBindingFormMultipart(t *testing.T) {
req := createFormMultipartRequest()
var obj FooBarStruct
FormMultipart.Bind(req, &obj)
assert.Equal(t, obj.Foo, "bar")
assert.Equal(t, obj.Bar, "foo")
}
func TestValidationFails(t *testing.T) {
var obj FooStruct
req := requestWithBody("POST", "/", `{"bar": "foo"}`)
err := JSON.Bind(req, &obj)
assert.Error(t, err)
}
func TestValidationDisabled(t *testing.T) {
backup := Validator
Validator = nil
defer func() { Validator = backup }()
var obj FooStruct
req := requestWithBody("POST", "/", `{"bar": "foo"}`)
err := JSON.Bind(req, &obj)
assert.NoError(t, err)
}
func TestExistsSucceeds(t *testing.T) {
type HogeStruct struct {
Hoge *int `json:"hoge" binding:"exists"`
}
var obj HogeStruct
req := requestWithBody("POST", "/", `{"hoge": 0}`)
err := JSON.Bind(req, &obj)
assert.NoError(t, err)
}
func TestFormDefaultValue(t *testing.T) {
params := "int=333&string=hello&bool=true&int64_slice=5,6,7,8&int8_slice=5,6,7,8"
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q := new(ComplexDefaultStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Equal(t, 333, q.Int)
assert.Equal(t, "hello", q.String)
assert.Equal(t, true, q.Bool)
assert.EqualValues(t, []int64{5, 6, 7, 8}, q.Int64Slice)
assert.EqualValues(t, []int8{5, 6, 7, 8}, q.Int8Slice)
params = "string=hello&bool=false"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(ComplexDefaultStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Equal(t, 999, q.Int)
assert.Equal(t, "hello", q.String)
assert.Equal(t, false, q.Bool)
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
params = "strings=hello"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(ComplexDefaultStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Equal(t, 999, q.Int)
assert.Equal(t, "default-string", q.String)
assert.Equal(t, false, q.Bool)
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
params = "int=&string=&bool=true&int64_slice=&int8_slice="
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(ComplexDefaultStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Equal(t, 999, q.Int)
assert.Equal(t, "default-string", q.String)
assert.Equal(t, true, q.Bool)
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
}
func testFormBinding(t *testing.T, method, path, badPath, body, badBody string) {
b := Form
assert.Equal(t, b.Name(), "form")
obj := FooBarStruct{}
req := requestWithBody(method, path, body)
if method == "POST" {
req.Header.Add("Content-Type", MIMEPOSTForm)
}
err := b.Bind(req, &obj)
assert.NoError(t, err)
assert.Equal(t, obj.Foo, "bar")
assert.Equal(t, obj.Bar, "foo")
obj = FooBarStruct{}
req = requestWithBody(method, badPath, badBody)
err = JSON.Bind(req, &obj)
assert.Error(t, err)
}
func testQueryBinding(t *testing.T, method, path, badPath, body, badBody string) {
b := Query
assert.Equal(t, b.Name(), "query")
obj := FooBarStruct{}
req := requestWithBody(method, path, body)
if method == "POST" {
req.Header.Add("Content-Type", MIMEPOSTForm)
}
err := b.Bind(req, &obj)
assert.NoError(t, err)
assert.Equal(t, obj.Foo, "bar")
assert.Equal(t, obj.Bar, "foo")
}
func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
assert.Equal(t, b.Name(), name)
obj := FooStruct{}
req := requestWithBody("POST", path, body)
err := b.Bind(req, &obj)
assert.NoError(t, err)
assert.Equal(t, obj.Foo, "bar")
obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
assert.Error(t, err)
}
func requestWithBody(method, path, body string) (req *http.Request) {
req, _ = http.NewRequest(method, path, bytes.NewBufferString(body))
return
}
func BenchmarkBindingForm(b *testing.B) {
req := requestWithBody("POST", "/", "foo=bar&bar=foo&slice=a&slice=b&slice=c&slice=w")
req.Header.Add("Content-Type", MIMEPOSTForm)
f := Form
for i := 0; i < b.N; i++ {
obj := FooBarStruct{}
f.Bind(req, &obj)
}
}

@ -0,0 +1,45 @@
package binding
import (
"reflect"
"sync"
"gopkg.in/go-playground/validator.v9"
)
type defaultValidator struct {
once sync.Once
validate *validator.Validate
}
var _ StructValidator = &defaultValidator{}
func (v *defaultValidator) ValidateStruct(obj interface{}) error {
if kindOfData(obj) == reflect.Struct {
v.lazyinit()
if err := v.validate.Struct(obj); err != nil {
return err
}
}
return nil
}
func (v *defaultValidator) RegisterValidation(key string, fn validator.Func) error {
v.lazyinit()
return v.validate.RegisterValidation(key, fn)
}
func (v *defaultValidator) lazyinit() {
v.once.Do(func() {
v.validate = validator.New()
})
}
func kindOfData(data interface{}) reflect.Kind {
value := reflect.ValueOf(data)
valueType := value.Kind()
if valueType == reflect.Ptr {
valueType = value.Elem().Kind()
}
return valueType
}

@ -0,0 +1,113 @@
// Code generated by protoc-gen-go.
// source: test.proto
// DO NOT EDIT!
/*
Package example is a generated protocol buffer package.
It is generated from these files:
test.proto
It has these top-level messages:
Test
*/
package example
import proto "github.com/golang/protobuf/proto"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = math.Inf
type FOO int32
const (
FOO_X FOO = 17
)
var FOO_name = map[int32]string{
17: "X",
}
var FOO_value = map[string]int32{
"X": 17,
}
func (x FOO) Enum() *FOO {
p := new(FOO)
*p = x
return p
}
func (x FOO) String() string {
return proto.EnumName(FOO_name, int32(x))
}
func (x *FOO) UnmarshalJSON(data []byte) error {
value, err := proto.UnmarshalJSONEnum(FOO_value, data, "FOO")
if err != nil {
return err
}
*x = FOO(value)
return nil
}
type Test struct {
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
func (m *Test) Reset() { *m = Test{} }
func (m *Test) String() string { return proto.CompactTextString(m) }
func (*Test) ProtoMessage() {}
const Default_Test_Type int32 = 77
func (m *Test) GetLabel() string {
if m != nil && m.Label != nil {
return *m.Label
}
return ""
}
func (m *Test) GetType() int32 {
if m != nil && m.Type != nil {
return *m.Type
}
return Default_Test_Type
}
func (m *Test) GetReps() []int64 {
if m != nil {
return m.Reps
}
return nil
}
func (m *Test) GetOptionalgroup() *Test_OptionalGroup {
if m != nil {
return m.Optionalgroup
}
return nil
}
type Test_OptionalGroup struct {
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} }
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) }
func (*Test_OptionalGroup) ProtoMessage() {}
func (m *Test_OptionalGroup) GetRequiredField() string {
if m != nil && m.RequiredField != nil {
return *m.RequiredField
}
return ""
}
func init() {
proto.RegisterEnum("example.FOO", FOO_name, FOO_value)
}

@ -0,0 +1,12 @@
package example;
enum FOO {X=17;};
message Test {
required string label = 1;
optional int32 type = 2[default=77];
repeated int64 reps = 3;
optional group OptionalGroup = 4{
required string RequiredField = 5;
}
}

@ -0,0 +1,36 @@
package binding
import (
"fmt"
"log"
"net/http"
)
type Arg struct {
Max int64 `form:"max" validate:"max=10"`
Min int64 `form:"min" validate:"min=2"`
Range int64 `form:"range" validate:"min=1,max=10"`
// use split option to split arg 1,2,3 into slice [1 2 3]
// otherwise slice type with parse url.Values (eg:a=b&a=c) default.
Slice []int64 `form:"slice,split" validate:"min=1"`
}
func ExampleBinding() {
req := initHTTP("max=9&min=3&range=3&slice=1,2,3")
arg := new(Arg)
if err := Form.Bind(req, arg); err != nil {
log.Fatal(err)
}
fmt.Printf("arg.Max %d\narg.Min %d\narg.Range %d\narg.Slice %v", arg.Max, arg.Min, arg.Range, arg.Slice)
// Output:
// arg.Max 9
// arg.Min 3
// arg.Range 3
// arg.Slice [1 2 3]
}
func initHTTP(params string) (req *http.Request) {
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
req.ParseForm()
return
}

@ -0,0 +1,55 @@
package binding
import (
"net/http"
"github.com/pkg/errors"
)
const defaultMemory = 32 * 1024 * 1024
type formBinding struct{}
type formPostBinding struct{}
type formMultipartBinding struct{}
func (f formBinding) Name() string {
return "form"
}
func (f formBinding) Bind(req *http.Request, obj interface{}) error {
if err := req.ParseForm(); err != nil {
return errors.WithStack(err)
}
if err := mapForm(obj, req.Form); err != nil {
return err
}
return validate(obj)
}
func (f formPostBinding) Name() string {
return "form-urlencoded"
}
func (f formPostBinding) Bind(req *http.Request, obj interface{}) error {
if err := req.ParseForm(); err != nil {
return errors.WithStack(err)
}
if err := mapForm(obj, req.PostForm); err != nil {
return err
}
return validate(obj)
}
func (f formMultipartBinding) Name() string {
return "multipart/form-data"
}
func (f formMultipartBinding) Bind(req *http.Request, obj interface{}) error {
if err := req.ParseMultipartForm(defaultMemory); err != nil {
return errors.WithStack(err)
}
if err := mapForm(obj, req.MultipartForm.Value); err != nil {
return err
}
return validate(obj)
}

@ -0,0 +1,276 @@
package binding
import (
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/pkg/errors"
)
// scache struct reflect type cache.
var scache = &cache{
data: make(map[reflect.Type]*sinfo),
}
type cache struct {
data map[reflect.Type]*sinfo
mutex sync.RWMutex
}
func (c *cache) get(obj reflect.Type) (s *sinfo) {
var ok bool
c.mutex.RLock()
if s, ok = c.data[obj]; !ok {
c.mutex.RUnlock()
s = c.set(obj)
return
}
c.mutex.RUnlock()
return
}
func (c *cache) set(obj reflect.Type) (s *sinfo) {
s = new(sinfo)
tp := obj.Elem()
for i := 0; i < tp.NumField(); i++ {
fd := new(field)
fd.tp = tp.Field(i)
tag := fd.tp.Tag.Get("form")
fd.name, fd.option = parseTag(tag)
if defV := fd.tp.Tag.Get("default"); defV != "" {
dv := reflect.New(fd.tp.Type).Elem()
setWithProperType(fd.tp.Type.Kind(), []string{defV}, dv, fd.option)
fd.hasDefault = true
fd.defaultValue = dv
}
s.field = append(s.field, fd)
}
c.mutex.Lock()
c.data[obj] = s
c.mutex.Unlock()
return
}
type sinfo struct {
field []*field
}
type field struct {
tp reflect.StructField
name string
option tagOptions
hasDefault bool // if field had default value
defaultValue reflect.Value // field default value
}
func mapForm(ptr interface{}, form map[string][]string) error {
sinfo := scache.get(reflect.TypeOf(ptr))
val := reflect.ValueOf(ptr).Elem()
for i, fd := range sinfo.field {
typeField := fd.tp
structField := val.Field(i)
if !structField.CanSet() {
continue
}
structFieldKind := structField.Kind()
inputFieldName := fd.name
if inputFieldName == "" {
inputFieldName = typeField.Name
// if "form" tag is nil, we inspect if the field is a struct.
// this would not make sense for JSON parsing but it does for a form
// since data is flatten
if structFieldKind == reflect.Struct {
err := mapForm(structField.Addr().Interface(), form)
if err != nil {
return err
}
continue
}
}
inputValue, exists := form[inputFieldName]
if !exists {
// Set the field as default value when the input value is not exist
if fd.hasDefault {
structField.Set(fd.defaultValue)
}
continue
}
// Set the field as default value when the input value is empty
if fd.hasDefault && inputValue[0] == "" {
structField.Set(fd.defaultValue)
continue
}
if _, isTime := structField.Interface().(time.Time); isTime {
if err := setTimeField(inputValue[0], typeField, structField); err != nil {
return err
}
continue
}
if err := setWithProperType(typeField.Type.Kind(), inputValue, structField, fd.option); err != nil {
return err
}
}
return nil
}
func setWithProperType(valueKind reflect.Kind, val []string, structField reflect.Value, option tagOptions) error {
switch valueKind {
case reflect.Int:
return setIntField(val[0], 0, structField)
case reflect.Int8:
return setIntField(val[0], 8, structField)
case reflect.Int16:
return setIntField(val[0], 16, structField)
case reflect.Int32:
return setIntField(val[0], 32, structField)
case reflect.Int64:
return setIntField(val[0], 64, structField)
case reflect.Uint:
return setUintField(val[0], 0, structField)
case reflect.Uint8:
return setUintField(val[0], 8, structField)
case reflect.Uint16:
return setUintField(val[0], 16, structField)
case reflect.Uint32:
return setUintField(val[0], 32, structField)
case reflect.Uint64:
return setUintField(val[0], 64, structField)
case reflect.Bool:
return setBoolField(val[0], structField)
case reflect.Float32:
return setFloatField(val[0], 32, structField)
case reflect.Float64:
return setFloatField(val[0], 64, structField)
case reflect.String:
structField.SetString(val[0])
case reflect.Slice:
if option.Contains("split") {
val = strings.Split(val[0], ",")
}
filtered := filterEmpty(val)
switch structField.Type().Elem().Kind() {
case reflect.Int64:
valSli := make([]int64, 0, len(filtered))
for i := 0; i < len(filtered); i++ {
d, err := strconv.ParseInt(filtered[i], 10, 64)
if err != nil {
return err
}
valSli = append(valSli, d)
}
structField.Set(reflect.ValueOf(valSli))
case reflect.String:
valSli := make([]string, 0, len(filtered))
for i := 0; i < len(filtered); i++ {
valSli = append(valSli, filtered[i])
}
structField.Set(reflect.ValueOf(valSli))
default:
sliceOf := structField.Type().Elem().Kind()
numElems := len(filtered)
slice := reflect.MakeSlice(structField.Type(), len(filtered), len(filtered))
for i := 0; i < numElems; i++ {
if err := setWithProperType(sliceOf, filtered[i:], slice.Index(i), ""); err != nil {
return err
}
}
structField.Set(slice)
}
default:
return errors.New("Unknown type")
}
return nil
}
func setIntField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0"
}
intVal, err := strconv.ParseInt(val, 10, bitSize)
if err == nil {
field.SetInt(intVal)
}
return errors.WithStack(err)
}
func setUintField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0"
}
uintVal, err := strconv.ParseUint(val, 10, bitSize)
if err == nil {
field.SetUint(uintVal)
}
return errors.WithStack(err)
}
func setBoolField(val string, field reflect.Value) error {
if val == "" {
val = "false"
}
boolVal, err := strconv.ParseBool(val)
if err == nil {
field.SetBool(boolVal)
}
return nil
}
func setFloatField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0.0"
}
floatVal, err := strconv.ParseFloat(val, bitSize)
if err == nil {
field.SetFloat(floatVal)
}
return errors.WithStack(err)
}
func setTimeField(val string, structField reflect.StructField, value reflect.Value) error {
timeFormat := structField.Tag.Get("time_format")
if timeFormat == "" {
return errors.New("Blank time format")
}
if val == "" {
value.Set(reflect.ValueOf(time.Time{}))
return nil
}
l := time.Local
if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC {
l = time.UTC
}
if locTag := structField.Tag.Get("time_location"); locTag != "" {
loc, err := time.LoadLocation(locTag)
if err != nil {
return errors.WithStack(err)
}
l = loc
}
t, err := time.ParseInLocation(timeFormat, val, l)
if err != nil {
return errors.WithStack(err)
}
value.Set(reflect.ValueOf(t))
return nil
}
func filterEmpty(val []string) []string {
filtered := make([]string, 0, len(val))
for _, v := range val {
if v != "" {
filtered = append(filtered, v)
}
}
return filtered
}

@ -0,0 +1,22 @@
package binding
import (
"encoding/json"
"net/http"
"github.com/pkg/errors"
)
type jsonBinding struct{}
func (jsonBinding) Name() string {
return "json"
}
func (jsonBinding) Bind(req *http.Request, obj interface{}) error {
decoder := json.NewDecoder(req.Body)
if err := decoder.Decode(obj); err != nil {
return errors.WithStack(err)
}
return validate(obj)
}

@ -0,0 +1,19 @@
package binding
import (
"net/http"
)
type queryBinding struct{}
func (queryBinding) Name() string {
return "query"
}
func (queryBinding) Bind(req *http.Request, obj interface{}) error {
values := req.URL.Query()
if err := mapForm(obj, values); err != nil {
return err
}
return validate(obj)
}

@ -0,0 +1,44 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package binding
import (
"strings"
)
// tagOptions is the string following a comma in a struct field's "json"
// tag, or the empty string. It does not include the leading comma.
type tagOptions string
// parseTag splits a struct field's json tag into its name and
// comma-separated options.
func parseTag(tag string) (string, tagOptions) {
if idx := strings.Index(tag, ","); idx != -1 {
return tag[:idx], tagOptions(tag[idx+1:])
}
return tag, tagOptions("")
}
// Contains reports whether a comma-separated list of options
// contains a particular substr flag. substr must be surrounded by a
// string boundary or commas.
func (o tagOptions) Contains(optionName string) bool {
if len(o) == 0 {
return false
}
s := string(o)
for s != "" {
var next string
i := strings.Index(s, ",")
if i >= 0 {
s, next = s[:i], s[i+1:]
}
if s == optionName {
return true
}
s = next
}
return false
}

@ -0,0 +1,209 @@
package binding
import (
"bytes"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type testInterface interface {
String() string
}
type substructNoValidation struct {
IString string
IInt int
}
type mapNoValidationSub map[string]substructNoValidation
type structNoValidationValues struct {
substructNoValidation
Boolean bool
Uinteger uint
Integer int
Integer8 int8
Integer16 int16
Integer32 int32
Integer64 int64
Uinteger8 uint8
Uinteger16 uint16
Uinteger32 uint32
Uinteger64 uint64
Float32 float32
Float64 float64
String string
Date time.Time
Struct substructNoValidation
InlinedStruct struct {
String []string
Integer int
}
IntSlice []int
IntPointerSlice []*int
StructPointerSlice []*substructNoValidation
StructSlice []substructNoValidation
InterfaceSlice []testInterface
UniversalInterface interface{}
CustomInterface testInterface
FloatMap map[string]float32
StructMap mapNoValidationSub
}
func createNoValidationValues() structNoValidationValues {
integer := 1
s := structNoValidationValues{
Boolean: true,
Uinteger: 1 << 29,
Integer: -10000,
Integer8: 120,
Integer16: -20000,
Integer32: 1 << 29,
Integer64: 1 << 61,
Uinteger8: 250,
Uinteger16: 50000,
Uinteger32: 1 << 31,
Uinteger64: 1 << 62,
Float32: 123.456,
Float64: 123.456789,
String: "text",
Date: time.Time{},
CustomInterface: &bytes.Buffer{},
Struct: substructNoValidation{},
IntSlice: []int{-3, -2, 1, 0, 1, 2, 3},
IntPointerSlice: []*int{&integer},
StructSlice: []substructNoValidation{},
UniversalInterface: 1.2,
FloatMap: map[string]float32{
"foo": 1.23,
"bar": 232.323,
},
StructMap: mapNoValidationSub{
"foo": substructNoValidation{},
"bar": substructNoValidation{},
},
// StructPointerSlice []noValidationSub
// InterfaceSlice []testInterface
}
s.InlinedStruct.Integer = 1000
s.InlinedStruct.String = []string{"first", "second"}
s.IString = "substring"
s.IInt = 987654
return s
}
func TestValidateNoValidationValues(t *testing.T) {
origin := createNoValidationValues()
test := createNoValidationValues()
empty := structNoValidationValues{}
assert.Nil(t, validate(test))
assert.Nil(t, validate(&test))
assert.Nil(t, validate(empty))
assert.Nil(t, validate(&empty))
assert.Equal(t, origin, test)
}
type structNoValidationPointer struct {
// substructNoValidation
Boolean bool
Uinteger *uint
Integer *int
Integer8 *int8
Integer16 *int16
Integer32 *int32
Integer64 *int64
Uinteger8 *uint8
Uinteger16 *uint16
Uinteger32 *uint32
Uinteger64 *uint64
Float32 *float32
Float64 *float64
String *string
Date *time.Time
Struct *substructNoValidation
IntSlice *[]int
IntPointerSlice *[]*int
StructPointerSlice *[]*substructNoValidation
StructSlice *[]substructNoValidation
InterfaceSlice *[]testInterface
FloatMap *map[string]float32
StructMap *mapNoValidationSub
}
func TestValidateNoValidationPointers(t *testing.T) {
//origin := createNoValidation_values()
//test := createNoValidation_values()
empty := structNoValidationPointer{}
//assert.Nil(t, validate(test))
//assert.Nil(t, validate(&test))
assert.Nil(t, validate(empty))
assert.Nil(t, validate(&empty))
//assert.Equal(t, origin, test)
}
type Object map[string]interface{}
func TestValidatePrimitives(t *testing.T) {
obj := Object{"foo": "bar", "bar": 1}
assert.NoError(t, validate(obj))
assert.NoError(t, validate(&obj))
assert.Equal(t, obj, Object{"foo": "bar", "bar": 1})
obj2 := []Object{{"foo": "bar", "bar": 1}, {"foo": "bar", "bar": 1}}
assert.NoError(t, validate(obj2))
assert.NoError(t, validate(&obj2))
nu := 10
assert.NoError(t, validate(nu))
assert.NoError(t, validate(&nu))
assert.Equal(t, nu, 10)
str := "value"
assert.NoError(t, validate(str))
assert.NoError(t, validate(&str))
assert.Equal(t, str, "value")
}
// structCustomValidation is a helper struct we use to check that
// custom validation can be registered on it.
// The `notone` binding directive is for custom validation and registered later.
// type structCustomValidation struct {
// Integer int `binding:"notone"`
// }
// notOne is a custom validator meant to be used with `validator.v8` library.
// The method signature for `v9` is significantly different and this function
// would need to be changed for tests to pass after upgrade.
// See https://github.com/gin-gonic/gin/pull/1015.
// func notOne(
// v *validator.Validate, topStruct reflect.Value, currentStructOrField reflect.Value,
// field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string,
// ) bool {
// if val, ok := field.Interface().(int); ok {
// return val != 1
// }
// return false
// }

@ -0,0 +1,22 @@
package binding
import (
"encoding/xml"
"net/http"
"github.com/pkg/errors"
)
type xmlBinding struct{}
func (xmlBinding) Name() string {
return "xml"
}
func (xmlBinding) Bind(req *http.Request, obj interface{}) error {
decoder := xml.NewDecoder(req.Body)
if err := decoder.Decode(obj); err != nil {
return errors.WithStack(err)
}
return validate(obj)
}

@ -0,0 +1,306 @@
package blademaster
import (
"context"
"math"
"net/http"
"strconv"
"github.com/bilibili/kratos/pkg/ecode"
"github.com/bilibili/kratos/pkg/net/http/blademaster/binding"
"github.com/bilibili/kratos/pkg/net/http/blademaster/render"
"github.com/gogo/protobuf/proto"
"github.com/gogo/protobuf/types"
"github.com/pkg/errors"
)
const (
_abortIndex int8 = math.MaxInt8 / 2
)
var (
_openParen = []byte("(")
_closeParen = []byte(")")
)
// Context is the most important part. It allows us to pass variables between
// middleware, manage the flow, validate the JSON of a request and render a
// JSON response for example.
type Context struct {
context.Context
Request *http.Request
Writer http.ResponseWriter
// flow control
index int8
handlers []HandlerFunc
// Keys is a key/value pair exclusively for the context of each request.
Keys map[string]interface{}
Error error
method string
engine *Engine
}
/************************************/
/*********** FLOW CONTROL ***********/
/************************************/
// Next should be used only inside middleware.
// It executes the pending handlers in the chain inside the calling handler.
// See example in godoc.
func (c *Context) Next() {
c.index++
s := int8(len(c.handlers))
for ; c.index < s; c.index++ {
// only check method on last handler, otherwise middlewares
// will never be effected if request method is not matched
if c.index == s-1 && c.method != c.Request.Method {
code := http.StatusMethodNotAllowed
c.Error = ecode.MethodNotAllowed
http.Error(c.Writer, http.StatusText(code), code)
return
}
c.handlers[c.index](c)
}
}
// Abort prevents pending handlers from being called. Note that this will not stop the current handler.
// Let's say you have an authorization middleware that validates that the current request is authorized.
// If the authorization fails (ex: the password does not match), call Abort to ensure the remaining handlers
// for this request are not called.
func (c *Context) Abort() {
c.index = _abortIndex
}
// AbortWithStatus calls `Abort()` and writes the headers with the specified status code.
// For example, a failed attempt to authenticate a request could use: context.AbortWithStatus(401).
func (c *Context) AbortWithStatus(code int) {
c.Status(code)
c.Abort()
}
// IsAborted returns true if the current context was aborted.
func (c *Context) IsAborted() bool {
return c.index >= _abortIndex
}
/************************************/
/******** METADATA MANAGEMENT********/
/************************************/
// Set is used to store a new key/value pair exclusively for this context.
// It also lazy initializes c.Keys if it was not used previously.
func (c *Context) Set(key string, value interface{}) {
if c.Keys == nil {
c.Keys = make(map[string]interface{})
}
c.Keys[key] = value
}
// Get returns the value for the given key, ie: (value, true).
// If the value does not exists it returns (nil, false)
func (c *Context) Get(key string) (value interface{}, exists bool) {
value, exists = c.Keys[key]
return
}
/************************************/
/******** RESPONSE RENDERING ********/
/************************************/
// bodyAllowedForStatus is a copy of http.bodyAllowedForStatus non-exported function.
func bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == 204:
return false
case status == 304:
return false
}
return true
}
// Status sets the HTTP response code.
func (c *Context) Status(code int) {
c.Writer.WriteHeader(code)
}
// Render http response with http code by a render instance.
func (c *Context) Render(code int, r render.Render) {
r.WriteContentType(c.Writer)
if code > 0 {
c.Status(code)
}
if !bodyAllowedForStatus(code) {
return
}
params := c.Request.Form
cb := params.Get("callback")
jsonp := cb != "" && params.Get("jsonp") == "jsonp"
if jsonp {
c.Writer.Write([]byte(cb))
c.Writer.Write(_openParen)
}
if err := r.Render(c.Writer); err != nil {
c.Error = err
return
}
if jsonp {
if _, err := c.Writer.Write(_closeParen); err != nil {
c.Error = errors.WithStack(err)
}
}
}
// JSON serializes the given struct as JSON into the response body.
// It also sets the Content-Type as "application/json".
func (c *Context) JSON(data interface{}, err error) {
code := http.StatusOK
c.Error = err
bcode := ecode.Cause(err)
// TODO app allow 5xx?
/*
if bcode.Code() == -500 {
code = http.StatusServiceUnavailable
}
*/
writeStatusCode(c.Writer, bcode.Code())
c.Render(code, render.JSON{
Code: bcode.Code(),
Message: bcode.Message(),
Data: data,
})
}
// JSONMap serializes the given map as map JSON into the response body.
// It also sets the Content-Type as "application/json".
func (c *Context) JSONMap(data map[string]interface{}, err error) {
code := http.StatusOK
c.Error = err
bcode := ecode.Cause(err)
// TODO app allow 5xx?
/*
if bcode.Code() == -500 {
code = http.StatusServiceUnavailable
}
*/
writeStatusCode(c.Writer, bcode.Code())
data["code"] = bcode.Code()
if _, ok := data["message"]; !ok {
data["message"] = bcode.Message()
}
c.Render(code, render.MapJSON(data))
}
// XML serializes the given struct as XML into the response body.
// It also sets the Content-Type as "application/xml".
func (c *Context) XML(data interface{}, err error) {
code := http.StatusOK
c.Error = err
bcode := ecode.Cause(err)
// TODO app allow 5xx?
/*
if bcode.Code() == -500 {
code = http.StatusServiceUnavailable
}
*/
writeStatusCode(c.Writer, bcode.Code())
c.Render(code, render.XML{
Code: bcode.Code(),
Message: bcode.Message(),
Data: data,
})
}
// Protobuf serializes the given struct as PB into the response body.
// It also sets the ContentType as "application/x-protobuf".
func (c *Context) Protobuf(data proto.Message, err error) {
var (
bytes []byte
)
code := http.StatusOK
c.Error = err
bcode := ecode.Cause(err)
any := new(types.Any)
if data != nil {
if bytes, err = proto.Marshal(data); err != nil {
c.Error = errors.WithStack(err)
return
}
any.TypeUrl = "type.googleapis.com/" + proto.MessageName(data)
any.Value = bytes
}
writeStatusCode(c.Writer, bcode.Code())
c.Render(code, render.PB{
Code: int64(bcode.Code()),
Message: bcode.Message(),
Data: any,
})
}
// Bytes writes some data into the body stream and updates the HTTP code.
func (c *Context) Bytes(code int, contentType string, data ...[]byte) {
c.Render(code, render.Data{
ContentType: contentType,
Data: data,
})
}
// String writes the given string into the response body.
func (c *Context) String(code int, format string, values ...interface{}) {
c.Render(code, render.String{Format: format, Data: values})
}
// Redirect returns a HTTP redirect to the specific location.
func (c *Context) Redirect(code int, location string) {
c.Render(-1, render.Redirect{
Code: code,
Location: location,
Request: c.Request,
})
}
// BindWith bind req arg with parser.
func (c *Context) BindWith(obj interface{}, b binding.Binding) error {
return c.mustBindWith(obj, b)
}
// Bind bind req arg with defult form binding.
func (c *Context) Bind(obj interface{}) error {
return c.mustBindWith(obj, binding.Form)
}
// mustBindWith binds the passed struct pointer using the specified binding engine.
// It will abort the request with HTTP 400 if any error ocurrs.
// See the binding package.
func (c *Context) mustBindWith(obj interface{}, b binding.Binding) (err error) {
if err = b.Bind(c.Request, obj); err != nil {
c.Error = ecode.RequestErr
c.Render(http.StatusOK, render.JSON{
Code: ecode.RequestErr.Code(),
Message: err.Error(),
Data: nil,
})
c.Abort()
}
return
}
func writeStatusCode(w http.ResponseWriter, ecode int) {
header := w.Header()
header.Set("kratos-status-code", strconv.FormatInt(int64(ecode), 10))
}

@ -0,0 +1,249 @@
package blademaster
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/bilibili/kratos/pkg/log"
"github.com/pkg/errors"
)
// CORSConfig represents all available options for the middleware.
type CORSConfig struct {
AllowAllOrigins bool
// AllowedOrigins is a list of origins a cross-domain request can be executed from.
// If the special "*" value is present in the list, all origins will be allowed.
// Default value is []
AllowOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It take the origin
// as argument and returns true if allowed or false otherwise. If this option is
// set, the content of AllowedOrigins is ignored.
AllowOriginFunc func(origin string) bool
// AllowedMethods is a list of methods the client is allowed to use with
// cross-domain requests. Default value is simple methods (GET and POST)
AllowMethods []string
// AllowedHeaders is list of non simple headers the client is allowed to use with
// cross-domain requests.
AllowHeaders []string
// AllowCredentials indicates whether the request can include user credentials like
// cookies, HTTP authentication or client side SSL certificates.
AllowCredentials bool
// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
// API specification
ExposeHeaders []string
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached
MaxAge time.Duration
}
type cors struct {
allowAllOrigins bool
allowCredentials bool
allowOriginFunc func(string) bool
allowOrigins []string
normalHeaders http.Header
preflightHeaders http.Header
}
type converter func(string) string
// Validate is check configuration of user defined.
func (c *CORSConfig) Validate() error {
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed")
}
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
return errors.New("conflict settings: all origins disabled")
}
for _, origin := range c.AllowOrigins {
if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
return errors.New("bad origin: origins must either be '*' or include http:// or https://")
}
}
return nil
}
// CORS returns the location middleware with default configuration.
func CORS(allowOriginHosts []string) HandlerFunc {
config := &CORSConfig{
AllowMethods: []string{"GET", "POST"},
AllowHeaders: []string{"Origin", "Content-Length", "Content-Type"},
AllowCredentials: true,
MaxAge: time.Duration(0),
AllowOriginFunc: func(origin string) bool {
for _, host := range allowOriginHosts {
if strings.HasSuffix(strings.ToLower(origin), host) {
return true
}
}
return false
},
}
return newCORS(config)
}
// newCORS returns the location middleware with user-defined custom configuration.
func newCORS(config *CORSConfig) HandlerFunc {
if err := config.Validate(); err != nil {
panic(err.Error())
}
cors := &cors{
allowOriginFunc: config.AllowOriginFunc,
allowAllOrigins: config.AllowAllOrigins,
allowCredentials: config.AllowCredentials,
allowOrigins: normalize(config.AllowOrigins),
normalHeaders: generateNormalHeaders(config),
preflightHeaders: generatePreflightHeaders(config),
}
return func(c *Context) {
cors.applyCORS(c)
}
}
func (cors *cors) applyCORS(c *Context) {
origin := c.Request.Header.Get("Origin")
if len(origin) == 0 {
// request is not a CORS request
return
}
if !cors.validateOrigin(origin) {
log.V(5).Info("The request's Origin header `%s` does not match any of allowed origins.", origin)
c.AbortWithStatus(http.StatusForbidden)
return
}
if c.Request.Method == "OPTIONS" {
cors.handlePreflight(c)
defer c.AbortWithStatus(200)
} else {
cors.handleNormal(c)
}
if !cors.allowAllOrigins {
header := c.Writer.Header()
header.Set("Access-Control-Allow-Origin", origin)
}
}
func (cors *cors) validateOrigin(origin string) bool {
if cors.allowAllOrigins {
return true
}
for _, value := range cors.allowOrigins {
if value == origin {
return true
}
}
if cors.allowOriginFunc != nil {
return cors.allowOriginFunc(origin)
}
return false
}
func (cors *cors) handlePreflight(c *Context) {
header := c.Writer.Header()
for key, value := range cors.preflightHeaders {
header[key] = value
}
}
func (cors *cors) handleNormal(c *Context) {
header := c.Writer.Header()
for key, value := range cors.normalHeaders {
header[key] = value
}
}
func generateNormalHeaders(c *CORSConfig) http.Header {
headers := make(http.Header)
if c.AllowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
// backport support for early browsers
if len(c.AllowMethods) > 0 {
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
value := strings.Join(allowMethods, ",")
headers.Set("Access-Control-Allow-Methods", value)
}
if len(c.ExposeHeaders) > 0 {
exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey)
headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ","))
}
if c.AllowAllOrigins {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
headers.Set("Vary", "Origin")
}
return headers
}
func generatePreflightHeaders(c *CORSConfig) http.Header {
headers := make(http.Header)
if c.AllowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
if len(c.AllowMethods) > 0 {
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
value := strings.Join(allowMethods, ",")
headers.Set("Access-Control-Allow-Methods", value)
}
if len(c.AllowHeaders) > 0 {
allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey)
value := strings.Join(allowHeaders, ",")
headers.Set("Access-Control-Allow-Headers", value)
}
if c.MaxAge > time.Duration(0) {
value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10)
headers.Set("Access-Control-Max-Age", value)
}
if c.AllowAllOrigins {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
// Always set Vary headers
// see https://github.com/rs/cors/issues/10,
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
headers.Add("Vary", "Origin")
headers.Add("Vary", "Access-Control-Request-Method")
headers.Add("Vary", "Access-Control-Request-Headers")
}
return headers
}
func normalize(values []string) []string {
if values == nil {
return nil
}
distinctMap := make(map[string]bool, len(values))
normalized := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
value = strings.ToLower(value)
if _, seen := distinctMap[value]; !seen {
normalized = append(normalized, value)
distinctMap[value] = true
}
}
return normalized
}
func convert(s []string, c converter) []string {
var out []string
for _, i := range s {
out = append(out, c(i))
}
return out
}

@ -0,0 +1,64 @@
package blademaster
import (
"net/url"
"regexp"
"strings"
"github.com/bilibili/kratos/pkg/log"
)
func matchHostSuffix(suffix string) func(*url.URL) bool {
return func(uri *url.URL) bool {
return strings.HasSuffix(strings.ToLower(uri.Host), suffix)
}
}
func matchPattern(pattern *regexp.Regexp) func(*url.URL) bool {
return func(uri *url.URL) bool {
return pattern.MatchString(strings.ToLower(uri.String()))
}
}
// CSRF returns the csrf middleware to prevent invalid cross site request.
// Only referer is checked currently.
func CSRF(allowHosts []string, allowPattern []string) HandlerFunc {
validations := []func(*url.URL) bool{}
addHostSuffix := func(suffix string) {
validations = append(validations, matchHostSuffix(suffix))
}
addPattern := func(pattern string) {
validations = append(validations, matchPattern(regexp.MustCompile(pattern)))
}
for _, r := range allowHosts {
addHostSuffix(r)
}
for _, p := range allowPattern {
addPattern(p)
}
return func(c *Context) {
referer := c.Request.Header.Get("Referer")
if referer == "" {
log.V(5).Info("The request's Referer or Origin header is empty.")
c.AbortWithStatus(403)
return
}
illegal := true
if uri, err := url.Parse(referer); err == nil && uri.Host != "" {
for _, validate := range validations {
if validate(uri) {
illegal = false
break
}
}
}
if illegal {
log.V(5).Info("The request's Referer header `%s` does not match any of allowed referers.", referer)
c.AbortWithStatus(403)
return
}
}
}

@ -0,0 +1,69 @@
package blademaster
import (
"fmt"
"strconv"
"time"
"github.com/bilibili/kratos/pkg/ecode"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/net/metadata"
)
// Logger is logger middleware
func Logger() HandlerFunc {
const noUser = "no_user"
return func(c *Context) {
now := time.Now()
ip := metadata.String(c, metadata.RemoteIP)
req := c.Request
path := req.URL.Path
params := req.Form
var quota float64
if deadline, ok := c.Context.Deadline(); ok {
quota = time.Until(deadline).Seconds()
}
c.Next()
err := c.Error
cerr := ecode.Cause(err)
dt := time.Since(now)
caller := metadata.String(c, metadata.Caller)
if caller == "" {
caller = noUser
}
stats.Incr(caller, path[1:], strconv.FormatInt(int64(cerr.Code()), 10))
stats.Timing(caller, int64(dt/time.Millisecond), path[1:])
lf := log.Infov
errmsg := ""
isSlow := dt >= (time.Millisecond * 500)
if err != nil {
errmsg = err.Error()
lf = log.Errorv
if cerr.Code() > 0 {
lf = log.Warnv
}
} else {
if isSlow {
lf = log.Warnv
}
}
lf(c,
log.KVString("method", req.Method),
log.KVString("ip", ip),
log.KVString("user", caller),
log.KVString("path", path),
log.KVString("params", params.Encode()),
log.KVInt("ret", cerr.Code()),
log.KVString("msg", cerr.Message()),
log.KVString("stack", fmt.Sprintf("%+v", err)),
log.KVString("err", errmsg),
log.KVFloat64("timeout_quota", quota),
log.KVFloat64("ts", dt.Seconds()),
log.KVString("source", "http-access-log"),
)
}
}

@ -17,13 +17,14 @@ const (
_httpHeaderUser = "x1-bmspy-user" _httpHeaderUser = "x1-bmspy-user"
_httpHeaderColor = "x1-bmspy-color" _httpHeaderColor = "x1-bmspy-color"
_httpHeaderTimeout = "x1-bmspy-timeout" _httpHeaderTimeout = "x1-bmspy-timeout"
_httpHeaderMirror = "x1-bmspy-mirror"
_httpHeaderRemoteIP = "x-backend-bm-real-ip" _httpHeaderRemoteIP = "x-backend-bm-real-ip"
_httpHeaderRemoteIPPort = "x-backend-bm-real-ipport" _httpHeaderRemoteIPPort = "x-backend-bm-real-ipport"
) )
// mirror return true if x1-bilispy-mirror in http header and its value is 1 or true. // mirror return true if x-bmspy-mirror in http header and its value is 1 or true.
func mirror(req *http.Request) bool { func mirror(req *http.Request) bool {
mirrorStr := req.Header.Get("x1-bilispy-mirror") mirrorStr := req.Header.Get(_httpHeaderMirror)
if mirrorStr == "" { if mirrorStr == "" {
return false return false
} }
@ -79,7 +80,7 @@ func timeout(req *http.Request) time.Duration {
} }
// remoteIP implements a best effort algorithm to return the real client IP, it parses // remoteIP implements a best effort algorithm to return the real client IP, it parses
// X-BACKEND-BILI-REAL-IP or X-Real-IP or X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy. // x-backend-bm-real-ip or X-Real-IP or X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy.
// Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP. // Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP.
func remoteIP(req *http.Request) (remote string) { func remoteIP(req *http.Request) (remote string) {
if remote = req.Header.Get(_httpHeaderRemoteIP); remote != "" && remote != "null" { if remote = req.Header.Get(_httpHeaderRemoteIP); remote != "" && remote != "null" {

@ -0,0 +1,46 @@
package blademaster
import (
"flag"
"net/http"
"net/http/pprof"
"os"
"sync"
"github.com/bilibili/kratos/pkg/conf/dsn"
"github.com/pkg/errors"
)
var (
_perfOnce sync.Once
_perfDSN string
)
func init() {
v := os.Getenv("HTTP_PERF")
if v == "" {
v = "tcp://0.0.0.0:2333"
}
flag.StringVar(&_perfDSN, "http.perf", v, "listen http perf dsn, or use HTTP_PERF env variable.")
}
func startPerf() {
_perfOnce.Do(func() {
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
go func() {
d, err := dsn.Parse(_perfDSN)
if err != nil {
panic(errors.Errorf("blademaster: http perf dsn must be tcp://$host:port, %s:error(%v)", _perfDSN, err))
}
if err := http.ListenAndServe(d.Host, mux); err != nil {
panic(errors.Errorf("blademaster: listen %s: error(%v)", d.Host, err))
}
}()
})
}

@ -0,0 +1,12 @@
package blademaster
import (
"github.com/prometheus/client_golang/prometheus/promhttp"
)
func monitor() HandlerFunc {
return func(c *Context) {
h := promhttp.Handler()
h.ServeHTTP(c.Writer, c.Request)
}
}

@ -0,0 +1,32 @@
package blademaster
import (
"fmt"
"net/http/httputil"
"os"
"runtime"
"github.com/bilibili/kratos/pkg/log"
)
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
func Recovery() HandlerFunc {
return func(c *Context) {
defer func() {
var rawReq []byte
if err := recover(); err != nil {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
if c.Request != nil {
rawReq, _ = httputil.DumpRequest(c.Request, false)
}
pl := fmt.Sprintf("http call panic: %s\n%v\n%s\n", string(rawReq), err, buf)
fmt.Fprintf(os.Stderr, pl)
log.Error(pl)
c.AbortWithStatus(500)
}
}()
c.Next()
}
}

@ -0,0 +1,30 @@
package render
import (
"net/http"
"github.com/pkg/errors"
)
// Data common bytes struct.
type Data struct {
ContentType string
Data [][]byte
}
// Render (Data) writes data with custom ContentType.
func (r Data) Render(w http.ResponseWriter) (err error) {
r.WriteContentType(w)
for _, d := range r.Data {
if _, err = w.Write(d); err != nil {
err = errors.WithStack(err)
return
}
}
return
}
// WriteContentType writes data with custom ContentType.
func (r Data) WriteContentType(w http.ResponseWriter) {
writeContentType(w, []string{r.ContentType})
}

@ -0,0 +1,58 @@
package render
import (
"encoding/json"
"net/http"
"github.com/pkg/errors"
)
var jsonContentType = []string{"application/json; charset=utf-8"}
// JSON common json struct.
type JSON struct {
Code int `json:"code"`
Message string `json:"message"`
TTL int `json:"ttl"`
Data interface{} `json:"data,omitempty"`
}
func writeJSON(w http.ResponseWriter, obj interface{}) (err error) {
var jsonBytes []byte
writeContentType(w, jsonContentType)
if jsonBytes, err = json.Marshal(obj); err != nil {
err = errors.WithStack(err)
return
}
if _, err = w.Write(jsonBytes); err != nil {
err = errors.WithStack(err)
}
return
}
// Render (JSON) writes data with json ContentType.
func (r JSON) Render(w http.ResponseWriter) error {
// FIXME(zhoujiahui): the TTL field will be configurable in the future
if r.TTL <= 0 {
r.TTL = 1
}
return writeJSON(w, r)
}
// WriteContentType write json ContentType.
func (r JSON) WriteContentType(w http.ResponseWriter) {
writeContentType(w, jsonContentType)
}
// MapJSON common map json struct.
type MapJSON map[string]interface{}
// Render (MapJSON) writes data with json ContentType.
func (m MapJSON) Render(w http.ResponseWriter) error {
return writeJSON(w, m)
}
// WriteContentType write json ContentType.
func (m MapJSON) WriteContentType(w http.ResponseWriter) {
writeContentType(w, jsonContentType)
}

@ -0,0 +1,38 @@
package render
import (
"net/http"
"github.com/gogo/protobuf/proto"
"github.com/pkg/errors"
)
var pbContentType = []string{"application/x-protobuf"}
// Render (PB) writes data with protobuf ContentType.
func (r PB) Render(w http.ResponseWriter) error {
if r.TTL <= 0 {
r.TTL = 1
}
return writePB(w, r)
}
// WriteContentType write protobuf ContentType.
func (r PB) WriteContentType(w http.ResponseWriter) {
writeContentType(w, pbContentType)
}
func writePB(w http.ResponseWriter, obj PB) (err error) {
var pbBytes []byte
writeContentType(w, pbContentType)
if pbBytes, err = proto.Marshal(&obj); err != nil {
err = errors.WithStack(err)
return
}
if _, err = w.Write(pbBytes); err != nil {
err = errors.WithStack(err)
}
return
}

@ -0,0 +1,26 @@
package render
import (
"net/http"
"github.com/pkg/errors"
)
// Redirect render for redirect to specified location.
type Redirect struct {
Code int
Request *http.Request
Location string
}
// Render (Redirect) redirect to specified location.
func (r Redirect) Render(w http.ResponseWriter) error {
if (r.Code < 300 || r.Code > 308) && r.Code != 201 {
return errors.Errorf("Cannot redirect with status code %d", r.Code)
}
http.Redirect(w, r.Request, r.Location, r.Code)
return nil
}
// WriteContentType noneContentType.
func (r Redirect) WriteContentType(http.ResponseWriter) {}

@ -0,0 +1,30 @@
package render
import (
"net/http"
)
// Render http reponse render.
type Render interface {
// Render render it to http response writer.
Render(http.ResponseWriter) error
// WriteContentType write content-type to http response writer.
WriteContentType(w http.ResponseWriter)
}
var (
_ Render = JSON{}
_ Render = MapJSON{}
_ Render = XML{}
_ Render = String{}
_ Render = Redirect{}
_ Render = Data{}
_ Render = PB{}
)
func writeContentType(w http.ResponseWriter, value []string) {
header := w.Header()
if val := header["Content-Type"]; len(val) == 0 {
header["Content-Type"] = value
}
}

@ -0,0 +1,89 @@
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: pb.proto
/*
Package render is a generated protocol buffer package.
It is generated from these files:
pb.proto
It has these top-level messages:
PB
*/
package render
import proto "github.com/gogo/protobuf/proto"
import fmt "fmt"
import math "math"
import google_protobuf "github.com/gogo/protobuf/types"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package
type PB struct {
Code int64 `protobuf:"varint,1,opt,name=Code,proto3" json:"Code,omitempty"`
Message string `protobuf:"bytes,2,opt,name=Message,proto3" json:"Message,omitempty"`
TTL uint64 `protobuf:"varint,3,opt,name=TTL,proto3" json:"TTL,omitempty"`
Data *google_protobuf.Any `protobuf:"bytes,4,opt,name=Data" json:"Data,omitempty"`
}
func (m *PB) Reset() { *m = PB{} }
func (m *PB) String() string { return proto.CompactTextString(m) }
func (*PB) ProtoMessage() {}
func (*PB) Descriptor() ([]byte, []int) { return fileDescriptorPb, []int{0} }
func (m *PB) GetCode() int64 {
if m != nil {
return m.Code
}
return 0
}
func (m *PB) GetMessage() string {
if m != nil {
return m.Message
}
return ""
}
func (m *PB) GetTTL() uint64 {
if m != nil {
return m.TTL
}
return 0
}
func (m *PB) GetData() *google_protobuf.Any {
if m != nil {
return m.Data
}
return nil
}
func init() {
proto.RegisterType((*PB)(nil), "render.PB")
}
func init() { proto.RegisterFile("pb.proto", fileDescriptorPb) }
var fileDescriptorPb = []byte{
// 154 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x28, 0x48, 0xd2, 0x2b,
0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2b, 0x4a, 0xcd, 0x4b, 0x49, 0x2d, 0x92, 0x92, 0x4c, 0xcf,
0xcf, 0x4f, 0xcf, 0x49, 0xd5, 0x07, 0x8b, 0x26, 0x95, 0xa6, 0xe9, 0x27, 0xe6, 0x55, 0x42, 0x94,
0x28, 0xe5, 0x71, 0x31, 0x05, 0x38, 0x09, 0x09, 0x71, 0xb1, 0x38, 0xe7, 0xa7, 0xa4, 0x4a, 0x30,
0x2a, 0x30, 0x6a, 0x30, 0x07, 0x81, 0xd9, 0x42, 0x12, 0x5c, 0xec, 0xbe, 0xa9, 0xc5, 0xc5, 0x89,
0xe9, 0xa9, 0x12, 0x4c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x30, 0xae, 0x90, 0x00, 0x17, 0x73, 0x48,
0x88, 0x8f, 0x04, 0xb3, 0x02, 0xa3, 0x06, 0x4b, 0x10, 0x88, 0x29, 0xa4, 0xc1, 0xc5, 0xe2, 0x92,
0x58, 0x92, 0x28, 0xc1, 0xa2, 0xc0, 0xa8, 0xc1, 0x6d, 0x24, 0xa2, 0x07, 0xb1, 0x4f, 0x0f, 0x66,
0x9f, 0x9e, 0x63, 0x5e, 0x65, 0x10, 0x58, 0x45, 0x12, 0x1b, 0x58, 0xcc, 0x18, 0x10, 0x00, 0x00,
0xff, 0xff, 0x7a, 0x92, 0x16, 0x71, 0xa5, 0x00, 0x00, 0x00,
}

@ -0,0 +1,14 @@
// use under command to generate pb.pb.go
// protoc --proto_path=.:$GOPATH/src/github.com/gogo/protobuf --gogo_out=Mgoogle/protobuf/any.proto=github.com/gogo/protobuf/types:. *.proto
syntax = "proto3";
package render;
import "google/protobuf/any.proto";
import "github.com/gogo/protobuf/gogoproto/gogo.proto";
message PB {
int64 Code = 1;
string Message = 2;
uint64 TTL = 3;
google.protobuf.Any Data = 4;
}

@ -0,0 +1,40 @@
package render
import (
"fmt"
"io"
"net/http"
"github.com/pkg/errors"
)
var plainContentType = []string{"text/plain; charset=utf-8"}
// String common string struct.
type String struct {
Format string
Data []interface{}
}
// Render (String) writes data with custom ContentType.
func (r String) Render(w http.ResponseWriter) error {
return writeString(w, r.Format, r.Data)
}
// WriteContentType writes string with text/plain ContentType.
func (r String) WriteContentType(w http.ResponseWriter) {
writeContentType(w, plainContentType)
}
func writeString(w http.ResponseWriter, format string, data []interface{}) (err error) {
writeContentType(w, plainContentType)
if len(data) > 0 {
_, err = fmt.Fprintf(w, format, data...)
} else {
_, err = io.WriteString(w, format)
}
if err != nil {
err = errors.WithStack(err)
}
return
}

@ -0,0 +1,31 @@
package render
import (
"encoding/xml"
"net/http"
"github.com/pkg/errors"
)
// XML common xml struct.
type XML struct {
Code int
Message string
Data interface{}
}
var xmlContentType = []string{"application/xml; charset=utf-8"}
// Render (XML) writes data with xml ContentType.
func (r XML) Render(w http.ResponseWriter) (err error) {
r.WriteContentType(w)
if err = xml.NewEncoder(w).Encode(r.Data); err != nil {
err = errors.WithStack(err)
}
return
}
// WriteContentType write xml ContentType.
func (r XML) WriteContentType(w http.ResponseWriter) {
writeContentType(w, xmlContentType)
}

@ -0,0 +1,166 @@
package blademaster
import (
"regexp"
)
// IRouter http router framework interface.
type IRouter interface {
IRoutes
Group(string, ...HandlerFunc) *RouterGroup
}
// IRoutes http router interface.
type IRoutes interface {
UseFunc(...HandlerFunc) IRoutes
Use(...Handler) IRoutes
Handle(string, string, ...HandlerFunc) IRoutes
HEAD(string, ...HandlerFunc) IRoutes
GET(string, ...HandlerFunc) IRoutes
POST(string, ...HandlerFunc) IRoutes
PUT(string, ...HandlerFunc) IRoutes
DELETE(string, ...HandlerFunc) IRoutes
}
// RouterGroup is used internally to configure router, a RouterGroup is associated with a prefix
// and an array of handlers (middleware).
type RouterGroup struct {
Handlers []HandlerFunc
basePath string
engine *Engine
root bool
baseConfig *MethodConfig
}
var _ IRouter = &RouterGroup{}
// Use adds middleware to the group, see example code in doc.
func (group *RouterGroup) Use(middleware ...Handler) IRoutes {
for _, m := range middleware {
group.Handlers = append(group.Handlers, m.ServeHTTP)
}
return group.returnObj()
}
// UseFunc adds middleware to the group, see example code in doc.
func (group *RouterGroup) UseFunc(middleware ...HandlerFunc) IRoutes {
group.Handlers = append(group.Handlers, middleware...)
return group.returnObj()
}
// Group creates a new router group. You should add all the routes that have common middlwares or the same path prefix.
// For example, all the routes that use a common middlware for authorization could be grouped.
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup {
return &RouterGroup{
Handlers: group.combineHandlers(handlers),
basePath: group.calculateAbsolutePath(relativePath),
engine: group.engine,
root: false,
}
}
// SetMethodConfig is used to set config on specified method
func (group *RouterGroup) SetMethodConfig(config *MethodConfig) *RouterGroup {
group.baseConfig = config
return group
}
// BasePath router group base path.
func (group *RouterGroup) BasePath() string {
return group.basePath
}
func (group *RouterGroup) handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes {
absolutePath := group.calculateAbsolutePath(relativePath)
injections := group.injections(relativePath)
handlers = group.combineHandlers(injections, handlers)
group.engine.addRoute(httpMethod, absolutePath, handlers...)
if group.baseConfig != nil {
group.engine.SetMethodConfig(absolutePath, group.baseConfig)
}
return group.returnObj()
}
// Handle registers a new request handle and middleware with the given path and method.
// The last handler should be the real handler, the other ones should be middleware that can and should be shared among different routes.
// See the example code in doc.
//
// For HEAD, GET, POST, PUT, and DELETE requests the respective shortcut
// functions can be used.
//
// This function is intended for bulk loading and to allow the usage of less
// frequently used, non-standardized or custom methods (e.g. for internal
// communication with a proxy).
func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes {
if matches, err := regexp.MatchString("^[A-Z]+$", httpMethod); !matches || err != nil {
panic("http method " + httpMethod + " is not valid")
}
return group.handle(httpMethod, relativePath, handlers...)
}
// HEAD is a shortcut for router.Handle("HEAD", path, handle).
func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle("HEAD", relativePath, handlers...)
}
// GET is a shortcut for router.Handle("GET", path, handle).
func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle("GET", relativePath, handlers...)
}
// POST is a shortcut for router.Handle("POST", path, handle).
func (group *RouterGroup) POST(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle("POST", relativePath, handlers...)
}
// PUT is a shortcut for router.Handle("PUT", path, handle).
func (group *RouterGroup) PUT(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle("PUT", relativePath, handlers...)
}
// DELETE is a shortcut for router.Handle("DELETE", path, handle).
func (group *RouterGroup) DELETE(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle("DELETE", relativePath, handlers...)
}
func (group *RouterGroup) combineHandlers(handlerGroups ...[]HandlerFunc) []HandlerFunc {
finalSize := len(group.Handlers)
for _, handlers := range handlerGroups {
finalSize += len(handlers)
}
if finalSize >= int(_abortIndex) {
panic("too many handlers")
}
mergedHandlers := make([]HandlerFunc, finalSize)
copy(mergedHandlers, group.Handlers)
position := len(group.Handlers)
for _, handlers := range handlerGroups {
copy(mergedHandlers[position:], handlers)
position += len(handlers)
}
return mergedHandlers
}
func (group *RouterGroup) calculateAbsolutePath(relativePath string) string {
return joinPaths(group.basePath, relativePath)
}
func (group *RouterGroup) returnObj() IRoutes {
if group.root {
return group.engine
}
return group
}
// injections is
func (group *RouterGroup) injections(relativePath string) []HandlerFunc {
absPath := group.calculateAbsolutePath(relativePath)
for _, injection := range group.engine.injections {
if !injection.pattern.MatchString(absPath) {
continue
}
return injection.handlers
}
return nil
}

@ -0,0 +1,405 @@
package blademaster
import (
"context"
"flag"
"fmt"
"net"
"net/http"
"os"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/bilibili/kratos/pkg/conf/dsn"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/net/ip"
"github.com/bilibili/kratos/pkg/net/metadata"
"github.com/bilibili/kratos/pkg/stat"
xtime "github.com/bilibili/kratos/pkg/time"
"github.com/pkg/errors"
)
const (
defaultMaxMemory = 32 << 20 // 32 MB
)
var (
_ IRouter = &Engine{}
stats = stat.HTTPServer
_httpDSN string
)
func init() {
addFlag(flag.CommandLine)
}
func addFlag(fs *flag.FlagSet) {
v := os.Getenv("HTTP")
if v == "" {
v = "tcp://0.0.0.0:8000/?timeout=1s"
}
fs.StringVar(&_httpDSN, "http", v, "listen http dsn, or use HTTP env variable.")
}
func parseDSN(rawdsn string) *ServerConfig {
conf := new(ServerConfig)
d, err := dsn.Parse(rawdsn)
if err != nil {
panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn))
}
if _, err = d.Bind(conf); err != nil {
panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn))
}
return conf
}
// Handler responds to an HTTP request.
type Handler interface {
ServeHTTP(c *Context)
}
// HandlerFunc http request handler function.
type HandlerFunc func(*Context)
// ServeHTTP calls f(ctx).
func (f HandlerFunc) ServeHTTP(c *Context) {
f(c)
}
// ServerConfig is the bm server config model
type ServerConfig struct {
Network string `dsn:"network"`
Addr string `dsn:"address"`
Timeout xtime.Duration `dsn:"query.timeout"`
ReadTimeout xtime.Duration `dsn:"query.readTimeout"`
WriteTimeout xtime.Duration `dsn:"query.writeTimeout"`
}
// MethodConfig is
type MethodConfig struct {
Timeout xtime.Duration
}
// Start listen and serve bm engine by given DSN.
func (engine *Engine) Start() error {
conf := engine.conf
l, err := net.Listen(conf.Network, conf.Addr)
if err != nil {
errors.Wrapf(err, "blademaster: listen tcp: %s", conf.Addr)
return err
}
log.Info("blademaster: start http listen addr: %s", conf.Addr)
server := &http.Server{
ReadTimeout: time.Duration(conf.ReadTimeout),
WriteTimeout: time.Duration(conf.WriteTimeout),
}
go func() {
if err := engine.RunServer(server, l); err != nil {
if errors.Cause(err) == http.ErrServerClosed {
log.Info("blademaster: server closed")
return
}
panic(errors.Wrapf(err, "blademaster: engine.ListenServer(%+v, %+v)", server, l))
}
}()
return nil
}
// Engine is the framework's instance, it contains the muxer, middleware and configuration settings.
// Create an instance of Engine, by using New() or Default()
type Engine struct {
RouterGroup
lock sync.RWMutex
conf *ServerConfig
address string
mux *http.ServeMux // http mux router
server atomic.Value // store *http.Server
metastore map[string]map[string]interface{} // metastore is the path as key and the metadata of this path as value, it export via /metadata
pcLock sync.RWMutex
methodConfigs map[string]*MethodConfig
injections []injection
}
type injection struct {
pattern *regexp.Regexp
handlers []HandlerFunc
}
// NewServer returns a new blank Engine instance without any middleware attached.
func NewServer(conf *ServerConfig) *Engine {
if conf == nil {
if !flag.Parsed() {
fmt.Fprint(os.Stderr, "[blademaster] please call flag.Parse() before Init warden server, some configure may not effect.\n")
}
conf = parseDSN(_httpDSN)
}
engine := &Engine{
RouterGroup: RouterGroup{
Handlers: nil,
basePath: "/",
root: true,
},
address: ip.InternalIP(),
mux: http.NewServeMux(),
metastore: make(map[string]map[string]interface{}),
methodConfigs: make(map[string]*MethodConfig),
}
if err := engine.SetConfig(conf); err != nil {
panic(err)
}
engine.RouterGroup.engine = engine
// NOTE add prometheus monitor location
engine.addRoute("GET", "/metrics", monitor())
engine.addRoute("GET", "/metadata", engine.metadata())
startPerf()
return engine
}
// SetMethodConfig is used to set config on specified path
func (engine *Engine) SetMethodConfig(path string, mc *MethodConfig) {
engine.pcLock.Lock()
engine.methodConfigs[path] = mc
engine.pcLock.Unlock()
}
// DefaultServer returns an Engine instance with the Recovery, Logger and CSRF middleware already attached.
func DefaultServer(conf *ServerConfig) *Engine {
engine := NewServer(conf)
engine.Use(Recovery(), Trace(), Logger())
return engine
}
func (engine *Engine) addRoute(method, path string, handlers ...HandlerFunc) {
if path[0] != '/' {
panic("blademaster: path must begin with '/'")
}
if method == "" {
panic("blademaster: HTTP method can not be empty")
}
if len(handlers) == 0 {
panic("blademaster: there must be at least one handler")
}
if _, ok := engine.metastore[path]; !ok {
engine.metastore[path] = make(map[string]interface{})
}
engine.metastore[path]["method"] = method
engine.mux.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) {
c := &Context{
Context: nil,
engine: engine,
index: -1,
handlers: nil,
Keys: nil,
method: "",
Error: nil,
}
c.Request = req
c.Writer = w
c.handlers = handlers
c.method = method
engine.handleContext(c)
})
}
// SetConfig is used to set the engine configuration.
// Only the valid config will be loaded.
func (engine *Engine) SetConfig(conf *ServerConfig) (err error) {
if conf.Timeout <= 0 {
return errors.New("blademaster: config timeout must greater than 0")
}
if conf.Network == "" {
conf.Network = "tcp"
}
engine.lock.Lock()
engine.conf = conf
engine.lock.Unlock()
return
}
func (engine *Engine) methodConfig(path string) *MethodConfig {
engine.pcLock.RLock()
mc := engine.methodConfigs[path]
engine.pcLock.RUnlock()
return mc
}
func (engine *Engine) handleContext(c *Context) {
var cancel func()
req := c.Request
ctype := req.Header.Get("Content-Type")
switch {
case strings.Contains(ctype, "multipart/form-data"):
req.ParseMultipartForm(defaultMaxMemory)
default:
req.ParseForm()
}
// get derived timeout from http request header,
// compare with the engine configured,
// and use the minimum one
engine.lock.RLock()
tm := time.Duration(engine.conf.Timeout)
engine.lock.RUnlock()
// the method config is preferred
if pc := engine.methodConfig(c.Request.URL.Path); pc != nil {
tm = time.Duration(pc.Timeout)
}
if ctm := timeout(req); ctm > 0 && tm > ctm {
tm = ctm
}
md := metadata.MD{
metadata.Color: color(req),
metadata.RemoteIP: remoteIP(req),
metadata.RemotePort: remotePort(req),
metadata.Caller: caller(req),
metadata.Mirror: mirror(req),
}
ctx := metadata.NewContext(context.Background(), md)
if tm > 0 {
c.Context, cancel = context.WithTimeout(ctx, tm)
} else {
c.Context, cancel = context.WithCancel(ctx)
}
defer cancel()
c.Next()
}
// Router return a http.Handler for using http.ListenAndServe() directly.
func (engine *Engine) Router() http.Handler {
return engine.mux
}
// Server is used to load stored http server.
func (engine *Engine) Server() *http.Server {
s, ok := engine.server.Load().(*http.Server)
if !ok {
return nil
}
return s
}
// Shutdown the http server without interrupting active connections.
func (engine *Engine) Shutdown(ctx context.Context) error {
server := engine.Server()
if server == nil {
return errors.New("blademaster: no server")
}
return errors.WithStack(server.Shutdown(ctx))
}
// UseFunc attachs a global middleware to the router. ie. the middleware attached though UseFunc() will be
// included in the handlers chain for every single request. Even 404, 405, static files...
// For example, this is the right place for a logger or error management middleware.
func (engine *Engine) UseFunc(middleware ...HandlerFunc) IRoutes {
engine.RouterGroup.UseFunc(middleware...)
return engine
}
// Use attachs a global middleware to the router. ie. the middleware attached though Use() will be
// included in the handlers chain for every single request. Even 404, 405, static files...
// For example, this is the right place for a logger or error management middleware.
func (engine *Engine) Use(middleware ...Handler) IRoutes {
engine.RouterGroup.Use(middleware...)
return engine
}
// Ping is used to set the general HTTP ping handler.
func (engine *Engine) Ping(handler HandlerFunc) {
engine.GET("/ping", handler)
}
// Register is used to export metadata to discovery.
func (engine *Engine) Register(handler HandlerFunc) {
engine.GET("/register", handler)
}
// Run attaches the router to a http.Server and starts listening and serving HTTP requests.
// It is a shortcut for http.ListenAndServe(addr, router)
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func (engine *Engine) Run(addr ...string) (err error) {
address := resolveAddress(addr)
server := &http.Server{
Addr: address,
Handler: engine.mux,
}
engine.server.Store(server)
if err = server.ListenAndServe(); err != nil {
err = errors.Wrapf(err, "addrs: %v", addr)
}
return
}
// RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests.
// It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router)
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) {
server := &http.Server{
Addr: addr,
Handler: engine.mux,
}
engine.server.Store(server)
if err = server.ListenAndServeTLS(certFile, keyFile); err != nil {
err = errors.Wrapf(err, "tls: %s/%s:%s", addr, certFile, keyFile)
}
return
}
// RunUnix attaches the router to a http.Server and starts listening and serving HTTP requests
// through the specified unix socket (ie. a file).
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func (engine *Engine) RunUnix(file string) (err error) {
os.Remove(file)
listener, err := net.Listen("unix", file)
if err != nil {
err = errors.Wrapf(err, "unix: %s", file)
return
}
defer listener.Close()
server := &http.Server{
Handler: engine.mux,
}
engine.server.Store(server)
if err = server.Serve(listener); err != nil {
err = errors.Wrapf(err, "unix: %s", file)
}
return
}
// RunServer will serve and start listening HTTP requests by given server and listener.
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func (engine *Engine) RunServer(server *http.Server, l net.Listener) (err error) {
server.Handler = engine.mux
engine.server.Store(server)
if err = server.Serve(l); err != nil {
err = errors.Wrapf(err, "listen server: %+v/%+v", server, l)
return
}
return
}
func (engine *Engine) metadata() HandlerFunc {
return func(c *Context) {
c.JSON(engine.metastore, nil)
}
}
// Inject is
func (engine *Engine) Inject(pattern string, handlers ...HandlerFunc) {
engine.injections = append(engine.injections, injection{
pattern: regexp.MustCompile(pattern),
handlers: handlers,
})
}

@ -4,12 +4,42 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
"strconv"
"github.com/bilibili/kratos/pkg/net/metadata"
"github.com/bilibili/kratos/pkg/net/trace" "github.com/bilibili/kratos/pkg/net/trace"
) )
const _defaultComponentName = "net/http" const _defaultComponentName = "net/http"
// Trace is trace middleware
func Trace() HandlerFunc {
return func(c *Context) {
// handle http request
// get derived trace from http request header
t, err := trace.Extract(trace.HTTPFormat, c.Request.Header)
if err != nil {
var opts []trace.Option
if ok, _ := strconv.ParseBool(trace.KratosTraceDebug); ok {
opts = append(opts, trace.EnableDebug())
}
t = trace.New(c.Request.URL.Path, opts...)
}
t.SetTitle(c.Request.URL.Path)
t.SetTag(trace.String(trace.TagComponent, _defaultComponentName))
t.SetTag(trace.String(trace.TagHTTPMethod, c.Request.Method))
t.SetTag(trace.String(trace.TagHTTPURL, c.Request.URL.String()))
t.SetTag(trace.String(trace.TagSpanKind, "server"))
// business tag
t.SetTag(trace.String("caller", metadata.String(c.Context, metadata.Caller)))
// export trace id to user.
c.Writer.Header().Set(trace.KratosTraceID, t.TraceID())
c.Context = trace.NewContext(c.Context, t)
c.Next()
t.Finish(&c.Error)
}
}
type closeTracker struct { type closeTracker struct {
io.ReadCloser io.ReadCloser
tr trace.Trace tr trace.Trace

@ -0,0 +1,40 @@
package blademaster
import (
"os"
"path"
)
func lastChar(str string) uint8 {
if str == "" {
panic("The length of the string can't be 0")
}
return str[len(str)-1]
}
func joinPaths(absolutePath, relativePath string) string {
if relativePath == "" {
return absolutePath
}
finalPath := path.Join(absolutePath, relativePath)
appendSlash := lastChar(relativePath) == '/' && lastChar(finalPath) != '/'
if appendSlash {
return finalPath + "/"
}
return finalPath
}
func resolveAddress(addr []string) string {
switch len(addr) {
case 0:
if port := os.Getenv("PORT"); port != "" {
return ":" + port
}
return ":8080"
case 1:
return addr[0]
default:
panic("too much parameters")
}
}

@ -0,0 +1,84 @@
package mocktrace
import (
"github.com/bilibili/kratos/pkg/net/trace"
)
// MockTrace .
type MockTrace struct {
Spans []*MockSpan
}
// New .
func (m *MockTrace) New(operationName string, opts ...trace.Option) trace.Trace {
span := &MockSpan{OperationName: operationName, MockTrace: m}
m.Spans = append(m.Spans, span)
return span
}
// Inject .
func (m *MockTrace) Inject(t trace.Trace, format interface{}, carrier interface{}) error {
return nil
}
// Extract .
func (m *MockTrace) Extract(format interface{}, carrier interface{}) (trace.Trace, error) {
return &MockSpan{}, nil
}
// MockSpan .
type MockSpan struct {
*MockTrace
OperationName string
FinishErr error
Finished bool
Tags []trace.Tag
Logs []trace.LogField
}
// Fork .
func (m *MockSpan) Fork(serviceName string, operationName string) trace.Trace {
span := &MockSpan{OperationName: operationName, MockTrace: m.MockTrace}
m.Spans = append(m.Spans, span)
return span
}
// Follow .
func (m *MockSpan) Follow(serviceName string, operationName string) trace.Trace {
span := &MockSpan{OperationName: operationName, MockTrace: m.MockTrace}
m.Spans = append(m.Spans, span)
return span
}
// Finish .
func (m *MockSpan) Finish(perr *error) {
if perr != nil {
m.FinishErr = *perr
}
m.Finished = true
}
// SetTag .
func (m *MockSpan) SetTag(tags ...trace.Tag) trace.Trace {
m.Tags = append(m.Tags, tags...)
return m
}
// SetLog .
func (m *MockSpan) SetLog(logs ...trace.LogField) trace.Trace {
m.Logs = append(m.Logs, logs...)
return m
}
// Visit .
func (m *MockSpan) Visit(fn func(k, v string)) {}
// SetTitle .
func (m *MockSpan) SetTitle(title string) {
m.OperationName = title
}
// TraceID .
func (m *MockSpan) TraceID() string {
return ""
}

@ -0,0 +1,24 @@
// Package mocktrace this ut just make ci happay.
package mocktrace
import (
"fmt"
"testing"
)
func TestMockTrace(t *testing.T) {
mocktrace := &MockTrace{}
mocktrace.Inject(nil, nil, nil)
mocktrace.Extract(nil, nil)
root := mocktrace.New("test")
root.Fork("", "")
root.Follow("", "")
root.Finish(nil)
err := fmt.Errorf("test")
root.Finish(&err)
root.SetTag()
root.SetLog()
root.Visit(func(k, v string) {})
root.SetTitle("")
}

@ -20,6 +20,8 @@ func (n nooptracer) Extract(format interface{}, carrier interface{}) (Trace, err
type noopspan struct{} type noopspan struct{}
func (n noopspan) TraceID() string { return "" }
func (n noopspan) Fork(string, string) Trace { func (n noopspan) Fork(string, string) Trace {
return noopspan{} return noopspan{}
} }

@ -10,7 +10,7 @@ const (
slotLength = 2048 slotLength = 2048
) )
var ignoreds = []string{"/metrics", "/monitor/ping"} var ignoreds = []string{"/metrics", "/ping"} // NOTE: add YOUR URL PATH that want ignore
func init() { func init() {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())

@ -26,6 +26,10 @@ type span struct {
childs int childs int
} }
func (s *span) TraceID() string {
return s.context.String()
}
func (s *span) Fork(serviceName, operationName string) Trace { func (s *span) Fork(serviceName, operationName string) Trace {
if s.childs > _maxChilds { if s.childs > _maxChilds {
// if child span more than max childs set return noopspan // if child span more than max childs set return noopspan

@ -55,6 +55,8 @@ func Close() error {
// Trace trace common interface. // Trace trace common interface.
type Trace interface { type Trace interface {
// return current trace id.
TraceID() string
// Fork fork a trace with client trace. // Fork fork a trace with client trace.
Fork(serviceName, operationName string) Trace Fork(serviceName, operationName string) Trace

@ -8,7 +8,6 @@ import (
"github.com/bilibili/kratos/pkg/conf/env" "github.com/bilibili/kratos/pkg/conf/env"
"github.com/bilibili/kratos/pkg/net/ip" "github.com/bilibili/kratos/pkg/net/ip"
"github.com/bilibili/kratos/pkg/net/metadata"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -49,10 +48,6 @@ var _ctxkey ctxKey = "kratos/pkg/net/trace.trace"
// FromContext returns the trace bound to the context, if any. // FromContext returns the trace bound to the context, if any.
func FromContext(ctx context.Context) (t Trace, ok bool) { func FromContext(ctx context.Context) (t Trace, ok bool) {
if v := metadata.Value(ctx, metadata.Trace); v != nil {
t, ok = v.(Trace)
return
}
t, ok = ctx.Value(_ctxkey).(Trace) t, ok = ctx.Value(_ctxkey).(Trace)
return return
} }
@ -60,9 +55,5 @@ func FromContext(ctx context.Context) (t Trace, ok bool) {
// NewContext new a trace context. // NewContext new a trace context.
// NOTE: This method is not thread safe. // NOTE: This method is not thread safe.
func NewContext(ctx context.Context, t Trace) context.Context { func NewContext(ctx context.Context, t Trace) context.Context {
if md, ok := metadata.FromContext(ctx); ok {
md[metadata.Trace] = t
return ctx
}
return context.WithValue(ctx, _ctxkey, t) return context.WithValue(ctx, _ctxkey, t)
} }

@ -1,11 +1,11 @@
# kratos # kratos
## 项目简介 ## 项目简介
Kratos 工具 kratos 工具
## 安装 ## 安装
`go get -u github.com/bilibili/Kratos/tool/kratos` `go get -u github.com/bilibili/kratos/tool/kratos`
## 使用说明 ## 使用说明

Loading…
Cancel
Save