add database pkg (#8)

* add database pkg
pull/13/head
Sam 6 years ago committed by Felix Hao
parent 7c6e0ea7ba
commit 1efe0a084e
  1. 10
      go.mod
  2. 44
      pkg/database/hbase/README.md
  3. 23
      pkg/database/hbase/config.go
  4. 297
      pkg/database/hbase/hbase.go
  5. 48
      pkg/database/hbase/metrics.go
  6. 24
      pkg/database/hbase/slowlog.go
  7. 40
      pkg/database/hbase/trace.go
  8. 12
      pkg/database/sql/README.md
  9. 40
      pkg/database/sql/mysql.go
  10. 678
      pkg/database/sql/sql.go
  11. 17
      pkg/database/tidb/README.md
  12. 58
      pkg/database/tidb/discovery.go
  13. 82
      pkg/database/tidb/node_proc.go
  14. 739
      pkg/database/tidb/sql.go
  15. 38
      pkg/database/tidb/tidb.go

@ -2,18 +2,28 @@ module github.com/bilibili/Kratos
require (
github.com/BurntSushi/toml v0.3.1
github.com/aristanetworks/goarista v0.0.0-20190325233358-a123909ec740 // indirect
github.com/cznic/b v0.0.0-20181122101859-a26611c4d92d // indirect
github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 // indirect
github.com/cznic/strutil v0.0.0-20181122101858-275e90344537 // indirect
github.com/fatih/color v1.7.0
github.com/fsnotify/fsnotify v1.4.7
github.com/go-playground/locales v0.12.1 // indirect
github.com/go-playground/universal-translator v0.16.0 // indirect
github.com/go-sql-driver/mysql v1.4.1
github.com/gogo/protobuf v1.2.0
github.com/golang/protobuf v1.2.0
github.com/kr/pty v1.1.4
github.com/leodido/go-urn v1.1.0 // indirect
github.com/pkg/errors v0.8.1
github.com/prometheus/client_golang v0.9.2
github.com/remyoudompheng/bigfft v0.0.0-20190321074620-2f0d2b0e0001 // indirect
github.com/samuel/go-zookeeper v0.0.0-20180130194729-c4fab1ac1bec // indirect
github.com/sirupsen/logrus v1.4.1 // indirect
github.com/stretchr/testify v1.3.0
github.com/tsuna/gohbase v0.0.0-20190201102810-d3184c1526df
github.com/urfave/cli v1.20.0
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 // indirect
google.golang.org/grpc v1.18.0
gopkg.in/AlecAivazis/survey.v1 v1.8.2
gopkg.in/go-playground/assert.v1 v1.2.1 // indirect

@ -0,0 +1,44 @@
### database/hbase
### 项目简介
Hbase Client,进行封装加入了链路追踪和统计。
### usage
```go
package main
import (
"context"
"fmt"
"github.com/bilibili/Kratos/pkg/database/hbase"
)
func main() {
config := &hbase.Config{Zookeeper: &hbase.ZKConfig{Addrs: []string{"localhost"}}}
client := hbase.NewClient(config)
values := map[string]map[string][]byte{"name": {"firstname": []byte("hello"), "lastname": []byte("world")}}
ctx := context.Background()
_, err := client.PutStr(ctx, "user", "user1", values)
if err != nil {
panic(err)
}
result, err := client.GetStr(ctx, "user", "user1")
if err != nil {
panic(err)
}
fmt.Printf("%v", result)
}
```
##### 编译环境
> 请只用golang v1.8.x以上版本编译执行。
##### 依赖包
> 1.[gohbase](https://github.com/tsuna/gohbase)

@ -0,0 +1,23 @@
package hbase
import (
xtime "github.com/bilibili/Kratos/pkg/time"
)
// ZKConfig Server&Client settings.
type ZKConfig struct {
Root string
Addrs []string
Timeout xtime.Duration
}
// Config hbase config
type Config struct {
Zookeeper *ZKConfig
RPCQueueSize int
FlushInterval xtime.Duration
EffectiveUser string
RegionLookupTimeout xtime.Duration
RegionReadTimeout xtime.Duration
TestRowKey string
}

@ -0,0 +1,297 @@
package hbase
import (
"context"
"io"
"strings"
"time"
"github.com/tsuna/gohbase"
"github.com/tsuna/gohbase/hrpc"
"github.com/bilibili/Kratos/pkg/log"
)
// HookFunc hook function call before every method and hook return function will call after finish.
type HookFunc func(ctx context.Context, call hrpc.Call, customName string) func(err error)
// Client hbase client.
type Client struct {
hc gohbase.Client
addr string
config *Config
hooks []HookFunc
}
// AddHook add hook function.
func (c *Client) AddHook(hookFn HookFunc) {
c.hooks = append(c.hooks, hookFn)
}
func (c *Client) invokeHook(ctx context.Context, call hrpc.Call, customName string) func(error) {
finishHooks := make([]func(error), 0, len(c.hooks))
for _, fn := range c.hooks {
finishHooks = append(finishHooks, fn(ctx, call, customName))
}
return func(err error) {
for _, fn := range finishHooks {
fn(err)
}
}
}
// NewClient new a hbase client.
func NewClient(config *Config, options ...gohbase.Option) *Client {
rawcli := NewRawClient(config, options...)
rawcli.AddHook(NewSlowLogHook(250 * time.Millisecond))
rawcli.AddHook(MetricsHook(nil))
rawcli.AddHook(TraceHook("database/hbase", strings.Join(config.Zookeeper.Addrs, ",")))
return rawcli
}
// NewRawClient new a hbase client without prometheus metrics and dapper trace hook.
func NewRawClient(config *Config, options ...gohbase.Option) *Client {
zkquorum := strings.Join(config.Zookeeper.Addrs, ",")
if config.Zookeeper.Root != "" {
options = append(options, gohbase.ZookeeperRoot(config.Zookeeper.Root))
}
if config.Zookeeper.Timeout != 0 {
options = append(options, gohbase.ZookeeperTimeout(time.Duration(config.Zookeeper.Timeout)))
}
if config.RPCQueueSize != 0 {
log.Warn("RPCQueueSize configuration be ignored")
}
// force RpcQueueSize = 1, don't change it !!! it has reason (゜-゜)つロ
options = append(options, gohbase.RpcQueueSize(1))
if config.FlushInterval != 0 {
options = append(options, gohbase.FlushInterval(time.Duration(config.FlushInterval)))
}
if config.EffectiveUser != "" {
options = append(options, gohbase.EffectiveUser(config.EffectiveUser))
}
if config.RegionLookupTimeout != 0 {
options = append(options, gohbase.RegionLookupTimeout(time.Duration(config.RegionLookupTimeout)))
}
if config.RegionReadTimeout != 0 {
options = append(options, gohbase.RegionReadTimeout(time.Duration(config.RegionReadTimeout)))
}
hc := gohbase.NewClient(zkquorum, options...)
return &Client{
hc: hc,
addr: zkquorum,
config: config,
}
}
// ScanAll do scan command and return all result
// NOTE: if err != nil the results is safe for range operate even not result found
func (c *Client) ScanAll(ctx context.Context, table []byte, options ...func(hrpc.Call) error) (results []*hrpc.Result, err error) {
cursor, err := c.Scan(ctx, table, options...)
if err != nil {
return nil, err
}
for {
result, err := cursor.Next()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
results = append(results, result)
}
return results, nil
}
type scanTrace struct {
hrpc.Scanner
finishHook func(error)
}
func (s *scanTrace) Next() (*hrpc.Result, error) {
result, err := s.Scanner.Next()
if err != nil {
s.finishHook(err)
}
return result, err
}
func (s *scanTrace) Close() error {
err := s.Scanner.Close()
s.finishHook(err)
return err
}
// Scan do a scan command.
func (c *Client) Scan(ctx context.Context, table []byte, options ...func(hrpc.Call) error) (scanner hrpc.Scanner, err error) {
var scan *hrpc.Scan
scan, err = hrpc.NewScan(ctx, table, options...)
if err != nil {
return nil, err
}
st := &scanTrace{}
st.finishHook = c.invokeHook(ctx, scan, "Scan")
st.Scanner = c.hc.Scan(scan)
return st, nil
}
// ScanStr scan string
func (c *Client) ScanStr(ctx context.Context, table string, options ...func(hrpc.Call) error) (hrpc.Scanner, error) {
return c.Scan(ctx, []byte(table), options...)
}
// ScanStrAll scan string
// NOTE: if err != nil the results is safe for range operate even not result found
func (c *Client) ScanStrAll(ctx context.Context, table string, options ...func(hrpc.Call) error) ([]*hrpc.Result, error) {
return c.ScanAll(ctx, []byte(table), options...)
}
// ScanRange get a scanner for the given table and key range.
// The range is half-open, i.e. [startRow; stopRow[ -- stopRow is not
// included in the range.
func (c *Client) ScanRange(ctx context.Context, table, startRow, stopRow []byte, options ...func(hrpc.Call) error) (scanner hrpc.Scanner, err error) {
var scan *hrpc.Scan
scan, err = hrpc.NewScanRange(ctx, table, startRow, stopRow, options...)
if err != nil {
return nil, err
}
st := &scanTrace{}
st.finishHook = c.invokeHook(ctx, scan, "ScanRange")
st.Scanner = c.hc.Scan(scan)
return st, nil
}
// ScanRangeStr get a scanner for the given table and key range.
// The range is half-open, i.e. [startRow; stopRow[ -- stopRow is not
// included in the range.
func (c *Client) ScanRangeStr(ctx context.Context, table, startRow, stopRow string, options ...func(hrpc.Call) error) (hrpc.Scanner, error) {
return c.ScanRange(ctx, []byte(table), []byte(startRow), []byte(stopRow), options...)
}
// Get get result for the given table and row key.
// NOTE: if err != nil then result != nil, if result not exists result.Cells length is 0
func (c *Client) Get(ctx context.Context, table, key []byte, options ...func(hrpc.Call) error) (result *hrpc.Result, err error) {
var get *hrpc.Get
get, err = hrpc.NewGet(ctx, table, key, options...)
if err != nil {
return nil, err
}
finishHook := c.invokeHook(ctx, get, "GET")
result, err = c.hc.Get(get)
finishHook(err)
return
}
// GetStr do a get command.
// NOTE: if err != nil then result != nil, if result not exists result.Cells length is 0
func (c *Client) GetStr(ctx context.Context, table, key string, options ...func(hrpc.Call) error) (result *hrpc.Result, err error) {
return c.Get(ctx, []byte(table), []byte(key), options...)
}
// PutStr insert the given family-column-values in the given row key of the given table.
func (c *Client) PutStr(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) {
put, err := hrpc.NewPutStr(ctx, table, key, values, options...)
if err != nil {
return nil, err
}
finishHook := c.invokeHook(ctx, put, "PUT")
result, err := c.hc.Put(put)
finishHook(err)
return result, err
}
// Delete is used to perform Delete operations on a single row.
// To delete entire row, values should be nil.
//
// To delete specific families, qualifiers map should be nil:
// map[string]map[string][]byte{
// "cf1": nil,
// "cf2": nil,
// }
//
// To delete specific qualifiers:
// map[string]map[string][]byte{
// "cf": map[string][]byte{
// "q1": nil,
// "q2": nil,
// },
// }
//
// To delete all versions before and at a timestamp, pass hrpc.Timestamp() option.
// By default all versions will be removed.
//
// To delete only a specific version at a timestamp, pass hrpc.DeleteOneVersion() option
// along with a timestamp. For delete specific qualifiers request, if timestamp is not
// passed, only the latest version will be removed. For delete specific families request,
// the timestamp should be passed or it will have no effect as it's an expensive
// operation to perform.
func (c *Client) Delete(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) {
del, err := hrpc.NewDelStr(ctx, table, key, values, options...)
if err != nil {
return nil, err
}
finishHook := c.invokeHook(ctx, del, "Delete")
result, err := c.hc.Delete(del)
finishHook(err)
return result, err
}
// Append do a append command.
func (c *Client) Append(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) {
appd, err := hrpc.NewAppStr(ctx, table, key, values, options...)
if err != nil {
return nil, err
}
finishHook := c.invokeHook(ctx, appd, "Append")
result, err := c.hc.Append(appd)
finishHook(err)
return result, err
}
// Increment the given values in HBase under the given table and key.
func (c *Client) Increment(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (int64, error) {
increment, err := hrpc.NewIncStr(ctx, table, key, values, options...)
if err != nil {
return 0, err
}
finishHook := c.invokeHook(ctx, increment, "Increment")
result, err := c.hc.Increment(increment)
finishHook(err)
return result, err
}
// IncrementSingle increment the given value by amount in HBase under the given table, key, family and qualifier.
func (c *Client) IncrementSingle(ctx context.Context, table string, key string, family string, qualifier string, amount int64, options ...func(hrpc.Call) error) (int64, error) {
increment, err := hrpc.NewIncStrSingle(ctx, table, key, family, qualifier, amount, options...)
if err != nil {
return 0, err
}
finishHook := c.invokeHook(ctx, increment, "IncrementSingle")
result, err := c.hc.Increment(increment)
finishHook(err)
return result, err
}
// Ping ping.
func (c *Client) Ping(ctx context.Context) (err error) {
testRowKey := "test"
if c.config.TestRowKey != "" {
testRowKey = c.config.TestRowKey
}
values := map[string]map[string][]byte{"test": map[string][]byte{"test": []byte("test")}}
_, err = c.PutStr(ctx, "test", testRowKey, values)
return
}
// Close close client.
func (c *Client) Close() error {
c.hc.Close()
return nil
}

@ -0,0 +1,48 @@
package hbase
import (
"context"
"io"
"time"
"github.com/tsuna/gohbase"
"github.com/tsuna/gohbase/hrpc"
"github.com/bilibili/Kratos/pkg/stat"
)
func codeFromErr(err error) string {
code := "unknown_error"
switch err {
case gohbase.ErrClientClosed:
code = "client_closed"
case gohbase.ErrConnotFindRegion:
code = "connot_find_region"
case gohbase.TableNotFound:
code = "table_not_found"
case gohbase.ErrRegionUnavailable:
code = "region_unavailable"
}
return code
}
// MetricsHook if stats is nil use stat.DB as default.
func MetricsHook(stats stat.Stat) HookFunc {
if stats == nil {
stats = stat.DB
}
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) {
now := time.Now()
if customName == "" {
customName = call.Name()
}
method := "hbase:" + customName
return func(err error) {
durationMs := int64(time.Since(now) / time.Millisecond)
stats.Timing(method, durationMs)
if err != nil && err != io.EOF {
stats.Incr(method, codeFromErr(err))
}
}
}
}

@ -0,0 +1,24 @@
package hbase
import (
"context"
"time"
"github.com/tsuna/gohbase/hrpc"
"github.com/bilibili/Kratos/pkg/log"
)
// NewSlowLogHook log slow operation.
func NewSlowLogHook(threshold time.Duration) HookFunc {
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) {
start := time.Now()
return func(error) {
duration := time.Since(start)
if duration < threshold {
return
}
log.Warn("hbase slow log: %s %s %s time: %s", customName, call.Table(), call.Key(), duration)
}
}
}

@ -0,0 +1,40 @@
package hbase
import (
"context"
"io"
"github.com/tsuna/gohbase/hrpc"
"github.com/bilibili/Kratos/pkg/net/trace"
)
// TraceHook create new hbase trace hook.
func TraceHook(component, instance string) HookFunc {
var internalTags []trace.Tag
internalTags = append(internalTags, trace.TagString(trace.TagComponent, component))
internalTags = append(internalTags, trace.TagString(trace.TagDBInstance, instance))
internalTags = append(internalTags, trace.TagString(trace.TagPeerService, "hbase"))
internalTags = append(internalTags, trace.TagString(trace.TagSpanKind, "client"))
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) {
noop := func(error) {}
root, ok := trace.FromContext(ctx)
if !ok {
return noop
}
if customName == "" {
customName = call.Name()
}
span := root.Fork("", "Hbase:"+customName)
span.SetTag(internalTags...)
statement := string(call.Table()) + " " + string(call.Key())
span.SetTag(trace.TagString(trace.TagDBStatement, statement))
return func(err error) {
if err == io.EOF {
// reset error for trace.
err = nil
}
span.Finish(&err)
}
}
}

@ -0,0 +1,12 @@
#### database/sql
##### 项目简介
MySQL数据库驱动,进行封装加入了链路追踪和统计。
如果需要SQL级别的超时管理 可以在业务代码里面使用context.WithDeadline实现 推荐超时配置放到application.toml里面 方便热加载
##### 编译环境
> 请只用golang v1.8.x以上版本编译执行。
##### 依赖包
> 1.[Go-MySQL-Driver](https://github.com/go-sql-driver/mysql)

@ -0,0 +1,40 @@
package sql
import (
"github.com/bilibili/Kratos/pkg/log"
"github.com/bilibili/Kratos/pkg/net/netutil/breaker"
"github.com/bilibili/Kratos/pkg/stat"
"github.com/bilibili/Kratos/pkg/time"
// database driver
_ "github.com/go-sql-driver/mysql"
)
var stats = stat.DB
// Config mysql config.
type Config struct {
Addr string // for trace
DSN string // write data source name.
ReadDSN []string // read data source name.
Active int // pool
Idle int // pool
IdleTimeout time.Duration // connect max life time.
QueryTimeout time.Duration // query sql timeout
ExecTimeout time.Duration // execute sql timeout
TranTimeout time.Duration // transaction sql timeout
Breaker *breaker.Config // breaker
}
// NewMySQL new db and retry connection when has error.
func NewMySQL(c *Config) (db *DB) {
if c.QueryTimeout == 0 || c.ExecTimeout == 0 || c.TranTimeout == 0 {
panic("mysql must be set query/execute/transction timeout")
}
db, err := Open(c)
if err != nil {
log.Error("open mysql error(%v)", err)
panic(err)
}
return
}

@ -0,0 +1,678 @@
package sql
import (
"context"
"database/sql"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/bilibili/Kratos/pkg/ecode"
"github.com/bilibili/Kratos/pkg/log"
"github.com/bilibili/Kratos/pkg/net/netutil/breaker"
"github.com/bilibili/Kratos/pkg/net/trace"
"github.com/pkg/errors"
)
const (
_family = "sql_client"
_slowLogDuration = time.Millisecond * 250
)
var (
// ErrStmtNil prepared stmt error
ErrStmtNil = errors.New("sql: prepare failed and stmt nil")
// ErrNoMaster is returned by Master when call master multiple times.
ErrNoMaster = errors.New("sql: no master instance")
// ErrNoRows is returned by Scan when QueryRow doesn't return a row.
// In such a case, QueryRow returns a placeholder *Row value that defers
// this error until a Scan.
ErrNoRows = sql.ErrNoRows
// ErrTxDone transaction done.
ErrTxDone = sql.ErrTxDone
)
// DB database.
type DB struct {
write *conn
read []*conn
idx int64
master *DB
}
// conn database connection
type conn struct {
*sql.DB
breaker breaker.Breaker
conf *Config
}
// Tx transaction.
type Tx struct {
db *conn
tx *sql.Tx
t trace.Trace
c context.Context
cancel func()
}
// Row row.
type Row struct {
err error
*sql.Row
db *conn
query string
args []interface{}
t trace.Trace
cancel func()
}
// Scan copies the columns from the matched row into the values pointed at by dest.
func (r *Row) Scan(dest ...interface{}) (err error) {
defer slowLog(fmt.Sprintf("Scan query(%s) args(%+v)", r.query, r.args), time.Now())
if r.t != nil {
defer r.t.Finish(&err)
}
if r.err != nil {
err = r.err
} else if r.Row == nil {
err = ErrStmtNil
}
if err != nil {
return
}
err = r.Row.Scan(dest...)
if r.cancel != nil {
r.cancel()
}
r.db.onBreaker(&err)
if err != ErrNoRows {
err = errors.Wrapf(err, "query %s args %+v", r.query, r.args)
}
return
}
// Rows rows.
type Rows struct {
*sql.Rows
cancel func()
}
// Close closes the Rows, preventing further enumeration. If Next is called
// and returns false and there are no further result sets,
// the Rows are closed automatically and it will suffice to check the
// result of Err. Close is idempotent and does not affect the result of Err.
func (rs *Rows) Close() (err error) {
err = errors.WithStack(rs.Rows.Close())
if rs.cancel != nil {
rs.cancel()
}
return
}
// Stmt prepared stmt.
type Stmt struct {
db *conn
tx bool
query string
stmt atomic.Value
t trace.Trace
}
// Open opens a database specified by its database driver name and a
// driver-specific data source name, usually consisting of at least a database
// name and connection information.
func Open(c *Config) (*DB, error) {
db := new(DB)
d, err := connect(c, c.DSN)
if err != nil {
return nil, err
}
brkGroup := breaker.NewGroup(c.Breaker)
brk := brkGroup.Get(c.Addr)
w := &conn{DB: d, breaker: brk, conf: c}
rs := make([]*conn, 0, len(c.ReadDSN))
for _, rd := range c.ReadDSN {
d, err := connect(c, rd)
if err != nil {
return nil, err
}
brk := brkGroup.Get(parseDSNAddr(rd))
r := &conn{DB: d, breaker: brk, conf: c}
rs = append(rs, r)
}
db.write = w
db.read = rs
db.master = &DB{write: db.write}
return db, nil
}
func connect(c *Config, dataSourceName string) (*sql.DB, error) {
d, err := sql.Open("mysql", dataSourceName)
if err != nil {
err = errors.WithStack(err)
return nil, err
}
d.SetMaxOpenConns(c.Active)
d.SetMaxIdleConns(c.Idle)
d.SetConnMaxLifetime(time.Duration(c.IdleTimeout))
return d, nil
}
// Begin starts a transaction. The isolation level is dependent on the driver.
func (db *DB) Begin(c context.Context) (tx *Tx, err error) {
return db.write.begin(c)
}
// Exec executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (db *DB) Exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) {
return db.write.exec(c, query, args...)
}
// Prepare creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the returned
// statement. The caller must call the statement's Close method when the
// statement is no longer needed.
func (db *DB) Prepare(query string) (*Stmt, error) {
return db.write.prepare(query)
}
// Prepared creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the returned
// statement. The caller must call the statement's Close method when the
// statement is no longer needed.
func (db *DB) Prepared(query string) (stmt *Stmt) {
return db.write.prepared(query)
}
// Query executes a query that returns rows, typically a SELECT. The args are
// for any placeholder parameters in the query.
func (db *DB) Query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) {
idx := db.readIndex()
for i := range db.read {
if rows, err = db.read[(idx+i)%len(db.read)].query(c, query, args...); !ecode.ServiceUnavailable.Equal(err) {
return
}
}
return db.write.query(c, query, args...)
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until Row's
// Scan method is called.
func (db *DB) QueryRow(c context.Context, query string, args ...interface{}) *Row {
idx := db.readIndex()
for i := range db.read {
if row := db.read[(idx+i)%len(db.read)].queryRow(c, query, args...); !ecode.ServiceUnavailable.Equal(row.err) {
return row
}
}
return db.write.queryRow(c, query, args...)
}
func (db *DB) readIndex() int {
if len(db.read) == 0 {
return 0
}
v := atomic.AddInt64(&db.idx, 1)
return int(v) % len(db.read)
}
// Close closes the write and read database, releasing any open resources.
func (db *DB) Close() (err error) {
if e := db.write.Close(); e != nil {
err = errors.WithStack(e)
}
for _, rd := range db.read {
if e := rd.Close(); e != nil {
err = errors.WithStack(e)
}
}
return
}
// Ping verifies a connection to the database is still alive, establishing a
// connection if necessary.
func (db *DB) Ping(c context.Context) (err error) {
if err = db.write.ping(c); err != nil {
return
}
for _, rd := range db.read {
if err = rd.ping(c); err != nil {
return
}
}
return
}
// Master return *DB instance direct use master conn
// use this *DB instance only when you have some reason need to get result without any delay.
func (db *DB) Master() *DB {
if db.master == nil {
panic(ErrNoMaster)
}
return db.master
}
func (db *conn) onBreaker(err *error) {
if err != nil && *err != nil && *err != sql.ErrNoRows && *err != sql.ErrTxDone {
db.breaker.MarkFailed()
} else {
db.breaker.MarkSuccess()
}
}
func (db *conn) begin(c context.Context) (tx *Tx, err error) {
now := time.Now()
defer slowLog("Begin", now)
t, ok := trace.FromContext(c)
if ok {
t = t.Fork(_family, "begin")
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, ""))
defer func() {
if err != nil {
t.Finish(&err)
}
}()
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("mysql:begin", "breaker")
return
}
_, c, cancel := db.conf.TranTimeout.Shrink(c)
rtx, err := db.BeginTx(c, nil)
stats.Timing("mysql:begin", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.WithStack(err)
cancel()
return
}
tx = &Tx{tx: rtx, t: t, db: db, c: c, cancel: cancel}
return
}
func (db *conn) exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", query, args), now)
if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "exec")
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query))
defer t.Finish(&err)
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("mysql:exec", "breaker")
return
}
_, c, cancel := db.conf.ExecTimeout.Shrink(c)
res, err = db.ExecContext(c, query, args...)
cancel()
db.onBreaker(&err)
stats.Timing("mysql:exec", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "exec:%s, args:%+v", query, args)
}
return
}
func (db *conn) ping(c context.Context) (err error) {
now := time.Now()
defer slowLog("Ping", now)
if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "ping")
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, ""))
defer t.Finish(&err)
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("mysql:ping", "breaker")
return
}
_, c, cancel := db.conf.ExecTimeout.Shrink(c)
err = db.PingContext(c)
cancel()
db.onBreaker(&err)
stats.Timing("mysql:ping", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.WithStack(err)
}
return
}
func (db *conn) prepare(query string) (*Stmt, error) {
defer slowLog(fmt.Sprintf("Prepare query(%s)", query), time.Now())
stmt, err := db.Prepare(query)
if err != nil {
err = errors.Wrapf(err, "prepare %s", query)
return nil, err
}
st := &Stmt{query: query, db: db}
st.stmt.Store(stmt)
return st, nil
}
func (db *conn) prepared(query string) (stmt *Stmt) {
defer slowLog(fmt.Sprintf("Prepared query(%s)", query), time.Now())
stmt = &Stmt{query: query, db: db}
s, err := db.Prepare(query)
if err == nil {
stmt.stmt.Store(s)
return
}
go func() {
for {
s, err := db.Prepare(query)
if err != nil {
time.Sleep(time.Second)
continue
}
stmt.stmt.Store(s)
return
}
}()
return
}
func (db *conn) query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", query, args), now)
if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "query")
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query))
defer t.Finish(&err)
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("mysql:query", "breaker")
return
}
_, c, cancel := db.conf.QueryTimeout.Shrink(c)
rs, err := db.DB.QueryContext(c, query, args...)
db.onBreaker(&err)
stats.Timing("mysql:query", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "query:%s, args:%+v", query, args)
cancel()
return
}
rows = &Rows{Rows: rs, cancel: cancel}
return
}
func (db *conn) queryRow(c context.Context, query string, args ...interface{}) *Row {
now := time.Now()
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", query, args), now)
t, ok := trace.FromContext(c)
if ok {
t = t.Fork(_family, "queryrow")
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query))
}
if err := db.breaker.Allow(); err != nil {
stats.Incr("mysql:queryrow", "breaker")
return &Row{db: db, t: t, err: err}
}
_, c, cancel := db.conf.QueryTimeout.Shrink(c)
r := db.DB.QueryRowContext(c, query, args...)
stats.Timing("mysql:queryrow", int64(time.Since(now)/time.Millisecond))
return &Row{db: db, Row: r, query: query, args: args, t: t, cancel: cancel}
}
// Close closes the statement.
func (s *Stmt) Close() (err error) {
if s == nil {
err = ErrStmtNil
return
}
stmt, ok := s.stmt.Load().(*sql.Stmt)
if ok {
err = errors.WithStack(stmt.Close())
}
return
}
// Exec executes a prepared statement with the given arguments and returns a
// Result summarizing the effect of the statement.
func (s *Stmt) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) {
if s == nil {
err = ErrStmtNil
return
}
now := time.Now()
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", s.query, args), now)
if s.tx {
if s.t != nil {
s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
}
} else if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "exec")
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query))
defer t.Finish(&err)
}
if err = s.db.breaker.Allow(); err != nil {
stats.Incr("mysql:stmt:exec", "breaker")
return
}
stmt, ok := s.stmt.Load().(*sql.Stmt)
if !ok {
err = ErrStmtNil
return
}
_, c, cancel := s.db.conf.ExecTimeout.Shrink(c)
res, err = stmt.ExecContext(c, args...)
cancel()
s.db.onBreaker(&err)
stats.Timing("mysql:stmt:exec", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "exec:%s, args:%+v", s.query, args)
}
return
}
// Query executes a prepared query statement with the given arguments and
// returns the query results as a *Rows.
func (s *Stmt) Query(c context.Context, args ...interface{}) (rows *Rows, err error) {
if s == nil {
err = ErrStmtNil
return
}
now := time.Now()
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", s.query, args), now)
if s.tx {
if s.t != nil {
s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
}
} else if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "query")
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query))
defer t.Finish(&err)
}
if err = s.db.breaker.Allow(); err != nil {
stats.Incr("mysql:stmt:query", "breaker")
return
}
stmt, ok := s.stmt.Load().(*sql.Stmt)
if !ok {
err = ErrStmtNil
return
}
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c)
rs, err := stmt.QueryContext(c, args...)
s.db.onBreaker(&err)
stats.Timing("mysql:stmt:query", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "query:%s, args:%+v", s.query, args)
cancel()
return
}
rows = &Rows{Rows: rs, cancel: cancel}
return
}
// QueryRow executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
func (s *Stmt) QueryRow(c context.Context, args ...interface{}) (row *Row) {
now := time.Now()
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", s.query, args), now)
row = &Row{db: s.db, query: s.query, args: args}
if s == nil {
row.err = ErrStmtNil
return
}
if s.tx {
if s.t != nil {
s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
}
} else if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "queryrow")
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query))
row.t = t
}
if row.err = s.db.breaker.Allow(); row.err != nil {
stats.Incr("mysql:stmt:queryrow", "breaker")
return
}
stmt, ok := s.stmt.Load().(*sql.Stmt)
if !ok {
return
}
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c)
row.Row = stmt.QueryRowContext(c, args...)
row.cancel = cancel
stats.Timing("mysql:stmt:queryrow", int64(time.Since(now)/time.Millisecond))
return
}
// Commit commits the transaction.
func (tx *Tx) Commit() (err error) {
err = tx.tx.Commit()
tx.cancel()
tx.db.onBreaker(&err)
if tx.t != nil {
tx.t.Finish(&err)
}
if err != nil {
err = errors.WithStack(err)
}
return
}
// Rollback aborts the transaction.
func (tx *Tx) Rollback() (err error) {
err = tx.tx.Rollback()
tx.cancel()
tx.db.onBreaker(&err)
if tx.t != nil {
tx.t.Finish(&err)
}
if err != nil {
err = errors.WithStack(err)
}
return
}
// Exec executes a query that doesn't return rows. For example: an INSERT and
// UPDATE.
func (tx *Tx) Exec(query string, args ...interface{}) (res sql.Result, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", query, args), now)
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("exec %s", query)))
}
res, err = tx.tx.ExecContext(tx.c, query, args...)
stats.Timing("mysql:tx:exec", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "exec:%s, args:%+v", query, args)
}
return
}
// Query executes a query that returns rows, typically a SELECT.
func (tx *Tx) Query(query string, args ...interface{}) (rows *Rows, err error) {
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("query %s", query)))
}
now := time.Now()
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", query, args), now)
defer func() {
stats.Timing("mysql:tx:query", int64(time.Since(now)/time.Millisecond))
}()
rs, err := tx.tx.QueryContext(tx.c, query, args...)
if err == nil {
rows = &Rows{Rows: rs}
} else {
err = errors.Wrapf(err, "query:%s, args:%+v", query, args)
}
return
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until Row's
// Scan method is called.
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("queryrow %s", query)))
}
now := time.Now()
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", query, args), now)
defer func() {
stats.Timing("mysql:tx:queryrow", int64(time.Since(now)/time.Millisecond))
}()
r := tx.tx.QueryRowContext(tx.c, query, args...)
return &Row{Row: r, db: tx.db, query: query, args: args}
}
// Stmt returns a transaction-specific prepared statement from an existing statement.
func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
as, ok := stmt.stmt.Load().(*sql.Stmt)
if !ok {
return nil
}
ts := tx.tx.StmtContext(tx.c, as)
st := &Stmt{query: stmt.query, tx: true, t: tx.t, db: tx.db}
st.stmt.Store(ts)
return st
}
// Prepare creates a prepared statement for use within a transaction.
// The returned statement operates within the transaction and can no longer be
// used once the transaction has been committed or rolled back.
// To use an existing prepared statement on this transaction, see Tx.Stmt.
func (tx *Tx) Prepare(query string) (*Stmt, error) {
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("prepare %s", query)))
}
defer slowLog(fmt.Sprintf("Prepare query(%s)", query), time.Now())
stmt, err := tx.tx.Prepare(query)
if err != nil {
err = errors.Wrapf(err, "prepare %s", query)
return nil, err
}
st := &Stmt{query: query, tx: true, t: tx.t, db: tx.db}
st.stmt.Store(stmt)
return st, nil
}
// parseDSNAddr parse dsn name and return addr.
func parseDSNAddr(dsn string) (addr string) {
if dsn == "" {
return
}
part0 := strings.Split(dsn, "@")
if len(part0) > 1 {
part1 := strings.Split(part0[1], "?")
if len(part1) > 0 {
addr = part1[0]
}
}
return
}
func slowLog(statement string, now time.Time) {
du := time.Since(now)
if du > _slowLogDuration {
log.Warn("%s slow log statement: %s time: %v", _family, statement, du)
}
}

