parent
7c6e0ea7ba
commit
1efe0a084e
@ -0,0 +1,44 @@ |
||||
### database/hbase |
||||
|
||||
### 项目简介 |
||||
|
||||
Hbase Client,进行封装加入了链路追踪和统计。 |
||||
|
||||
### usage |
||||
```go |
||||
package main |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/database/hbase" |
||||
) |
||||
|
||||
func main() { |
||||
config := &hbase.Config{Zookeeper: &hbase.ZKConfig{Addrs: []string{"localhost"}}} |
||||
client := hbase.NewClient(config) |
||||
|
||||
values := map[string]map[string][]byte{"name": {"firstname": []byte("hello"), "lastname": []byte("world")}} |
||||
ctx := context.Background() |
||||
|
||||
_, err := client.PutStr(ctx, "user", "user1", values) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
|
||||
result, err := client.GetStr(ctx, "user", "user1") |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
fmt.Printf("%v", result) |
||||
} |
||||
``` |
||||
|
||||
##### 编译环境 |
||||
|
||||
> 请只用golang v1.8.x以上版本编译执行。 |
||||
|
||||
##### 依赖包 |
||||
|
||||
> 1.[gohbase](https://github.com/tsuna/gohbase) |
@ -0,0 +1,23 @@ |
||||
package hbase |
||||
|
||||
import ( |
||||
xtime "github.com/bilibili/Kratos/pkg/time" |
||||
) |
||||
|
||||
// ZKConfig Server&Client settings.
|
||||
type ZKConfig struct { |
||||
Root string |
||||
Addrs []string |
||||
Timeout xtime.Duration |
||||
} |
||||
|
||||
// Config hbase config
|
||||
type Config struct { |
||||
Zookeeper *ZKConfig |
||||
RPCQueueSize int |
||||
FlushInterval xtime.Duration |
||||
EffectiveUser string |
||||
RegionLookupTimeout xtime.Duration |
||||
RegionReadTimeout xtime.Duration |
||||
TestRowKey string |
||||
} |
@ -0,0 +1,297 @@ |
||||
package hbase |
||||
|
||||
import ( |
||||
"context" |
||||
"io" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/tsuna/gohbase" |
||||
"github.com/tsuna/gohbase/hrpc" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/log" |
||||
) |
||||
|
||||
// HookFunc hook function call before every method and hook return function will call after finish.
|
||||
type HookFunc func(ctx context.Context, call hrpc.Call, customName string) func(err error) |
||||
|
||||
// Client hbase client.
|
||||
type Client struct { |
||||
hc gohbase.Client |
||||
addr string |
||||
config *Config |
||||
hooks []HookFunc |
||||
} |
||||
|
||||
// AddHook add hook function.
|
||||
func (c *Client) AddHook(hookFn HookFunc) { |
||||
c.hooks = append(c.hooks, hookFn) |
||||
} |
||||
|
||||
func (c *Client) invokeHook(ctx context.Context, call hrpc.Call, customName string) func(error) { |
||||
finishHooks := make([]func(error), 0, len(c.hooks)) |
||||
for _, fn := range c.hooks { |
||||
finishHooks = append(finishHooks, fn(ctx, call, customName)) |
||||
} |
||||
return func(err error) { |
||||
for _, fn := range finishHooks { |
||||
fn(err) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// NewClient new a hbase client.
|
||||
func NewClient(config *Config, options ...gohbase.Option) *Client { |
||||
rawcli := NewRawClient(config, options...) |
||||
rawcli.AddHook(NewSlowLogHook(250 * time.Millisecond)) |
||||
rawcli.AddHook(MetricsHook(nil)) |
||||
rawcli.AddHook(TraceHook("database/hbase", strings.Join(config.Zookeeper.Addrs, ","))) |
||||
return rawcli |
||||
} |
||||
|
||||
// NewRawClient new a hbase client without prometheus metrics and dapper trace hook.
|
||||
func NewRawClient(config *Config, options ...gohbase.Option) *Client { |
||||
zkquorum := strings.Join(config.Zookeeper.Addrs, ",") |
||||
if config.Zookeeper.Root != "" { |
||||
options = append(options, gohbase.ZookeeperRoot(config.Zookeeper.Root)) |
||||
} |
||||
if config.Zookeeper.Timeout != 0 { |
||||
options = append(options, gohbase.ZookeeperTimeout(time.Duration(config.Zookeeper.Timeout))) |
||||
} |
||||
|
||||
if config.RPCQueueSize != 0 { |
||||
log.Warn("RPCQueueSize configuration be ignored") |
||||
} |
||||
// force RpcQueueSize = 1, don't change it !!! it has reason (゜-゜)つロ
|
||||
options = append(options, gohbase.RpcQueueSize(1)) |
||||
|
||||
if config.FlushInterval != 0 { |
||||
options = append(options, gohbase.FlushInterval(time.Duration(config.FlushInterval))) |
||||
} |
||||
if config.EffectiveUser != "" { |
||||
options = append(options, gohbase.EffectiveUser(config.EffectiveUser)) |
||||
} |
||||
if config.RegionLookupTimeout != 0 { |
||||
options = append(options, gohbase.RegionLookupTimeout(time.Duration(config.RegionLookupTimeout))) |
||||
} |
||||
if config.RegionReadTimeout != 0 { |
||||
options = append(options, gohbase.RegionReadTimeout(time.Duration(config.RegionReadTimeout))) |
||||
} |
||||
hc := gohbase.NewClient(zkquorum, options...) |
||||
return &Client{ |
||||
hc: hc, |
||||
addr: zkquorum, |
||||
config: config, |
||||
} |
||||
} |
||||
|
||||
// ScanAll do scan command and return all result
|
||||
// NOTE: if err != nil the results is safe for range operate even not result found
|
||||
func (c *Client) ScanAll(ctx context.Context, table []byte, options ...func(hrpc.Call) error) (results []*hrpc.Result, err error) { |
||||
cursor, err := c.Scan(ctx, table, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
for { |
||||
result, err := cursor.Next() |
||||
if err != nil { |
||||
if err == io.EOF { |
||||
break |
||||
} |
||||
return nil, err |
||||
} |
||||
results = append(results, result) |
||||
} |
||||
return results, nil |
||||
} |
||||
|
||||
type scanTrace struct { |
||||
hrpc.Scanner |
||||
finishHook func(error) |
||||
} |
||||
|
||||
func (s *scanTrace) Next() (*hrpc.Result, error) { |
||||
result, err := s.Scanner.Next() |
||||
if err != nil { |
||||
s.finishHook(err) |
||||
} |
||||
return result, err |
||||
} |
||||
|
||||
func (s *scanTrace) Close() error { |
||||
err := s.Scanner.Close() |
||||
s.finishHook(err) |
||||
return err |
||||
} |
||||
|
||||
// Scan do a scan command.
|
||||
func (c *Client) Scan(ctx context.Context, table []byte, options ...func(hrpc.Call) error) (scanner hrpc.Scanner, err error) { |
||||
var scan *hrpc.Scan |
||||
scan, err = hrpc.NewScan(ctx, table, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
st := &scanTrace{} |
||||
st.finishHook = c.invokeHook(ctx, scan, "Scan") |
||||
st.Scanner = c.hc.Scan(scan) |
||||
return st, nil |
||||
} |
||||
|
||||
// ScanStr scan string
|
||||
func (c *Client) ScanStr(ctx context.Context, table string, options ...func(hrpc.Call) error) (hrpc.Scanner, error) { |
||||
return c.Scan(ctx, []byte(table), options...) |
||||
} |
||||
|
||||
// ScanStrAll scan string
|
||||
// NOTE: if err != nil the results is safe for range operate even not result found
|
||||
func (c *Client) ScanStrAll(ctx context.Context, table string, options ...func(hrpc.Call) error) ([]*hrpc.Result, error) { |
||||
return c.ScanAll(ctx, []byte(table), options...) |
||||
} |
||||
|
||||
// ScanRange get a scanner for the given table and key range.
|
||||
// The range is half-open, i.e. [startRow; stopRow[ -- stopRow is not
|
||||
// included in the range.
|
||||
func (c *Client) ScanRange(ctx context.Context, table, startRow, stopRow []byte, options ...func(hrpc.Call) error) (scanner hrpc.Scanner, err error) { |
||||
var scan *hrpc.Scan |
||||
scan, err = hrpc.NewScanRange(ctx, table, startRow, stopRow, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
st := &scanTrace{} |
||||
st.finishHook = c.invokeHook(ctx, scan, "ScanRange") |
||||
st.Scanner = c.hc.Scan(scan) |
||||
return st, nil |
||||
} |
||||
|
||||
// ScanRangeStr get a scanner for the given table and key range.
|
||||
// The range is half-open, i.e. [startRow; stopRow[ -- stopRow is not
|
||||
// included in the range.
|
||||
func (c *Client) ScanRangeStr(ctx context.Context, table, startRow, stopRow string, options ...func(hrpc.Call) error) (hrpc.Scanner, error) { |
||||
return c.ScanRange(ctx, []byte(table), []byte(startRow), []byte(stopRow), options...) |
||||
} |
||||
|
||||
// Get get result for the given table and row key.
|
||||
// NOTE: if err != nil then result != nil, if result not exists result.Cells length is 0
|
||||
func (c *Client) Get(ctx context.Context, table, key []byte, options ...func(hrpc.Call) error) (result *hrpc.Result, err error) { |
||||
var get *hrpc.Get |
||||
get, err = hrpc.NewGet(ctx, table, key, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
finishHook := c.invokeHook(ctx, get, "GET") |
||||
result, err = c.hc.Get(get) |
||||
finishHook(err) |
||||
return |
||||
} |
||||
|
||||
// GetStr do a get command.
|
||||
// NOTE: if err != nil then result != nil, if result not exists result.Cells length is 0
|
||||
func (c *Client) GetStr(ctx context.Context, table, key string, options ...func(hrpc.Call) error) (result *hrpc.Result, err error) { |
||||
return c.Get(ctx, []byte(table), []byte(key), options...) |
||||
} |
||||
|
||||
// PutStr insert the given family-column-values in the given row key of the given table.
|
||||
func (c *Client) PutStr(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) { |
||||
put, err := hrpc.NewPutStr(ctx, table, key, values, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
finishHook := c.invokeHook(ctx, put, "PUT") |
||||
result, err := c.hc.Put(put) |
||||
finishHook(err) |
||||
return result, err |
||||
} |
||||
|
||||
// Delete is used to perform Delete operations on a single row.
|
||||
// To delete entire row, values should be nil.
|
||||
//
|
||||
// To delete specific families, qualifiers map should be nil:
|
||||
// map[string]map[string][]byte{
|
||||
// "cf1": nil,
|
||||
// "cf2": nil,
|
||||
// }
|
||||
//
|
||||
// To delete specific qualifiers:
|
||||
// map[string]map[string][]byte{
|
||||
// "cf": map[string][]byte{
|
||||
// "q1": nil,
|
||||
// "q2": nil,
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// To delete all versions before and at a timestamp, pass hrpc.Timestamp() option.
|
||||
// By default all versions will be removed.
|
||||
//
|
||||
// To delete only a specific version at a timestamp, pass hrpc.DeleteOneVersion() option
|
||||
// along with a timestamp. For delete specific qualifiers request, if timestamp is not
|
||||
// passed, only the latest version will be removed. For delete specific families request,
|
||||
// the timestamp should be passed or it will have no effect as it's an expensive
|
||||
// operation to perform.
|
||||
func (c *Client) Delete(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) { |
||||
del, err := hrpc.NewDelStr(ctx, table, key, values, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
finishHook := c.invokeHook(ctx, del, "Delete") |
||||
result, err := c.hc.Delete(del) |
||||
finishHook(err) |
||||
return result, err |
||||
} |
||||
|
||||
// Append do a append command.
|
||||
func (c *Client) Append(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (*hrpc.Result, error) { |
||||
appd, err := hrpc.NewAppStr(ctx, table, key, values, options...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
finishHook := c.invokeHook(ctx, appd, "Append") |
||||
result, err := c.hc.Append(appd) |
||||
finishHook(err) |
||||
return result, err |
||||
} |
||||
|
||||
// Increment the given values in HBase under the given table and key.
|
||||
func (c *Client) Increment(ctx context.Context, table string, key string, values map[string]map[string][]byte, options ...func(hrpc.Call) error) (int64, error) { |
||||
increment, err := hrpc.NewIncStr(ctx, table, key, values, options...) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
finishHook := c.invokeHook(ctx, increment, "Increment") |
||||
result, err := c.hc.Increment(increment) |
||||
finishHook(err) |
||||
return result, err |
||||
} |
||||
|
||||
// IncrementSingle increment the given value by amount in HBase under the given table, key, family and qualifier.
|
||||
func (c *Client) IncrementSingle(ctx context.Context, table string, key string, family string, qualifier string, amount int64, options ...func(hrpc.Call) error) (int64, error) { |
||||
increment, err := hrpc.NewIncStrSingle(ctx, table, key, family, qualifier, amount, options...) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
|
||||
finishHook := c.invokeHook(ctx, increment, "IncrementSingle") |
||||
result, err := c.hc.Increment(increment) |
||||
finishHook(err) |
||||
return result, err |
||||
} |
||||
|
||||
// Ping ping.
|
||||
func (c *Client) Ping(ctx context.Context) (err error) { |
||||
testRowKey := "test" |
||||
if c.config.TestRowKey != "" { |
||||
testRowKey = c.config.TestRowKey |
||||
} |
||||
values := map[string]map[string][]byte{"test": map[string][]byte{"test": []byte("test")}} |
||||
_, err = c.PutStr(ctx, "test", testRowKey, values) |
||||
return |
||||
} |
||||
|
||||
// Close close client.
|
||||
func (c *Client) Close() error { |
||||
c.hc.Close() |
||||
return nil |
||||
} |
@ -0,0 +1,48 @@ |
||||
package hbase |
||||
|
||||
import ( |
||||
"context" |
||||
"io" |
||||
"time" |
||||
|
||||
"github.com/tsuna/gohbase" |
||||
"github.com/tsuna/gohbase/hrpc" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/stat" |
||||
) |
||||
|
||||
func codeFromErr(err error) string { |
||||
code := "unknown_error" |
||||
switch err { |
||||
case gohbase.ErrClientClosed: |
||||
code = "client_closed" |
||||
case gohbase.ErrConnotFindRegion: |
||||
code = "connot_find_region" |
||||
case gohbase.TableNotFound: |
||||
code = "table_not_found" |
||||
case gohbase.ErrRegionUnavailable: |
||||
code = "region_unavailable" |
||||
} |
||||
return code |
||||
} |
||||
|
||||
// MetricsHook if stats is nil use stat.DB as default.
|
||||
func MetricsHook(stats stat.Stat) HookFunc { |
||||
if stats == nil { |
||||
stats = stat.DB |
||||
} |
||||
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) { |
||||
now := time.Now() |
||||
if customName == "" { |
||||
customName = call.Name() |
||||
} |
||||
method := "hbase:" + customName |
||||
return func(err error) { |
||||
durationMs := int64(time.Since(now) / time.Millisecond) |
||||
stats.Timing(method, durationMs) |
||||
if err != nil && err != io.EOF { |
||||
stats.Incr(method, codeFromErr(err)) |
||||
} |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,24 @@ |
||||
package hbase |
||||
|
||||
import ( |
||||
"context" |
||||
"time" |
||||
|
||||
"github.com/tsuna/gohbase/hrpc" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/log" |
||||
) |
||||
|
||||
// NewSlowLogHook log slow operation.
|
||||
func NewSlowLogHook(threshold time.Duration) HookFunc { |
||||
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) { |
||||
start := time.Now() |
||||
return func(error) { |
||||
duration := time.Since(start) |
||||
if duration < threshold { |
||||
return |
||||
} |
||||
log.Warn("hbase slow log: %s %s %s time: %s", customName, call.Table(), call.Key(), duration) |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,40 @@ |
||||
package hbase |
||||
|
||||
import ( |
||||
"context" |
||||
"io" |
||||
|
||||
"github.com/tsuna/gohbase/hrpc" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/net/trace" |
||||
) |
||||
|
||||
// TraceHook create new hbase trace hook.
|
||||
func TraceHook(component, instance string) HookFunc { |
||||
var internalTags []trace.Tag |
||||
internalTags = append(internalTags, trace.TagString(trace.TagComponent, component)) |
||||
internalTags = append(internalTags, trace.TagString(trace.TagDBInstance, instance)) |
||||
internalTags = append(internalTags, trace.TagString(trace.TagPeerService, "hbase")) |
||||
internalTags = append(internalTags, trace.TagString(trace.TagSpanKind, "client")) |
||||
return func(ctx context.Context, call hrpc.Call, customName string) func(err error) { |
||||
noop := func(error) {} |
||||
root, ok := trace.FromContext(ctx) |
||||
if !ok { |
||||
return noop |
||||
} |
||||
if customName == "" { |
||||
customName = call.Name() |
||||
} |
||||
span := root.Fork("", "Hbase:"+customName) |
||||
span.SetTag(internalTags...) |
||||
statement := string(call.Table()) + " " + string(call.Key()) |
||||
span.SetTag(trace.TagString(trace.TagDBStatement, statement)) |
||||
return func(err error) { |
||||
if err == io.EOF { |
||||
// reset error for trace.
|
||||
err = nil |
||||
} |
||||
span.Finish(&err) |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,12 @@ |
||||
#### database/sql |
||||
|
||||
##### 项目简介 |
||||
MySQL数据库驱动,进行封装加入了链路追踪和统计。 |
||||
|
||||
如果需要SQL级别的超时管理 可以在业务代码里面使用context.WithDeadline实现 推荐超时配置放到application.toml里面 方便热加载 |
||||
|
||||
##### 编译环境 |
||||
> 请只用golang v1.8.x以上版本编译执行。 |
||||
|
||||
##### 依赖包 |
||||
> 1.[Go-MySQL-Driver](https://github.com/go-sql-driver/mysql) |
@ -0,0 +1,40 @@ |
||||
package sql |
||||
|
||||
import ( |
||||
"github.com/bilibili/Kratos/pkg/log" |
||||
"github.com/bilibili/Kratos/pkg/net/netutil/breaker" |
||||
"github.com/bilibili/Kratos/pkg/stat" |
||||
"github.com/bilibili/Kratos/pkg/time" |
||||
|
||||
// database driver
|
||||
_ "github.com/go-sql-driver/mysql" |
||||
) |
||||
|
||||
var stats = stat.DB |
||||
|
||||
// Config mysql config.
|
||||
type Config struct { |
||||
Addr string // for trace
|
||||
DSN string // write data source name.
|
||||
ReadDSN []string // read data source name.
|
||||
Active int // pool
|
||||
Idle int // pool
|
||||
IdleTimeout time.Duration // connect max life time.
|
||||
QueryTimeout time.Duration // query sql timeout
|
||||
ExecTimeout time.Duration // execute sql timeout
|
||||
TranTimeout time.Duration // transaction sql timeout
|
||||
Breaker *breaker.Config // breaker
|
||||
} |
||||
|
||||
// NewMySQL new db and retry connection when has error.
|
||||
func NewMySQL(c *Config) (db *DB) { |
||||
if c.QueryTimeout == 0 || c.ExecTimeout == 0 || c.TranTimeout == 0 { |
||||
panic("mysql must be set query/execute/transction timeout") |
||||
} |
||||
db, err := Open(c) |
||||
if err != nil { |
||||
log.Error("open mysql error(%v)", err) |
||||
panic(err) |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,678 @@ |
||||
package sql |
||||
|
||||
import ( |
||||
"context" |
||||
"database/sql" |
||||
"fmt" |
||||
"strings" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/ecode" |
||||
"github.com/bilibili/Kratos/pkg/log" |
||||
"github.com/bilibili/Kratos/pkg/net/netutil/breaker" |
||||
"github.com/bilibili/Kratos/pkg/net/trace" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
const ( |
||||
_family = "sql_client" |
||||
_slowLogDuration = time.Millisecond * 250 |
||||
) |
||||
|
||||
var ( |
||||
// ErrStmtNil prepared stmt error
|
||||
ErrStmtNil = errors.New("sql: prepare failed and stmt nil") |
||||
// ErrNoMaster is returned by Master when call master multiple times.
|
||||
ErrNoMaster = errors.New("sql: no master instance") |
||||
// ErrNoRows is returned by Scan when QueryRow doesn't return a row.
|
||||
// In such a case, QueryRow returns a placeholder *Row value that defers
|
||||
// this error until a Scan.
|
||||
ErrNoRows = sql.ErrNoRows |
||||
// ErrTxDone transaction done.
|
||||
ErrTxDone = sql.ErrTxDone |
||||
) |
||||
|
||||
// DB database.
|
||||
type DB struct { |
||||
write *conn |
||||
read []*conn |
||||
idx int64 |
||||
master *DB |
||||
} |
||||
|
||||
// conn database connection
|
||||
type conn struct { |
||||
*sql.DB |
||||
breaker breaker.Breaker |
||||
conf *Config |
||||
} |
||||
|
||||
// Tx transaction.
|
||||
type Tx struct { |
||||
db *conn |
||||
tx *sql.Tx |
||||
t trace.Trace |
||||
c context.Context |
||||
cancel func() |
||||
} |
||||
|
||||
// Row row.
|
||||
type Row struct { |
||||
err error |
||||
*sql.Row |
||||
db *conn |
||||
query string |
||||
args []interface{} |
||||
t trace.Trace |
||||
cancel func() |
||||
} |
||||
|
||||
// Scan copies the columns from the matched row into the values pointed at by dest.
|
||||
func (r *Row) Scan(dest ...interface{}) (err error) { |
||||
defer slowLog(fmt.Sprintf("Scan query(%s) args(%+v)", r.query, r.args), time.Now()) |
||||
if r.t != nil { |
||||
defer r.t.Finish(&err) |
||||
} |
||||
if r.err != nil { |
||||
err = r.err |
||||
} else if r.Row == nil { |
||||
err = ErrStmtNil |
||||
} |
||||
if err != nil { |
||||
return |
||||
} |
||||
err = r.Row.Scan(dest...) |
||||
if r.cancel != nil { |
||||
r.cancel() |
||||
} |
||||
r.db.onBreaker(&err) |
||||
if err != ErrNoRows { |
||||
err = errors.Wrapf(err, "query %s args %+v", r.query, r.args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Rows rows.
|
||||
type Rows struct { |
||||
*sql.Rows |
||||
cancel func() |
||||
} |
||||
|
||||
// Close closes the Rows, preventing further enumeration. If Next is called
|
||||
// and returns false and there are no further result sets,
|
||||
// the Rows are closed automatically and it will suffice to check the
|
||||
// result of Err. Close is idempotent and does not affect the result of Err.
|
||||
func (rs *Rows) Close() (err error) { |
||||
err = errors.WithStack(rs.Rows.Close()) |
||||
if rs.cancel != nil { |
||||
rs.cancel() |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Stmt prepared stmt.
|
||||
type Stmt struct { |
||||
db *conn |
||||
tx bool |
||||
query string |
||||
stmt atomic.Value |
||||
t trace.Trace |
||||
} |
||||
|
||||
// Open opens a database specified by its database driver name and a
|
||||
// driver-specific data source name, usually consisting of at least a database
|
||||
// name and connection information.
|
||||
func Open(c *Config) (*DB, error) { |
||||
db := new(DB) |
||||
d, err := connect(c, c.DSN) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
brkGroup := breaker.NewGroup(c.Breaker) |
||||
brk := brkGroup.Get(c.Addr) |
||||
w := &conn{DB: d, breaker: brk, conf: c} |
||||
rs := make([]*conn, 0, len(c.ReadDSN)) |
||||
for _, rd := range c.ReadDSN { |
||||
d, err := connect(c, rd) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
brk := brkGroup.Get(parseDSNAddr(rd)) |
||||
r := &conn{DB: d, breaker: brk, conf: c} |
||||
rs = append(rs, r) |
||||
} |
||||
db.write = w |
||||
db.read = rs |
||||
db.master = &DB{write: db.write} |
||||
return db, nil |
||||
} |
||||
|
||||
func connect(c *Config, dataSourceName string) (*sql.DB, error) { |
||||
d, err := sql.Open("mysql", dataSourceName) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
return nil, err |
||||
} |
||||
d.SetMaxOpenConns(c.Active) |
||||
d.SetMaxIdleConns(c.Idle) |
||||
d.SetConnMaxLifetime(time.Duration(c.IdleTimeout)) |
||||
return d, nil |
||||
} |
||||
|
||||
// Begin starts a transaction. The isolation level is dependent on the driver.
|
||||
func (db *DB) Begin(c context.Context) (tx *Tx, err error) { |
||||
return db.write.begin(c) |
||||
} |
||||
|
||||
// Exec executes a query without returning any rows.
|
||||
// The args are for any placeholder parameters in the query.
|
||||
func (db *DB) Exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) { |
||||
return db.write.exec(c, query, args...) |
||||
} |
||||
|
||||
// Prepare creates a prepared statement for later queries or executions.
|
||||
// Multiple queries or executions may be run concurrently from the returned
|
||||
// statement. The caller must call the statement's Close method when the
|
||||
// statement is no longer needed.
|
||||
func (db *DB) Prepare(query string) (*Stmt, error) { |
||||
return db.write.prepare(query) |
||||
} |
||||
|
||||
// Prepared creates a prepared statement for later queries or executions.
|
||||
// Multiple queries or executions may be run concurrently from the returned
|
||||
// statement. The caller must call the statement's Close method when the
|
||||
// statement is no longer needed.
|
||||
func (db *DB) Prepared(query string) (stmt *Stmt) { |
||||
return db.write.prepared(query) |
||||
} |
||||
|
||||
// Query executes a query that returns rows, typically a SELECT. The args are
|
||||
// for any placeholder parameters in the query.
|
||||
func (db *DB) Query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) { |
||||
idx := db.readIndex() |
||||
for i := range db.read { |
||||
if rows, err = db.read[(idx+i)%len(db.read)].query(c, query, args...); !ecode.ServiceUnavailable.Equal(err) { |
||||
return |
||||
} |
||||
} |
||||
return db.write.query(c, query, args...) |
||||
} |
||||
|
||||
// QueryRow executes a query that is expected to return at most one row.
|
||||
// QueryRow always returns a non-nil value. Errors are deferred until Row's
|
||||
// Scan method is called.
|
||||
func (db *DB) QueryRow(c context.Context, query string, args ...interface{}) *Row { |
||||
idx := db.readIndex() |
||||
for i := range db.read { |
||||
if row := db.read[(idx+i)%len(db.read)].queryRow(c, query, args...); !ecode.ServiceUnavailable.Equal(row.err) { |
||||
return row |
||||
} |
||||
} |
||||
return db.write.queryRow(c, query, args...) |
||||
} |
||||
|
||||
func (db *DB) readIndex() int { |
||||
if len(db.read) == 0 { |
||||
return 0 |
||||
} |
||||
v := atomic.AddInt64(&db.idx, 1) |
||||
return int(v) % len(db.read) |
||||
} |
||||
|
||||
// Close closes the write and read database, releasing any open resources.
|
||||
func (db *DB) Close() (err error) { |
||||
if e := db.write.Close(); e != nil { |
||||
err = errors.WithStack(e) |
||||
} |
||||
for _, rd := range db.read { |
||||
if e := rd.Close(); e != nil { |
||||
err = errors.WithStack(e) |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Ping verifies a connection to the database is still alive, establishing a
|
||||
// connection if necessary.
|
||||
func (db *DB) Ping(c context.Context) (err error) { |
||||
if err = db.write.ping(c); err != nil { |
||||
return |
||||
} |
||||
for _, rd := range db.read { |
||||
if err = rd.ping(c); err != nil { |
||||
return |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Master return *DB instance direct use master conn
|
||||
// use this *DB instance only when you have some reason need to get result without any delay.
|
||||
func (db *DB) Master() *DB { |
||||
if db.master == nil { |
||||
panic(ErrNoMaster) |
||||
} |
||||
return db.master |
||||
} |
||||
|
||||
func (db *conn) onBreaker(err *error) { |
||||
if err != nil && *err != nil && *err != sql.ErrNoRows && *err != sql.ErrTxDone { |
||||
db.breaker.MarkFailed() |
||||
} else { |
||||
db.breaker.MarkSuccess() |
||||
} |
||||
} |
||||
|
||||
func (db *conn) begin(c context.Context) (tx *Tx, err error) { |
||||
now := time.Now() |
||||
defer slowLog("Begin", now) |
||||
t, ok := trace.FromContext(c) |
||||
if ok { |
||||
t = t.Fork(_family, "begin") |
||||
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, "")) |
||||
defer func() { |
||||
if err != nil { |
||||
t.Finish(&err) |
||||
} |
||||
}() |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("mysql:begin", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.TranTimeout.Shrink(c) |
||||
rtx, err := db.BeginTx(c, nil) |
||||
stats.Timing("mysql:begin", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
cancel() |
||||
return |
||||
} |
||||
tx = &Tx{tx: rtx, t: t, db: db, c: c, cancel: cancel} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", query, args), now) |
||||
if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "exec") |
||||
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("mysql:exec", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.ExecTimeout.Shrink(c) |
||||
res, err = db.ExecContext(c, query, args...) |
||||
cancel() |
||||
db.onBreaker(&err) |
||||
stats.Timing("mysql:exec", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "exec:%s, args:%+v", query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) ping(c context.Context) (err error) { |
||||
now := time.Now() |
||||
defer slowLog("Ping", now) |
||||
if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "ping") |
||||
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, "")) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("mysql:ping", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.ExecTimeout.Shrink(c) |
||||
err = db.PingContext(c) |
||||
cancel() |
||||
db.onBreaker(&err) |
||||
stats.Timing("mysql:ping", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) prepare(query string) (*Stmt, error) { |
||||
defer slowLog(fmt.Sprintf("Prepare query(%s)", query), time.Now()) |
||||
stmt, err := db.Prepare(query) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "prepare %s", query) |
||||
return nil, err |
||||
} |
||||
st := &Stmt{query: query, db: db} |
||||
st.stmt.Store(stmt) |
||||
return st, nil |
||||
} |
||||
|
||||
func (db *conn) prepared(query string) (stmt *Stmt) { |
||||
defer slowLog(fmt.Sprintf("Prepared query(%s)", query), time.Now()) |
||||
stmt = &Stmt{query: query, db: db} |
||||
s, err := db.Prepare(query) |
||||
if err == nil { |
||||
stmt.stmt.Store(s) |
||||
return |
||||
} |
||||
go func() { |
||||
for { |
||||
s, err := db.Prepare(query) |
||||
if err != nil { |
||||
time.Sleep(time.Second) |
||||
continue |
||||
} |
||||
stmt.stmt.Store(s) |
||||
return |
||||
} |
||||
}() |
||||
return |
||||
} |
||||
|
||||
func (db *conn) query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", query, args), now) |
||||
if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "query") |
||||
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("mysql:query", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.QueryTimeout.Shrink(c) |
||||
rs, err := db.DB.QueryContext(c, query, args...) |
||||
db.onBreaker(&err) |
||||
stats.Timing("mysql:query", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "query:%s, args:%+v", query, args) |
||||
cancel() |
||||
return |
||||
} |
||||
rows = &Rows{Rows: rs, cancel: cancel} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) queryRow(c context.Context, query string, args ...interface{}) *Row { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", query, args), now) |
||||
t, ok := trace.FromContext(c) |
||||
if ok { |
||||
t = t.Fork(_family, "queryrow") |
||||
t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query)) |
||||
} |
||||
if err := db.breaker.Allow(); err != nil { |
||||
stats.Incr("mysql:queryrow", "breaker") |
||||
return &Row{db: db, t: t, err: err} |
||||
} |
||||
_, c, cancel := db.conf.QueryTimeout.Shrink(c) |
||||
r := db.DB.QueryRowContext(c, query, args...) |
||||
stats.Timing("mysql:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||
return &Row{db: db, Row: r, query: query, args: args, t: t, cancel: cancel} |
||||
} |
||||
|
||||
// Close closes the statement.
|
||||
func (s *Stmt) Close() (err error) { |
||||
if s == nil { |
||||
err = ErrStmtNil |
||||
return |
||||
} |
||||
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||
if ok { |
||||
err = errors.WithStack(stmt.Close()) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Exec executes a prepared statement with the given arguments and returns a
|
||||
// Result summarizing the effect of the statement.
|
||||
func (s *Stmt) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) { |
||||
if s == nil { |
||||
err = ErrStmtNil |
||||
return |
||||
} |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", s.query, args), now) |
||||
if s.tx { |
||||
if s.t != nil { |
||||
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||
} |
||||
} else if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "exec") |
||||
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = s.db.breaker.Allow(); err != nil { |
||||
stats.Incr("mysql:stmt:exec", "breaker") |
||||
return |
||||
} |
||||
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||
if !ok { |
||||
err = ErrStmtNil |
||||
return |
||||
} |
||||
_, c, cancel := s.db.conf.ExecTimeout.Shrink(c) |
||||
res, err = stmt.ExecContext(c, args...) |
||||
cancel() |
||||
s.db.onBreaker(&err) |
||||
stats.Timing("mysql:stmt:exec", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "exec:%s, args:%+v", s.query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Query executes a prepared query statement with the given arguments and
|
||||
// returns the query results as a *Rows.
|
||||
func (s *Stmt) Query(c context.Context, args ...interface{}) (rows *Rows, err error) { |
||||
if s == nil { |
||||
err = ErrStmtNil |
||||
return |
||||
} |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", s.query, args), now) |
||||
if s.tx { |
||||
if s.t != nil { |
||||
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||
} |
||||
} else if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "query") |
||||
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = s.db.breaker.Allow(); err != nil { |
||||
stats.Incr("mysql:stmt:query", "breaker") |
||||
return |
||||
} |
||||
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||
if !ok { |
||||
err = ErrStmtNil |
||||
return |
||||
} |
||||
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c) |
||||
rs, err := stmt.QueryContext(c, args...) |
||||
s.db.onBreaker(&err) |
||||
stats.Timing("mysql:stmt:query", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "query:%s, args:%+v", s.query, args) |
||||
cancel() |
||||
return |
||||
} |
||||
rows = &Rows{Rows: rs, cancel: cancel} |
||||
return |
||||
} |
||||
|
||||
// QueryRow executes a prepared query statement with the given arguments.
|
||||
// If an error occurs during the execution of the statement, that error will
|
||||
// be returned by a call to Scan on the returned *Row, which is always non-nil.
|
||||
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
|
||||
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
|
||||
func (s *Stmt) QueryRow(c context.Context, args ...interface{}) (row *Row) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", s.query, args), now) |
||||
row = &Row{db: s.db, query: s.query, args: args} |
||||
if s == nil { |
||||
row.err = ErrStmtNil |
||||
return |
||||
} |
||||
if s.tx { |
||||
if s.t != nil { |
||||
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||
} |
||||
} else if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "queryrow") |
||||
t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query)) |
||||
row.t = t |
||||
} |
||||
if row.err = s.db.breaker.Allow(); row.err != nil { |
||||
stats.Incr("mysql:stmt:queryrow", "breaker") |
||||
return |
||||
} |
||||
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||
if !ok { |
||||
return |
||||
} |
||||
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c) |
||||
row.Row = stmt.QueryRowContext(c, args...) |
||||
row.cancel = cancel |
||||
stats.Timing("mysql:stmt:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||
return |
||||
} |
||||
|
||||
// Commit commits the transaction.
|
||||
func (tx *Tx) Commit() (err error) { |
||||
err = tx.tx.Commit() |
||||
tx.cancel() |
||||
tx.db.onBreaker(&err) |
||||
if tx.t != nil { |
||||
tx.t.Finish(&err) |
||||
} |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Rollback aborts the transaction.
|
||||
func (tx *Tx) Rollback() (err error) { |
||||
err = tx.tx.Rollback() |
||||
tx.cancel() |
||||
tx.db.onBreaker(&err) |
||||
if tx.t != nil { |
||||
tx.t.Finish(&err) |
||||
} |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Exec executes a query that doesn't return rows. For example: an INSERT and
|
||||
// UPDATE.
|
||||
func (tx *Tx) Exec(query string, args ...interface{}) (res sql.Result, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Exec query(%s) args(%+v)", query, args), now) |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("exec %s", query))) |
||||
} |
||||
res, err = tx.tx.ExecContext(tx.c, query, args...) |
||||
stats.Timing("mysql:tx:exec", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "exec:%s, args:%+v", query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Query executes a query that returns rows, typically a SELECT.
|
||||
func (tx *Tx) Query(query string, args ...interface{}) (rows *Rows, err error) { |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("query %s", query))) |
||||
} |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Query query(%s) args(%+v)", query, args), now) |
||||
defer func() { |
||||
stats.Timing("mysql:tx:query", int64(time.Since(now)/time.Millisecond)) |
||||
}() |
||||
rs, err := tx.tx.QueryContext(tx.c, query, args...) |
||||
if err == nil { |
||||
rows = &Rows{Rows: rs} |
||||
} else { |
||||
err = errors.Wrapf(err, "query:%s, args:%+v", query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// QueryRow executes a query that is expected to return at most one row.
|
||||
// QueryRow always returns a non-nil value. Errors are deferred until Row's
|
||||
// Scan method is called.
|
||||
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("queryrow %s", query))) |
||||
} |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("QueryRow query(%s) args(%+v)", query, args), now) |
||||
defer func() { |
||||
stats.Timing("mysql:tx:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||
}() |
||||
r := tx.tx.QueryRowContext(tx.c, query, args...) |
||||
return &Row{Row: r, db: tx.db, query: query, args: args} |
||||
} |
||||
|
||||
// Stmt returns a transaction-specific prepared statement from an existing statement.
|
||||
func (tx *Tx) Stmt(stmt *Stmt) *Stmt { |
||||
as, ok := stmt.stmt.Load().(*sql.Stmt) |
||||
if !ok { |
||||
return nil |
||||
} |
||||
ts := tx.tx.StmtContext(tx.c, as) |
||||
st := &Stmt{query: stmt.query, tx: true, t: tx.t, db: tx.db} |
||||
st.stmt.Store(ts) |
||||
return st |
||||
} |
||||
|
||||
// Prepare creates a prepared statement for use within a transaction.
|
||||
// The returned statement operates within the transaction and can no longer be
|
||||
// used once the transaction has been committed or rolled back.
|
||||
// To use an existing prepared statement on this transaction, see Tx.Stmt.
|
||||
func (tx *Tx) Prepare(query string) (*Stmt, error) { |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("prepare %s", query))) |
||||
} |
||||
defer slowLog(fmt.Sprintf("Prepare query(%s)", query), time.Now()) |
||||
stmt, err := tx.tx.Prepare(query) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "prepare %s", query) |
||||
return nil, err |
||||
} |
||||
st := &Stmt{query: query, tx: true, t: tx.t, db: tx.db} |
||||
st.stmt.Store(stmt) |
||||
return st, nil |
||||
} |
||||
|
||||
// parseDSNAddr parse dsn name and return addr.
|
||||
func parseDSNAddr(dsn string) (addr string) { |
||||
if dsn == "" { |
||||
return |
||||
} |
||||
part0 := strings.Split(dsn, "@") |
||||
if len(part0) > 1 { |
||||
part1 := strings.Split(part0[1], "?") |
||||
if len(part1) > 0 { |
||||
addr = part1[0] |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
func slowLog(statement string, now time.Time) { |
||||
du := time.Since(now) |
||||
if du > _slowLogDuration { |
||||
log.Warn("%s slow log statement: %s time: %v", _family, statement, du) |
||||
} |
||||
} |
@ -0,0 +1,17 @@ |
||||
#### database/tidb |
||||
|
||||
##### 项目简介 |
||||
TiDB数据库驱动 对mysql驱动进行封装 |
||||
|
||||
##### 功能 |
||||
1. 支持discovery服务发现 多节点直连 |
||||
2. 支持通过lvs单一地址连接 |
||||
3. 支持prepare绑定多个节点 |
||||
4. 支持动态增减节点负载均衡 |
||||
5. 日志区分运行节点 |
||||
|
||||
##### 编译环境 |
||||
> 请只用golang v1.8.x以上版本编译执行。 |
||||
|
||||
##### 依赖包 |
||||
> 1.[Go-MySQL-Driver](https://github.com/go-sql-driver/mysql) |
@ -0,0 +1,58 @@ |
||||
package tidb |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/conf/env" |
||||
"github.com/bilibili/Kratos/pkg/log" |
||||
"github.com/bilibili/Kratos/pkg/naming" |
||||
"github.com/bilibili/Kratos/pkg/naming/discovery" |
||||
) |
||||
|
||||
var _schema = "tidb://" |
||||
|
||||
func (db *DB) nodeList() (nodes []string) { |
||||
var ( |
||||
insInfo *naming.InstancesInfo |
||||
insMap map[string][]*naming.Instance |
||||
ins []*naming.Instance |
||||
ok bool |
||||
) |
||||
if insInfo, ok = db.dis.Fetch(context.Background()); !ok { |
||||
return |
||||
} |
||||
insMap = insInfo.Instances |
||||
if ins, ok = insMap[env.Zone]; !ok || len(ins) == 0 { |
||||
return |
||||
} |
||||
for _, in := range ins { |
||||
for _, addr := range in.Addrs { |
||||
if strings.HasPrefix(addr, _schema) { |
||||
addr = strings.Replace(addr, _schema, "", -1) |
||||
nodes = append(nodes, addr) |
||||
} |
||||
} |
||||
} |
||||
log.Info("tidb get %s instances(%v)", db.appid, nodes) |
||||
return |
||||
} |
||||
|
||||
func (db *DB) disc() (nodes []string) { |
||||
db.dis = discovery.Build(db.appid) |
||||
e := db.dis.Watch() |
||||
select { |
||||
case <-e: |
||||
nodes = db.nodeList() |
||||
case <-time.After(10 * time.Second): |
||||
panic("tidb init discovery err") |
||||
} |
||||
if len(nodes) == 0 { |
||||
panic(fmt.Sprintf("tidb %s no instance", db.appid)) |
||||
} |
||||
go db.nodeproc(e) |
||||
log.Info("init tidb discvoery info successfully") |
||||
return |
||||
} |
@ -0,0 +1,82 @@ |
||||
package tidb |
||||
|
||||
import ( |
||||
"time" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/log" |
||||
) |
||||
|
||||
func (db *DB) nodeproc(e <-chan struct{}) { |
||||
if db.dis == nil { |
||||
return |
||||
} |
||||
for { |
||||
<-e |
||||
nodes := db.nodeList() |
||||
if len(nodes) == 0 { |
||||
continue |
||||
} |
||||
cm := make(map[string]*conn) |
||||
var conns []*conn |
||||
for _, conn := range db.conns { |
||||
cm[conn.addr] = conn |
||||
} |
||||
for _, node := range nodes { |
||||
if cm[node] != nil { |
||||
conns = append(conns, cm[node]) |
||||
continue |
||||
} |
||||
c, err := db.connectDSN(genDSN(db.conf.DSN, node)) |
||||
if err == nil { |
||||
conns = append(conns, c) |
||||
} else { |
||||
log.Error("tidb: connect addr: %s err: %+v", node, err) |
||||
} |
||||
} |
||||
if len(conns) == 0 { |
||||
log.Error("tidb: no nodes ignore event") |
||||
continue |
||||
} |
||||
oldConns := db.conns |
||||
db.mutex.Lock() |
||||
db.conns = conns |
||||
db.mutex.Unlock() |
||||
log.Info("tidb: new nodes: %v", nodes) |
||||
var removedConn []*conn |
||||
for _, conn := range oldConns { |
||||
var exist bool |
||||
for _, c := range conns { |
||||
if c.addr == conn.addr { |
||||
exist = true |
||||
break |
||||
} |
||||
} |
||||
if !exist { |
||||
removedConn = append(removedConn, conn) |
||||
} |
||||
} |
||||
go db.closeConns(removedConn) |
||||
} |
||||
} |
||||
|
||||
func (db *DB) closeConns(conns []*conn) { |
||||
if len(conns) == 0 { |
||||
return |
||||
} |
||||
du := db.conf.QueryTimeout |
||||
if db.conf.ExecTimeout > du { |
||||
du = db.conf.ExecTimeout |
||||
} |
||||
if db.conf.TranTimeout > du { |
||||
du = db.conf.TranTimeout |
||||
} |
||||
time.Sleep(time.Duration(du)) |
||||
for _, conn := range conns { |
||||
err := conn.Close() |
||||
if err != nil { |
||||
log.Error("tidb: close removed conn: %s err: %v", conn.addr, err) |
||||
} else { |
||||
log.Info("tidb: close removed conn: %s", conn.addr) |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,739 @@ |
||||
package tidb |
||||
|
||||
import ( |
||||
"context" |
||||
"database/sql" |
||||
"fmt" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/log" |
||||
"github.com/bilibili/Kratos/pkg/naming" |
||||
"github.com/bilibili/Kratos/pkg/net/netutil/breaker" |
||||
"github.com/bilibili/Kratos/pkg/net/trace" |
||||
|
||||
"github.com/go-sql-driver/mysql" |
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
const ( |
||||
_family = "tidb_client" |
||||
_slowLogDuration = time.Millisecond * 250 |
||||
) |
||||
|
||||
var ( |
||||
// ErrStmtNil prepared stmt error
|
||||
ErrStmtNil = errors.New("sql: prepare failed and stmt nil") |
||||
// ErrNoRows is returned by Scan when QueryRow doesn't return a row.
|
||||
// In such a case, QueryRow returns a placeholder *Row value that defers
|
||||
// this error until a Scan.
|
||||
ErrNoRows = sql.ErrNoRows |
||||
// ErrTxDone transaction done.
|
||||
ErrTxDone = sql.ErrTxDone |
||||
) |
||||
|
||||
// DB database.
|
||||
type DB struct { |
||||
conf *Config |
||||
conns []*conn |
||||
idx int64 |
||||
dis naming.Resolver |
||||
appid string |
||||
mutex sync.RWMutex |
||||
breakerGroup *breaker.Group |
||||
} |
||||
|
||||
// conn database connection
|
||||
type conn struct { |
||||
*sql.DB |
||||
breaker breaker.Breaker |
||||
conf *Config |
||||
addr string |
||||
} |
||||
|
||||
// Tx transaction.
|
||||
type Tx struct { |
||||
db *conn |
||||
tx *sql.Tx |
||||
t trace.Trace |
||||
c context.Context |
||||
cancel func() |
||||
} |
||||
|
||||
// Row row.
|
||||
type Row struct { |
||||
err error |
||||
*sql.Row |
||||
db *conn |
||||
query string |
||||
args []interface{} |
||||
t trace.Trace |
||||
cancel func() |
||||
} |
||||
|
||||
// Scan copies the columns from the matched row into the values pointed at by dest.
|
||||
func (r *Row) Scan(dest ...interface{}) (err error) { |
||||
defer slowLog(fmt.Sprintf("Scan addr: %s query(%s) args(%+v)", r.db.addr, r.query, r.args), time.Now()) |
||||
if r.t != nil { |
||||
defer r.t.Finish(&err) |
||||
} |
||||
if r.err != nil { |
||||
err = r.err |
||||
} else if r.Row == nil { |
||||
err = ErrStmtNil |
||||
} |
||||
if err != nil { |
||||
return |
||||
} |
||||
err = r.Row.Scan(dest...) |
||||
if r.cancel != nil { |
||||
r.cancel() |
||||
} |
||||
r.db.onBreaker(&err) |
||||
if err != ErrNoRows { |
||||
err = errors.Wrapf(err, "addr: %s, query %s args %+v", r.db.addr, r.query, r.args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Rows rows.
|
||||
type Rows struct { |
||||
*sql.Rows |
||||
cancel func() |
||||
} |
||||
|
||||
// Close closes the Rows, preventing further enumeration. If Next is called
|
||||
// and returns false and there are no further result sets,
|
||||
// the Rows are closed automatically and it will suffice to check the
|
||||
// result of Err. Close is idempotent and does not affect the result of Err.
|
||||
func (rs *Rows) Close() (err error) { |
||||
err = errors.WithStack(rs.Rows.Close()) |
||||
if rs.cancel != nil { |
||||
rs.cancel() |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Stmt prepared stmt.
|
||||
type Stmt struct { |
||||
db *conn |
||||
tx bool |
||||
query string |
||||
stmt atomic.Value |
||||
t trace.Trace |
||||
} |
||||
|
||||
// Stmts random prepared stmt.
|
||||
type Stmts struct { |
||||
query string |
||||
sts map[string]*Stmt |
||||
mu sync.RWMutex |
||||
db *DB |
||||
} |
||||
|
||||
// Open opens a database specified by its database driver name and a
|
||||
// driver-specific data source name, usually consisting of at least a database
|
||||
// name and connection information.
|
||||
func Open(c *Config) (db *DB, err error) { |
||||
db = &DB{conf: c, breakerGroup: breaker.NewGroup(c.Breaker)} |
||||
cfg, err := mysql.ParseDSN(c.DSN) |
||||
if err != nil { |
||||
return |
||||
} |
||||
var dsns []string |
||||
if cfg.Net == "discovery" { |
||||
db.appid = cfg.Addr |
||||
for _, addr := range db.disc() { |
||||
dsns = append(dsns, genDSN(c.DSN, addr)) |
||||
} |
||||
} else { |
||||
dsns = append(dsns, c.DSN) |
||||
} |
||||
|
||||
cs := make([]*conn, 0, len(dsns)) |
||||
for _, dsn := range dsns { |
||||
r, err := db.connectDSN(dsn) |
||||
if err != nil { |
||||
return db, err |
||||
} |
||||
cs = append(cs, r) |
||||
} |
||||
db.conns = cs |
||||
return |
||||
} |
||||
|
||||
func (db *DB) connectDSN(dsn string) (c *conn, err error) { |
||||
d, err := connect(db.conf, dsn) |
||||
if err != nil { |
||||
return |
||||
} |
||||
addr := parseDSNAddr(dsn) |
||||
brk := db.breakerGroup.Get(addr) |
||||
c = &conn{DB: d, breaker: brk, conf: db.conf, addr: addr} |
||||
return |
||||
} |
||||
|
||||
func connect(c *Config, dataSourceName string) (*sql.DB, error) { |
||||
d, err := sql.Open("mysql", dataSourceName) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
return nil, err |
||||
} |
||||
d.SetMaxOpenConns(c.Active) |
||||
d.SetMaxIdleConns(c.Idle) |
||||
d.SetConnMaxLifetime(time.Duration(c.IdleTimeout)) |
||||
return d, nil |
||||
} |
||||
|
||||
func (db *DB) conn() (c *conn) { |
||||
db.mutex.RLock() |
||||
c = db.conns[db.index()] |
||||
db.mutex.RUnlock() |
||||
return |
||||
} |
||||
|
||||
// Begin starts a transaction. The isolation level is dependent on the driver.
|
||||
func (db *DB) Begin(c context.Context) (tx *Tx, err error) { |
||||
return db.conn().begin(c) |
||||
} |
||||
|
||||
// Exec executes a query without returning any rows.
|
||||
// The args are for any placeholder parameters in the query.
|
||||
func (db *DB) Exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) { |
||||
return db.conn().exec(c, query, args...) |
||||
} |
||||
|
||||
// Prepare creates a prepared statement for later queries or executions.
|
||||
// Multiple queries or executions may be run concurrently from the returned
|
||||
// statement. The caller must call the statement's Close method when the
|
||||
// statement is no longer needed.
|
||||
func (db *DB) Prepare(query string) (*Stmt, error) { |
||||
return db.conn().prepare(query) |
||||
} |
||||
|
||||
// Prepared creates a prepared statement for later queries or executions.
|
||||
// Multiple queries or executions may be run concurrently from the returned
|
||||
// statement. The caller must call the statement's Close method when the
|
||||
// statement is no longer needed.
|
||||
func (db *DB) Prepared(query string) (s *Stmts) { |
||||
s = &Stmts{query: query, sts: make(map[string]*Stmt), db: db} |
||||
for _, c := range db.conns { |
||||
st := c.prepared(query) |
||||
s.mu.Lock() |
||||
s.sts[c.addr] = st |
||||
s.mu.Unlock() |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Query executes a query that returns rows, typically a SELECT. The args are
|
||||
// for any placeholder parameters in the query.
|
||||
func (db *DB) Query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) { |
||||
return db.conn().query(c, query, args...) |
||||
} |
||||
|
||||
// QueryRow executes a query that is expected to return at most one row.
|
||||
// QueryRow always returns a non-nil value. Errors are deferred until Row's
|
||||
// Scan method is called.
|
||||
func (db *DB) QueryRow(c context.Context, query string, args ...interface{}) *Row { |
||||
return db.conn().queryRow(c, query, args...) |
||||
} |
||||
|
||||
func (db *DB) index() int { |
||||
if len(db.conns) == 1 { |
||||
return 0 |
||||
} |
||||
v := atomic.AddInt64(&db.idx, 1) |
||||
return int(v) % len(db.conns) |
||||
} |
||||
|
||||
// Close closes the databases, releasing any open resources.
|
||||
func (db *DB) Close() (err error) { |
||||
db.mutex.RLock() |
||||
defer db.mutex.RUnlock() |
||||
for _, d := range db.conns { |
||||
if e := d.Close(); e != nil { |
||||
err = errors.WithStack(e) |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Ping verifies a connection to the database is still alive, establishing a
|
||||
// connection if necessary.
|
||||
func (db *DB) Ping(c context.Context) (err error) { |
||||
if err = db.conn().ping(c); err != nil { |
||||
return |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) onBreaker(err *error) { |
||||
if err != nil && *err != nil && *err != sql.ErrNoRows && *err != sql.ErrTxDone { |
||||
db.breaker.MarkFailed() |
||||
} else { |
||||
db.breaker.MarkSuccess() |
||||
} |
||||
} |
||||
|
||||
func (db *conn) begin(c context.Context) (tx *Tx, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Begin addr: %s", db.addr), now) |
||||
t, ok := trace.FromContext(c) |
||||
if ok { |
||||
t = t.Fork(_family, "begin") |
||||
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, "")) |
||||
defer func() { |
||||
if err != nil { |
||||
t.Finish(&err) |
||||
} |
||||
}() |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("tidb:begin", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.TranTimeout.Shrink(c) |
||||
rtx, err := db.BeginTx(c, nil) |
||||
stats.Timing("tidb:begin", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
cancel() |
||||
return |
||||
} |
||||
tx = &Tx{tx: rtx, t: t, db: db, c: c, cancel: cancel} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", db.addr, query, args), now) |
||||
if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "exec") |
||||
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("tidb:exec", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.ExecTimeout.Shrink(c) |
||||
res, err = db.ExecContext(c, query, args...) |
||||
cancel() |
||||
db.onBreaker(&err) |
||||
stats.Timing("tidb:exec", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", db.addr, query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) ping(c context.Context) (err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Ping addr: %s", db.addr), now) |
||||
if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "ping") |
||||
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, "")) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("tidb:ping", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.ExecTimeout.Shrink(c) |
||||
err = db.PingContext(c) |
||||
cancel() |
||||
db.onBreaker(&err) |
||||
stats.Timing("tidb:ping", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) prepare(query string) (*Stmt, error) { |
||||
defer slowLog(fmt.Sprintf("Prepare addr: %s query(%s)", db.addr, query), time.Now()) |
||||
stmt, err := db.Prepare(query) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "addr: %s prepare %s", db.addr, query) |
||||
return nil, err |
||||
} |
||||
st := &Stmt{query: query, db: db} |
||||
st.stmt.Store(stmt) |
||||
return st, nil |
||||
} |
||||
|
||||
func (db *conn) prepared(query string) (stmt *Stmt) { |
||||
defer slowLog(fmt.Sprintf("Prepared addr: %s query(%s)", db.addr, query), time.Now()) |
||||
stmt = &Stmt{query: query, db: db} |
||||
s, err := db.Prepare(query) |
||||
if err == nil { |
||||
stmt.stmt.Store(s) |
||||
return |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", db.addr, query, args), now) |
||||
if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "query") |
||||
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = db.breaker.Allow(); err != nil { |
||||
stats.Incr("tidb:query", "breaker") |
||||
return |
||||
} |
||||
_, c, cancel := db.conf.QueryTimeout.Shrink(c) |
||||
rs, err := db.DB.QueryContext(c, query, args...) |
||||
db.onBreaker(&err) |
||||
stats.Timing("tidb:query", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", db.addr, query, args) |
||||
cancel() |
||||
return |
||||
} |
||||
rows = &Rows{Rows: rs, cancel: cancel} |
||||
return |
||||
} |
||||
|
||||
func (db *conn) queryRow(c context.Context, query string, args ...interface{}) *Row { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", db.addr, query, args), now) |
||||
t, ok := trace.FromContext(c) |
||||
if ok { |
||||
t = t.Fork(_family, "queryrow") |
||||
t.SetTag(trace.String(trace.TagAddress, db.addr), trace.String(trace.TagComment, query)) |
||||
} |
||||
if err := db.breaker.Allow(); err != nil { |
||||
stats.Incr("tidb:queryrow", "breaker") |
||||
return &Row{db: db, t: t, err: err} |
||||
} |
||||
_, c, cancel := db.conf.QueryTimeout.Shrink(c) |
||||
r := db.DB.QueryRowContext(c, query, args...) |
||||
stats.Timing("tidb:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||
return &Row{db: db, Row: r, query: query, args: args, t: t, cancel: cancel} |
||||
} |
||||
|
||||
// Close closes the statement.
|
||||
func (s *Stmt) Close() (err error) { |
||||
stmt, ok := s.stmt.Load().(*sql.Stmt) |
||||
if ok { |
||||
err = errors.WithStack(stmt.Close()) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (s *Stmt) prepare() (st *sql.Stmt) { |
||||
var ok bool |
||||
if st, ok = s.stmt.Load().(*sql.Stmt); ok { |
||||
return |
||||
} |
||||
var err error |
||||
if st, err = s.db.Prepare(s.query); err == nil { |
||||
s.stmt.Store(st) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Exec executes a prepared statement with the given arguments and returns a
|
||||
// Result summarizing the effect of the statement.
|
||||
func (s *Stmt) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now) |
||||
if s.tx { |
||||
if s.t != nil { |
||||
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||
} |
||||
} else if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "exec") |
||||
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = s.db.breaker.Allow(); err != nil { |
||||
stats.Incr("tidb:stmt:exec", "breaker") |
||||
return |
||||
} |
||||
stmt := s.prepare() |
||||
if stmt == nil { |
||||
err = ErrStmtNil |
||||
return |
||||
} |
||||
_, c, cancel := s.db.conf.ExecTimeout.Shrink(c) |
||||
res, err = stmt.ExecContext(c, args...) |
||||
cancel() |
||||
s.db.onBreaker(&err) |
||||
stats.Timing("tidb:stmt:exec", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", s.db.addr, s.query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Query executes a prepared query statement with the given arguments and
|
||||
// returns the query results as a *Rows.
|
||||
func (s *Stmt) Query(c context.Context, args ...interface{}) (rows *Rows, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now) |
||||
if s.tx { |
||||
if s.t != nil { |
||||
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||
} |
||||
} else if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "query") |
||||
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query)) |
||||
defer t.Finish(&err) |
||||
} |
||||
if err = s.db.breaker.Allow(); err != nil { |
||||
stats.Incr("tidb:stmt:query", "breaker") |
||||
return |
||||
} |
||||
stmt := s.prepare() |
||||
if stmt == nil { |
||||
err = ErrStmtNil |
||||
return |
||||
} |
||||
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c) |
||||
rs, err := stmt.QueryContext(c, args...) |
||||
s.db.onBreaker(&err) |
||||
stats.Timing("tidb:stmt:query", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", s.db.addr, s.query, args) |
||||
cancel() |
||||
return |
||||
} |
||||
rows = &Rows{Rows: rs, cancel: cancel} |
||||
return |
||||
} |
||||
|
||||
// QueryRow executes a prepared query statement with the given arguments.
|
||||
// If an error occurs during the execution of the statement, that error will
|
||||
// be returned by a call to Scan on the returned *Row, which is always non-nil.
|
||||
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
|
||||
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
|
||||
func (s *Stmt) QueryRow(c context.Context, args ...interface{}) (row *Row) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", s.db.addr, s.query, args), now) |
||||
row = &Row{db: s.db, query: s.query, args: args} |
||||
if s.tx { |
||||
if s.t != nil { |
||||
s.t.SetTag(trace.String(trace.TagAnnotation, s.query)) |
||||
} |
||||
} else if t, ok := trace.FromContext(c); ok { |
||||
t = t.Fork(_family, "queryrow") |
||||
t.SetTag(trace.String(trace.TagAddress, s.db.addr), trace.String(trace.TagComment, s.query)) |
||||
row.t = t |
||||
} |
||||
if row.err = s.db.breaker.Allow(); row.err != nil { |
||||
stats.Incr("tidb:stmt:queryrow", "breaker") |
||||
return |
||||
} |
||||
stmt := s.prepare() |
||||
if stmt == nil { |
||||
return |
||||
} |
||||
_, c, cancel := s.db.conf.QueryTimeout.Shrink(c) |
||||
row.Row = stmt.QueryRowContext(c, args...) |
||||
row.cancel = cancel |
||||
stats.Timing("tidb:stmt:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||
return |
||||
} |
||||
|
||||
func (s *Stmts) prepare(conn *conn) (st *Stmt) { |
||||
if conn == nil { |
||||
conn = s.db.conn() |
||||
} |
||||
s.mu.RLock() |
||||
st = s.sts[conn.addr] |
||||
s.mu.RUnlock() |
||||
if st == nil { |
||||
st = conn.prepared(s.query) |
||||
s.mu.Lock() |
||||
s.sts[conn.addr] = st |
||||
s.mu.Unlock() |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Exec executes a prepared statement with the given arguments and returns a
|
||||
// Result summarizing the effect of the statement.
|
||||
func (s *Stmts) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) { |
||||
return s.prepare(nil).Exec(c, args...) |
||||
} |
||||
|
||||
// Query executes a prepared query statement with the given arguments and
|
||||
// returns the query results as a *Rows.
|
||||
func (s *Stmts) Query(c context.Context, args ...interface{}) (rows *Rows, err error) { |
||||
return s.prepare(nil).Query(c, args...) |
||||
} |
||||
|
||||
// QueryRow executes a prepared query statement with the given arguments.
|
||||
// If an error occurs during the execution of the statement, that error will
|
||||
// be returned by a call to Scan on the returned *Row, which is always non-nil.
|
||||
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
|
||||
// Otherwise, the *Row's Scan scans the first selected row and discards the rest.
|
||||
func (s *Stmts) QueryRow(c context.Context, args ...interface{}) (row *Row) { |
||||
return s.prepare(nil).QueryRow(c, args...) |
||||
} |
||||
|
||||
// Close closes the statement.
|
||||
func (s *Stmts) Close() (err error) { |
||||
for _, st := range s.sts { |
||||
if err = errors.WithStack(st.Close()); err != nil { |
||||
return |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Commit commits the transaction.
|
||||
func (tx *Tx) Commit() (err error) { |
||||
err = tx.tx.Commit() |
||||
tx.cancel() |
||||
tx.db.onBreaker(&err) |
||||
if tx.t != nil { |
||||
tx.t.Finish(&err) |
||||
} |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Rollback aborts the transaction.
|
||||
func (tx *Tx) Rollback() (err error) { |
||||
err = tx.tx.Rollback() |
||||
tx.cancel() |
||||
tx.db.onBreaker(&err) |
||||
if tx.t != nil { |
||||
tx.t.Finish(&err) |
||||
} |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Exec executes a query that doesn't return rows. For example: an INSERT and
|
||||
// UPDATE.
|
||||
func (tx *Tx) Exec(query string, args ...interface{}) (res sql.Result, err error) { |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Exec addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now) |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("exec %s", query))) |
||||
} |
||||
res, err = tx.tx.ExecContext(tx.c, query, args...) |
||||
stats.Timing("tidb:tx:exec", int64(time.Since(now)/time.Millisecond)) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "addr: %s exec:%s, args:%+v", tx.db.addr, query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Query executes a query that returns rows, typically a SELECT.
|
||||
func (tx *Tx) Query(query string, args ...interface{}) (rows *Rows, err error) { |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("query %s", query))) |
||||
} |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("Query addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now) |
||||
defer func() { |
||||
stats.Timing("tidb:tx:query", int64(time.Since(now)/time.Millisecond)) |
||||
}() |
||||
rs, err := tx.tx.QueryContext(tx.c, query, args...) |
||||
if err == nil { |
||||
rows = &Rows{Rows: rs} |
||||
} else { |
||||
err = errors.Wrapf(err, "addr: %s, query:%s, args:%+v", tx.db.addr, query, args) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// QueryRow executes a query that is expected to return at most one row.
|
||||
// QueryRow always returns a non-nil value. Errors are deferred until Row's
|
||||
// Scan method is called.
|
||||
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("queryrow %s", query))) |
||||
} |
||||
now := time.Now() |
||||
defer slowLog(fmt.Sprintf("QueryRow addr: %s query(%s) args(%+v)", tx.db.addr, query, args), now) |
||||
defer func() { |
||||
stats.Timing("tidb:tx:queryrow", int64(time.Since(now)/time.Millisecond)) |
||||
}() |
||||
r := tx.tx.QueryRowContext(tx.c, query, args...) |
||||
return &Row{Row: r, db: tx.db, query: query, args: args} |
||||
} |
||||
|
||||
// Stmt returns a transaction-specific prepared statement from an existing statement.
|
||||
func (tx *Tx) Stmt(stmt *Stmt) *Stmt { |
||||
if stmt == nil { |
||||
return nil |
||||
} |
||||
as, ok := stmt.stmt.Load().(*sql.Stmt) |
||||
if !ok { |
||||
return nil |
||||
} |
||||
ts := tx.tx.StmtContext(tx.c, as) |
||||
st := &Stmt{query: stmt.query, tx: true, t: tx.t, db: tx.db} |
||||
st.stmt.Store(ts) |
||||
return st |
||||
} |
||||
|
||||
// Stmts returns a transaction-specific prepared statement from an existing statement.
|
||||
func (tx *Tx) Stmts(stmt *Stmts) *Stmt { |
||||
return tx.Stmt(stmt.prepare(tx.db)) |
||||
} |
||||
|
||||
// Prepare creates a prepared statement for use within a transaction.
|
||||
// The returned statement operates within the transaction and can no longer be
|
||||
// used once the transaction has been committed or rolled back.
|
||||
// To use an existing prepared statement on this transaction, see Tx.Stmt.
|
||||
func (tx *Tx) Prepare(query string) (*Stmt, error) { |
||||
if tx.t != nil { |
||||
tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("prepare %s", query))) |
||||
} |
||||
defer slowLog(fmt.Sprintf("Prepare addr: %s query(%s)", tx.db.addr, query), time.Now()) |
||||
stmt, err := tx.tx.Prepare(query) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "addr: %s prepare %s", tx.db.addr, query) |
||||
return nil, err |
||||
} |
||||
st := &Stmt{query: query, tx: true, t: tx.t, db: tx.db} |
||||
st.stmt.Store(stmt) |
||||
return st, nil |
||||
} |
||||
|
||||
// parseDSNAddr parse dsn name and return addr.
|
||||
func parseDSNAddr(dsn string) (addr string) { |
||||
if dsn == "" { |
||||
return |
||||
} |
||||
cfg, err := mysql.ParseDSN(dsn) |
||||
if err != nil { |
||||
return |
||||
} |
||||
addr = cfg.Addr |
||||
return |
||||
} |
||||
|
||||
func genDSN(dsn, addr string) (res string) { |
||||
cfg, err := mysql.ParseDSN(dsn) |
||||
if err != nil { |
||||
return |
||||
} |
||||
cfg.Addr = addr |
||||
cfg.Net = "tcp" |
||||
res = cfg.FormatDSN() |
||||
return |
||||
} |
||||
|
||||
func slowLog(statement string, now time.Time) { |
||||
du := time.Since(now) |
||||
if du > _slowLogDuration { |
||||
log.Warn("%s slow log statement: %s time: %v", _family, statement, du) |
||||
} |
||||
} |
@ -0,0 +1,38 @@ |
||||
package tidb |
||||
|
||||
import ( |
||||
"github.com/bilibili/Kratos/pkg/log" |
||||
"github.com/bilibili/Kratos/pkg/net/netutil/breaker" |
||||
"github.com/bilibili/Kratos/pkg/stat" |
||||
"github.com/bilibili/Kratos/pkg/time" |
||||
|
||||
// database driver
|
||||
_ "github.com/go-sql-driver/mysql" |
||||
) |
||||
|
||||
var stats = stat.DB |
||||
|
||||
// Config mysql config.
|
||||
type Config struct { |
||||
DSN string // dsn
|
||||
Active int // pool
|
||||
Idle int // pool
|
||||
IdleTimeout time.Duration // connect max life time.
|
||||
QueryTimeout time.Duration // query sql timeout
|
||||
ExecTimeout time.Duration // execute sql timeout
|
||||
TranTimeout time.Duration // transaction sql timeout
|
||||
Breaker *breaker.Config // breaker
|
||||
} |
||||
|
||||
// NewTiDB new db and retry connection when has error.
|
||||
func NewTiDB(c *Config) (db *DB) { |
||||
if c.QueryTimeout == 0 || c.ExecTimeout == 0 || c.TranTimeout == 0 { |
||||
panic("tidb must be set query/execute/transction timeout") |
||||
} |
||||
db, err := Open(c) |
||||
if err != nil { |
||||
log.Error("open tidb error(%v)", err) |
||||
panic(err) |
||||
} |
||||
return |
||||
} |
Loading…
Reference in new issue