@ -0,0 +1,17 @@
#### database/tidb
##### 项目简介
TiDB数据库驱动 对mysql驱动进行封装
##### 功能
1. 支持discovery服务发现 多节点直连
2. 支持通过lvs单一地址连接
3. 支持prepare绑定多个节点
4. 支持动态增减节点负载均衡
5. 日志区分运行节点
##### 编译环境
> 请只用golang v1.8.x以上版本编译执行。
##### 依赖包
> 1.[Go-MySQL-Driver](https://github.com/go-sql-driver/mysql)

@ -0,0 +1,58 @@
package tidb
import (
"context"
"fmt"
"strings"
"time"
"github.com/bilibili/Kratos/pkg/conf/env"
"github.com/bilibili/Kratos/pkg/log"
"github.com/bilibili/Kratos/pkg/naming"
"github.com/bilibili/Kratos/pkg/naming/discovery"
)
var _schema = "tidb://"
func (db *DB) nodeList() (nodes []string) {
var (
insInfo *naming.InstancesInfo
insMap map[string][]*naming.Instance
ins []*naming.Instance
ok bool
)
if insInfo, ok = db.dis.Fetch(context.Background()); !ok {
return
}
insMap = insInfo.Instances
if ins, ok = insMap[env.Zone]; !ok || len(ins) == 0 {
return
}
for _, in := range ins {
for _, addr := range in.Addrs {
if strings.HasPrefix(addr, _schema) {
addr = strings.Replace(addr, _schema, "", -1)
nodes = append(nodes, addr)
}
}
}
log.Info("tidb get %s instances(%v)", db.appid, nodes)
return
}
func (db *DB) disc() (nodes []string) {
db.dis = discovery.Build(db.appid)
e := db.dis.Watch()
select {
case <-e:
nodes = db.nodeList()
case <-time.After(10 * time.Second):
panic("tidb init discovery err")
}
if len(nodes) == 0 {
panic(fmt.Sprintf("tidb %s no instance", db.appid))
}
go db.nodeproc(e)
log.Info("init tidb discvoery info successfully")
return
}

@ -0,0 +1,82 @@
package tidb
import (
"time"
"github.com/bilibili/Kratos/pkg/log"
)
func (db *DB) nodeproc(e <-chan struct{}) {
if db.dis == nil {
return
}
for {
<-e
nodes := db.nodeList()
if len(nodes) == 0 {
continue
}
cm := make(map[string]*conn)
var conns []*conn
for _, conn := range db.conns {
cm[conn.addr] = conn
}
for _, node := range nodes {
if cm[node] != nil {
conns = append(conns, cm[node])
continue
}
c, err := db.connectDSN(genDSN(db.conf.DSN, node))
if err == nil {
conns = append(conns, c)
} else {
log.Error("tidb: connect addr: %s err: %+v", node, err)
}
}
if len(conns) == 0 {
log.Error("tidb: no nodes ignore event")
continue
}
oldConns := db.conns
db.mutex.Lock()
db.conns = conns
db.mutex.Unlock()
log.Info("tidb: new nodes: %v", nodes)
var removedConn []*conn
for _, conn := range oldConns {
var exist bool
for _, c := range conns {
if c.addr == conn.addr {
exist = true
break
}
}
if !exist {
removedConn = append(removedConn, conn)
}
}
go db.closeConns(removedConn)
}
}
func (db *DB) closeConns(conns []*conn) {
if len(conns) == 0 {
return
}
du := db.conf.QueryTimeout
if db.conf.ExecTimeout > du {
du = db.conf.ExecTimeout
}
if db.conf.TranTimeout > du {
du = db.conf.TranTimeout
}
time.Sleep(time.Duration(du))
for _, conn := range conns {
err := conn.Close()
if err != nil {
log.Error("tidb: close removed conn: %s err: %v", conn.addr, err)
} else {
log.Info("tidb: close removed conn: %s", conn.addr)
}
}
}

@ -0,0 +1,739 @@
package tidb
import (
"context"
"database/sql"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/bilibili/Kratos/pkg/log"
"github.com/bilibili/Kratos/pkg/naming"
"github.com/bilibili/Kratos/pkg/net/netutil/breaker"
"github.com/bilibili/Kratos/pkg/net/trace"
"github.com/go-sql-driver/mysql"
"github.com/pkg/errors"
)
const (
_family = "tidb_client"
_slowLogDuration = time.Millisecond * 250
)
var (
// ErrStmtNil prepared stmt error
ErrStmtNil = errors.New("sql: prepare failed and stmt nil")
// ErrNoRows is returned by Scan when QueryRow doesn't return a row.
// In such a case, QueryRow returns a placeholder *Row value that defers
// this error until a Scan.
ErrNoRows = sql.ErrNoRows
// ErrTxDone transaction done.
ErrTxDone = sql.ErrTxDone
)
// DB database.
type DB struct {
conf *Config
conns []*conn
idx int64
dis naming.Resolver
appid string
mutex sync.RWMutex
breakerGroup *breaker.Group
}
// conn database connection
type conn struct {
*sql.DB
breaker breaker.Breaker
conf *Config
addr string
}
// Tx transaction.
type Tx struct {
db *conn
tx *sql.Tx
t trace.Trace
c context.Context
cancel func()
}
// Row row.
type Row struct {
err error
*sql.Row
db *conn
query string
args []interface{}
t trace.Trace
cancel func()
}
// Scan copies the columns from the matched row into the values pointed at by dest.
func (r *Row) Scan(dest ...interface{}) (err error) {
defer slowLog(fmt.Sprintf("Scan addr: %s query(%s) args(%+v)", r.db.addr, r.query, r.args), time.Now())
if r.t != nil {
defer r.t.Finish(&err)
}
if r.err != nil {
err = r.err
} else if r.Row == nil {
err = ErrStmtNil
}
if err != nil {
return
}
err = r.Row.Scan(dest...)
if r.cancel != nil {
r.cancel()
}
r.db.onBreaker(&err)
if err != ErrNoRows {
err = errors.Wrapf(err, "addr: %s, query %s args %+v", r.db.addr, r.query, r.args)
}
return
}
// Rows rows.
type Rows struct {
*sql.Rows
cancel func()
}
// Close closes the Rows, preventing further enumeration. If Next is called
// and returns false and there are no further result sets,
// the Rows are closed automatically and it will suffice to check the
// result of Err. Close is idempotent and does not affect the result of Err.
func (rs *Rows) Close() (err error) {
err = errors.WithStack(rs.Rows.Close())
if rs.cancel != nil {
rs.cancel()
}
return
}
// Stmt prepared stmt.
type Stmt struct {
db *conn
tx bool
query string
stmt atomic.Value
t trace.Trace
}
// Stmts random prepared stmt.
type Stmts struct {
query string
sts map[string]*Stmt
mu sync.RWMutex
db *DB
}
// Open opens a database specified by its database driver name and a
// driver-specific data source name, usually consisting of at least a database
// name and connection information.
func Open(c *Config) (db *DB, err error) {
db = &DB{conf: c, breakerGroup: breaker.NewGroup(c.Breaker)}
cfg, err := mysql.ParseDSN(c.DSN)
if err != nil {
return
}
var dsns []string
if cfg.Net == "discovery" {
db.appid = cfg.Addr
for _, addr := range db.disc() {
dsns = append(dsns, genDSN(c.DSN, addr))
}
} else {
dsns = append(dsns, c.DSN)
}
cs := make([]*conn, 0, len(dsns))
for _, dsn := range dsns {
r, err := db.connectDSN(dsn)
if err != nil {
return db, err
}
cs = append(cs, r)
}
db.conns = cs
return
}
func (db *DB) connectDSN(dsn string) (c *conn, err error) {
d, err := connect(db.conf, dsn)
if err != nil {
return
}
addr := parseDSNAddr(dsn)
brk := db.breakerGroup.Get(addr)
c = &conn{DB: d, breaker: brk, conf: db.conf, addr: addr}
return
}
func connect(c *Config, dataSourceName string) (*sql.DB, error) {
d, err := sql.Open("mysql", dataSourceName)
if err != nil {
err = errors.WithStack(err)
return nil, err
}
d.SetMaxOpenConns(c.Active)
d.SetMaxIdleConns(c.Idle)
d.SetConnMaxLifetime(time.Duration(c.IdleTimeout))
return d, nil
}
func (db *DB) conn() (c *conn) {
db.mutex.RLock()
c = db.conns[db.index()]
db.mutex.RUnlock()
return
}
// Begin starts a transaction. The isolation level is dependent on the driver.
func (db *DB) Begin(c context.Context) (tx *Tx, err error) {
return db.conn().begin(c)
}
// Exec executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (db *DB) Exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) {
return db.conn().exec(c, query, args...)
}
// Prepare creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the returned
// statement. The caller must call the statement's Close method when the
// statement is no longer needed.
func (db *DB) Prepare(query string) (*Stmt, error) {
return db.conn().prepare(query)
}
// Prepared creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the returned
// statement. The caller must call the statement's Close method when the
// statement is no longer needed.
func (db *DB) Prepared(query string) (s *Stmts) {
s = &Stmts{query: query, sts: make(map[string]*Stmt), db: db}
for _, c := range db.conns {
st := c.prepared(query)
s.mu.Lock()
s.sts[c.addr] = st
s.mu.Unlock()
}
return
}
// Query executes a query that returns rows, typically a SELECT. The args are
// for any placeholder parameters in the query.
func (db *DB) Query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) {
return db.conn().query(c, query, args...)
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until Row's
// Scan method is called.
func (db *DB) QueryRow(c context.Context, query string, args ...interface{}) *Row {
return db.conn().queryRow(c, query, args...)
}
func (db *DB) index() int {
if len(db.conns) == 1 {
return 0
}
v := atomic.AddInt64(&db.idx, 1)
return int(v) % len(db.conns)
}
// Close closes the databases, releasing any open resources.
func (db *DB) Close() (err error) {
db.mutex.RLock()
defer db.mutex.RUnlock()
for _, d := range db.conns {
if e := d.Close(); e != nil {
err = errors.WithStack(e)
}
}
return
}
// Ping verifies a connection to the database is still alive, establishing a
// connection if necessary.
func (db *DB) Ping(c context.Context) (err error) {
if err = db.conn().ping(c); err != nil {
return
}
return
}
func (db *conn) onBreaker(err *error) {
if err != nil && *err != nil && *err != sql.ErrNoRows && *err != sql.ErrTxDone {
db.breaker.MarkFailed()
} else {
db.breaker.MarkSuccess()
}
}
func (db *conn) begin(c context.Context) (tx *Tx, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Begin addr: %s", db.addr), now)
t, ok := trace.FromContext(c)
if ok {
t = t.Fork(_family, "begin")
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, ""))
defer func() {
if err != nil {
t.Finish(&err)
}
}()
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("tidb:begin", "breaker")
return
}
_, c, cancel := db.conf.TranTimeout.Shrink(c)
rtx, err := db.BeginTx(c, nil)
stats.Timing("tidb:begin", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.WithStack(err)
cancel()
return
}
tx = &Tx{tx: rtx, t: t, db: db, c: c, cancel: cancel}
return
}
func (db *conn) exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", db.addr, query, args), now)
if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "exec")
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query))
defer t.Finish(&err)
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("tidb:exec", "breaker")
return
}
_, c, cancel := db.conf.ExecTimeout.Shrink(c)
res, err = db.ExecContext(c, query, args...)
cancel()
db.onBreaker(&err)
stats.Timing("tidb:exec", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", db.addr, query, args)
}
return
}
func (db *conn) ping(c context.Context) (err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Ping addr: %s", db.addr), now)
if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "ping")
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, ""))
defer t.Finish(&err)
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("tidb:ping", "breaker")
return
}
_, c, cancel := db.conf.ExecTimeout.Shrink(c)
err = db.PingContext(c)
cancel()
db.onBreaker(&err)
stats.Timing("tidb:ping", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.WithStack(err)
}
return
}
func (db *conn) prepare(query string) (*Stmt, error) {
defer slowLog(fmt.Sprintf("Prepare addr: %s query(%s)", db.addr, query), time.Now())
stmt, err := db.Prepare(query)
if err != nil {
err = errors.Wrapf(err, "addr: %s prepare %s", db.addr, query)
return nil, err
}
st := &Stmt{query: query, db: db}
st.stmt.Store(stmt)
return st, nil
}
func (db *conn) prepared(query string) (stmt *Stmt) {
defer slowLog(fmt.Sprintf("Prepared addr: %s query(%s)", db.addr, query), time.Now())
stmt = &Stmt{query: query, db: db}
s, err := db.Prepare(query)
if err == nil {
stmt.stmt.Store(s)
return
}
return
}
func (db *conn) query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", db.addr, query, args), now)
if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "query")
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query))
defer t.Finish(&err)
}
if err = db.breaker.Allow(); err != nil {
stats.Incr("tidb:query", "breaker")
return
}
_, c, cancel := db.conf.QueryTimeout.Shrink(c)
rs, err := db.DB.QueryContext(c, query, args...)
db.onBreaker(&err)
stats.Timing("tidb:query", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", db.addr, query, args)
cancel()
return
}
rows = &Rows{Rows: rs, cancel: cancel}
return
}
func (db *conn) queryRow(c context.Context, query string, args ...interface{}) *Row {
now := time.Now()
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", db.addr, query, args), now)
t, ok := trace.FromContext(c)
if ok {
t = t.Fork(_family, "queryrow")
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query))
}
if err := db.breaker.Allow(); err != nil {
stats.Incr("tidb:queryrow", "breaker")
return &Row{db: db, t: t, err: err}
}
_, c, cancel := db.conf.QueryTimeout.Shrink(c)
r := db.DB.QueryRowContext(c, query, args...)
stats.Timing("tidb:queryrow", int64(time.Since(now)/time.Millisecond))
return &Row{db: db, Row: r, query: query, args: args, t: t, cancel: cancel}
}
// Close closes the statement.
func (s *Stmt) Close() (err error) {
stmt, ok := s.stmt.Load().(*sql.Stmt)
if ok {
err = errors.WithStack(stmt.Close())
}
return
}
func (s *Stmt) prepare() (st *sql.Stmt) {
var ok bool
if st, ok = s.stmt.Load().(*sql.Stmt); ok {
return
}
var err error
if st, err = s.db.Prepare(s.query); err == nil {
s.stmt.Store(st)
}
return
}
// Exec executes a prepared statement with the given arguments and returns a
// Result summarizing the effect of the statement.
func (s *Stmt) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now)
if s.tx {
if s.t != nil {
s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
}
} else if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "exec")
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query))
defer t.Finish(&err)
}
if err = s.db.breaker.Allow(); err != nil {
stats.Incr("tidb:stmt:exec", "breaker")
return
}
stmt := s.prepare()
if stmt == nil {
err = ErrStmtNil
return
}
_, c, cancel := s.db.conf.ExecTimeout.Shrink(c)
res, err = stmt.ExecContext(c, args...)
cancel()
s.db.onBreaker(&err)
stats.Timing("tidb:stmt:exec", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", s.db.addr, s.query, args)
}
return
}
// Query executes a prepared query statement with the given arguments and
// returns the query results as a *Rows.
func (s *Stmt) Query(c context.Context, args ...interface{}) (rows *Rows, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now)
if s.tx {
if s.t != nil {
s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
}
} else if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "query")
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query))
defer t.Finish(&err)
}
if err = s.db.breaker.Allow(); err != nil {
stats.Incr("tidb:stmt:query", "breaker")
return
}
stmt := s.prepare()
if stmt == nil {
err = ErrStmtNil
return
}
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c)
rs, err := stmt.QueryContext(c, args...)
s.db.onBreaker(&err)
stats.Timing("tidb:stmt:query", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", s.db.addr, s.query, args)
cancel()
return
}
rows = &Rows{Rows: rs, cancel: cancel}
return
}
// QueryRow executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
func (s *Stmt) QueryRow(c context.Context, args ...interface{}) (row *Row) {
now := time.Now()
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now)
row = &Row{db: s.db, query: s.query, args: args}
if s.tx {
if s.t != nil {
s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
}
} else if t, ok := trace.FromContext(c); ok {
t = t.Fork(_family, "queryrow")
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query))
row.t = t
}
if row.err = s.db.breaker.Allow(); row.err != nil {
stats.Incr("tidb:stmt:queryrow", "breaker")
return
}
stmt := s.prepare()
if stmt == nil {
return
}
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c)
row.Row = stmt.QueryRowContext(c, args...)
row.cancel = cancel
stats.Timing("tidb:stmt:queryrow", int64(time.Since(now)/time.Millisecond))
return
}
func (s *Stmts) prepare(conn *conn) (st *Stmt) {
if conn == nil {
conn = s.db.conn()
}
s.mu.RLock()
st = s.sts[conn.addr]
s.mu.RUnlock()
if st == nil {
st = conn.prepared(s.query)
s.mu.Lock()
s.sts[conn.addr] = st
s.mu.Unlock()
}
return
}
// Exec executes a prepared statement with the given arguments and returns a
// Result summarizing the effect of the statement.
func (s *Stmts) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) {
return s.prepare(nil).Exec(c, args...)
}
// Query executes a prepared query statement with the given arguments and
// returns the query results as a *Rows.
func (s *Stmts) Query(c context.Context, args ...interface{}) (rows *Rows, err error) {
return s.prepare(nil).Query(c, args...)
}
// QueryRow executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
func (s *Stmts) QueryRow(c context.Context, args ...interface{}) (row *Row) {
return s.prepare(nil).QueryRow(c, args...)
}
// Close closes the statement.
func (s *Stmts) Close() (err error) {
for _, st := range s.sts {
if err = errors.WithStack(st.Close()); err != nil {
return
}
}
return
}
// Commit commits the transaction.
func (tx *Tx) Commit() (err error) {
err = tx.tx.Commit()
tx.cancel()
tx.db.onBreaker(&err)
if tx.t != nil {
tx.t.Finish(&err)
}
if err != nil {
err = errors.WithStack(err)
}
return
}
// Rollback aborts the transaction.
func (tx *Tx) Rollback() (err error) {
err = tx.tx.Rollback()
tx.cancel()
tx.db.onBreaker(&err)
if tx.t != nil {
tx.t.Finish(&err)
}
if err != nil {
err = errors.WithStack(err)
}
return
}
// Exec executes a query that doesn't return rows. For example: an INSERT and
// UPDATE.
func (tx *Tx) Exec(query string, args ...interface{}) (res sql.Result, err error) {
now := time.Now()
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now)
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("exec %s", query)))
}
res, err = tx.tx.ExecContext(tx.c, query, args...)
stats.Timing("tidb:tx:exec", int64(time.Since(now)/time.Millisecond))
if err != nil {
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", tx.db.addr, query, args)
}
return
}
// Query executes a query that returns rows, typically a SELECT.
func (tx *Tx) Query(query string, args ...interface{}) (rows *Rows, err error) {
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("query %s", query)))
}
now := time.Now()
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now)
defer func() {
stats.Timing("tidb:tx:query", int64(time.Since(now)/time.Millisecond))
}()
rs, err := tx.tx.QueryContext(tx.c, query, args...)
if err == nil {
rows = &Rows{Rows: rs}
} else {
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", tx.db.addr, query, args)
}
return
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until Row's
// Scan method is called.
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("queryrow %s", query)))
}
now := time.Now()
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now)
defer func() {
stats.Timing("tidb:tx:queryrow", int64(time.Since(now)/time.Millisecond))
}()
r := tx.tx.QueryRowContext(tx.c, query, args...)
return &Row{Row: r, db: tx.db, query: query, args: args}
}
// Stmt returns a transaction-specific prepared statement from an existing statement.
func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
if stmt == nil {
return nil
}
as, ok := stmt.stmt.Load().(*sql.Stmt)
if !ok {
return nil
}
ts := tx.tx.StmtContext(tx.c, as)
st := &Stmt{query: stmt.query, tx: true, t: tx.t, db: tx.db}
st.stmt.Store(ts)
return st
}
// Stmts returns a transaction-specific prepared statement from an existing statement.
func (tx *Tx) Stmts(stmt *Stmts) *Stmt {
return tx.Stmt(stmt.prepare(tx.db))
}
// Prepare creates a prepared statement for use within a transaction.
// The returned statement operates within the transaction and can no longer be
// used once the transaction has been committed or rolled back.
// To use an existing prepared statement on this transaction, see Tx.Stmt.
func (tx *Tx) Prepare(query string) (*Stmt, error) {
if tx.t != nil {
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("prepare %s", query)))
}
defer slowLog(fmt.Sprintf("Prepare addr: %s query(%s)", tx.db.addr, query), time.Now())
stmt, err := tx.tx.Prepare(query)
if err != nil {
err = errors.Wrapf(err, "addr: %s prepare %s", tx.db.addr, query)
return nil, err
}
st := &Stmt{query: query, tx: true, t: tx.t, db: tx.db}
st.stmt.Store(stmt)
return st, nil
}
// parseDSNAddr parse dsn name and return addr.
func parseDSNAddr(dsn string) (addr string) {
if dsn == "" {
return
}
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return
}
addr = cfg.Addr
return
}
func genDSN(dsn, addr string) (res string) {
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return
}
cfg.Addr = addr
cfg.Net = "tcp"
res = cfg.FormatDSN()
return
}
func slowLog(statement string, now time.Time) {
du := time.Since(now)
if du > _slowLogDuration {
log.Warn("%s slow log statement: %s time: %v", _family, statement, du)
}
}

@ -0,0 +1,38 @@
package tidb
import (
"github.com/bilibili/Kratos/pkg/log"
"github.com/bilibili/Kratos/pkg/net/netutil/breaker"
"github.com/bilibili/Kratos/pkg/stat"
"github.com/bilibili/Kratos/pkg/time"
// database driver
_ "github.com/go-sql-driver/mysql"
)
var stats = stat.DB
// Config mysql config.
type Config struct {
DSN string // dsn
Active int // pool
Idle int // pool
IdleTimeout time.Duration // connect max life time.
QueryTimeout time.Duration // query sql timeout
ExecTimeout time.Duration // execute sql timeout
TranTimeout time.Duration // transaction sql timeout
Breaker *breaker.Config // breaker
}
// NewTiDB new db and retry connection when has error.
func NewTiDB(c *Config) (db *DB) {
if c.QueryTimeout == 0 || c.ExecTimeout == 0 || c.TranTimeout == 0 {
panic("tidb must be set query/execute/transction timeout")
}
db, err := Open(c)
if err != nil {
log.Error("open tidb error(%v)", err)
panic(err)
}
return
}
Loading…
Cancel
Save