From 76da31effb5ece597cff22e970816a5ddd7a7659 Mon Sep 17 00:00:00 2001 From: felixhao Date: Fri, 1 Feb 2019 15:57:43 +0800 Subject: [PATCH] add kratos pkg --- .travis.yml | 29 ++ README.md | 23 +- go.mod | 14 + go.sum | 53 +++ pkg/conf/dsn/README.md | 5 + pkg/conf/dsn/doc.go | 63 +++ pkg/conf/dsn/dsn.go | 106 +++++ pkg/conf/dsn/dsn_test.go | 79 ++++ pkg/conf/dsn/example_test.go | 31 ++ pkg/conf/dsn/query.go | 422 +++++++++++++++++++ pkg/conf/dsn/query_test.go | 128 ++++++ pkg/conf/env/README.md | 5 + pkg/conf/env/env.go | 68 ++++ pkg/conf/env/env_test.go | 104 +++++ pkg/container/pool/README.md | 5 + pkg/container/pool/list.go | 227 +++++++++++ pkg/container/pool/list_test.go | 322 +++++++++++++++ pkg/container/pool/pool.go | 62 +++ pkg/container/pool/slice.go | 423 +++++++++++++++++++ pkg/container/pool/slice_test.go | 350 ++++++++++++++++ pkg/container/queue/aqm/README.md | 5 + pkg/container/queue/aqm/codel.go | 185 +++++++++ pkg/container/queue/aqm/codel_test.go | 97 +++++ pkg/ecode/common_ecode.go | 19 + pkg/ecode/ecode.go | 120 ++++++ pkg/naming/README.md | 5 + pkg/naming/naming.go | 58 +++ pkg/net/ip/ip.go | 66 +++ pkg/net/metadata/README.md | 5 + pkg/net/metadata/key.go | 39 ++ pkg/net/metadata/metadata.go | 134 +++++++ pkg/net/metadata/metadata_test.go | 96 +++++ pkg/net/trace/README.md | 20 + pkg/net/trace/config.go | 75 ++++ pkg/net/trace/config_test.go | 33 ++ pkg/net/trace/const.go | 7 + pkg/net/trace/context.go | 110 +++++ pkg/net/trace/context_test.go | 26 ++ pkg/net/trace/dapper.go | 189 +++++++++ pkg/net/trace/dapper_test.go | 136 +++++++ pkg/net/trace/marshal.go | 106 +++++ pkg/net/trace/marshal_test.go | 18 + pkg/net/trace/noop.go | 45 +++ pkg/net/trace/option.go | 17 + pkg/net/trace/propagation.go | 177 ++++++++ pkg/net/trace/proto/span.pb.go | 557 ++++++++++++++++++++++++++ pkg/net/trace/proto/span.proto | 77 ++++ pkg/net/trace/report.go | 138 +++++++ pkg/net/trace/report_test.go | 88 ++++ pkg/net/trace/sample.go | 67 ++++ pkg/net/trace/sample_test.go | 35 ++ pkg/net/trace/span.go | 108 +++++ pkg/net/trace/span_test.go | 108 +++++ pkg/net/trace/tag.go | 182 +++++++++ pkg/net/trace/tag_test.go | 1 + pkg/net/trace/tracer.go | 92 +++++ pkg/net/trace/util.go | 68 ++++ pkg/net/trace/util_test.go | 21 + pkg/stat/README.md | 5 + pkg/stat/prom/README.md | 5 + pkg/stat/prom/prometheus.go | 163 ++++++++ pkg/stat/stat.go | 25 ++ pkg/stat/summary/README.md | 5 + pkg/stat/summary/summary.go | 129 ++++++ pkg/stat/summary/summary_test.go | 69 ++++ pkg/stat/sys/cpu/README.md | 7 + pkg/stat/sys/cpu/cgroup.go | 125 ++++++ pkg/stat/sys/cpu/cgroup_test.go | 11 + pkg/stat/sys/cpu/cpu.go | 107 +++++ pkg/stat/sys/cpu/cpu_darwin.go | 20 + pkg/stat/sys/cpu/cpu_linux.go | 147 +++++++ pkg/stat/sys/cpu/cpu_other.go | 11 + pkg/stat/sys/cpu/stat_test.go | 20 + pkg/stat/sys/cpu/sysconfig_notcgo.go | 14 + pkg/stat/sys/cpu/util.go | 121 ++++++ pkg/str/str.go | 55 +++ pkg/str/str_test.go | 60 +++ pkg/sync/errgroup/README.md | 3 + pkg/sync/errgroup/doc.go | 47 +++ pkg/sync/errgroup/errgroup.go | 119 ++++++ pkg/sync/errgroup/errgroup_test.go | 266 ++++++++++++ pkg/sync/errgroup/example_test.go | 63 +++ pkg/time/README.md | 5 + pkg/time/time.go | 59 +++ pkg/time/time_test.go | 60 +++ 85 files changed, 7569 insertions(+), 1 deletion(-) create mode 100644 .travis.yml create mode 100644 go.mod create mode 100644 go.sum create mode 100644 pkg/conf/dsn/README.md create mode 100644 pkg/conf/dsn/doc.go create mode 100644 pkg/conf/dsn/dsn.go create mode 100644 pkg/conf/dsn/dsn_test.go create mode 100644 pkg/conf/dsn/example_test.go create mode 100644 pkg/conf/dsn/query.go create mode 100644 pkg/conf/dsn/query_test.go create mode 100644 pkg/conf/env/README.md create mode 100644 pkg/conf/env/env.go create mode 100644 pkg/conf/env/env_test.go create mode 100644 pkg/container/pool/README.md create mode 100644 pkg/container/pool/list.go create mode 100644 pkg/container/pool/list_test.go create mode 100644 pkg/container/pool/pool.go create mode 100644 pkg/container/pool/slice.go create mode 100644 pkg/container/pool/slice_test.go create mode 100644 pkg/container/queue/aqm/README.md create mode 100644 pkg/container/queue/aqm/codel.go create mode 100644 pkg/container/queue/aqm/codel_test.go create mode 100644 pkg/ecode/common_ecode.go create mode 100644 pkg/ecode/ecode.go create mode 100644 pkg/naming/README.md create mode 100644 pkg/naming/naming.go create mode 100644 pkg/net/ip/ip.go create mode 100644 pkg/net/metadata/README.md create mode 100644 pkg/net/metadata/key.go create mode 100644 pkg/net/metadata/metadata.go create mode 100644 pkg/net/metadata/metadata_test.go create mode 100644 pkg/net/trace/README.md create mode 100644 pkg/net/trace/config.go create mode 100644 pkg/net/trace/config_test.go create mode 100644 pkg/net/trace/const.go create mode 100644 pkg/net/trace/context.go create mode 100644 pkg/net/trace/context_test.go create mode 100644 pkg/net/trace/dapper.go create mode 100644 pkg/net/trace/dapper_test.go create mode 100644 pkg/net/trace/marshal.go create mode 100644 pkg/net/trace/marshal_test.go create mode 100644 pkg/net/trace/noop.go create mode 100644 pkg/net/trace/option.go create mode 100644 pkg/net/trace/propagation.go create mode 100644 pkg/net/trace/proto/span.pb.go create mode 100644 pkg/net/trace/proto/span.proto create mode 100644 pkg/net/trace/report.go create mode 100644 pkg/net/trace/report_test.go create mode 100644 pkg/net/trace/sample.go create mode 100644 pkg/net/trace/sample_test.go create mode 100644 pkg/net/trace/span.go create mode 100644 pkg/net/trace/span_test.go create mode 100644 pkg/net/trace/tag.go create mode 100644 pkg/net/trace/tag_test.go create mode 100644 pkg/net/trace/tracer.go create mode 100644 pkg/net/trace/util.go create mode 100644 pkg/net/trace/util_test.go create mode 100644 pkg/stat/README.md create mode 100644 pkg/stat/prom/README.md create mode 100644 pkg/stat/prom/prometheus.go create mode 100644 pkg/stat/stat.go create mode 100644 pkg/stat/summary/README.md create mode 100644 pkg/stat/summary/summary.go create mode 100644 pkg/stat/summary/summary_test.go create mode 100644 pkg/stat/sys/cpu/README.md create mode 100644 pkg/stat/sys/cpu/cgroup.go create mode 100644 pkg/stat/sys/cpu/cgroup_test.go create mode 100644 pkg/stat/sys/cpu/cpu.go create mode 100644 pkg/stat/sys/cpu/cpu_darwin.go create mode 100644 pkg/stat/sys/cpu/cpu_linux.go create mode 100644 pkg/stat/sys/cpu/cpu_other.go create mode 100644 pkg/stat/sys/cpu/stat_test.go create mode 100644 pkg/stat/sys/cpu/sysconfig_notcgo.go create mode 100644 pkg/stat/sys/cpu/util.go create mode 100644 pkg/str/str.go create mode 100644 pkg/str/str_test.go create mode 100644 pkg/sync/errgroup/README.md create mode 100644 pkg/sync/errgroup/doc.go create mode 100644 pkg/sync/errgroup/errgroup.go create mode 100644 pkg/sync/errgroup/errgroup_test.go create mode 100644 pkg/sync/errgroup/example_test.go create mode 100644 pkg/time/README.md create mode 100644 pkg/time/time.go create mode 100644 pkg/time/time_test.go diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 000000000..eb83d1d9d --- /dev/null +++ b/.travis.yml @@ -0,0 +1,29 @@ +language: go + +go: + - 1.11.x + +# Only clone the most recent commit. +git: + depth: 1 + +# Force-enable Go modules. This will be unnecessary when Go 1.12 lands. +env: + - GO111MODULE=on + +# Skip the install step. Don't `go get` dependencies. Only build with the code +# in vendor/ +install: true + +# Anything in before_script that returns a nonzero exit code will flunk the +# build and immediately stop. It's sorta like having set -e enabled in bash. +# Make sure golangci-lint is vendored. +before_script: + - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $GOPATH/bin + +script: + - go test ./... + - go build ./... + +after_success: + - golangci-lint run # run a bunch of code checkers/linters in parallel diff --git a/README.md b/README.md index a9df5240d..405ba3c7b 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,23 @@ # Kratos -Hello "war"ld! + +Kratos是[bilibili](https://www.bilibili.com)开源的一套Go微服务框架,基于“大仓(monorepo)”理念,包含大量微服务相关框架及工具,如:discovery(服务注册发现)、blademaster(HTTP框架)、warden(gRPC封装)、log、breaker、dapper(trace)、cache&db sdk、kratos(代码生成等工具)等等。我们致力于提供完整的微服务研发体验,大仓整合相关框架及工具后可对研发者无感,整体研发过程可聚焦于业务交付。对开发者而言,整套Kratos框架也是不错的学习仓库,可以学到[bilibili](https://www.bilibili.com)在微服务方面的技术积累和经验。 + + +# TODOs + +- [ ] log&log-agent @围城 +- [ ] config @志辉 +- [ ] bm @佳辉 +- [ ] warden @龙虾 +- [ ] naming discovery @堂辉 +- [ ] cache&database @小旭 +- [ ] kratos tool @普余 + +# issues + +***有权限后加到issue里*** + +1. 需要考虑配置中心开源方式:类discovery单独 或 集成在大仓库 +2. log-agent和dapper需要完整的解决方案,包含ES集群、dapperUI +3. databus是否需要开源 +4. proto文件相关生成工具正和到kratos工具内 diff --git a/go.mod b/go.mod new file mode 100644 index 000000000..1d34460e4 --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module github.com/bilibili/Kratos + +require ( + github.com/go-playground/locales v0.12.1 // indirect + github.com/go-playground/universal-translator v0.16.0 // indirect + github.com/golang/protobuf v1.2.0 + github.com/leodido/go-urn v1.1.0 // indirect + github.com/pkg/errors v0.8.1 + github.com/prometheus/client_golang v0.9.2 + github.com/stretchr/testify v1.3.0 + google.golang.org/grpc v1.18.0 + gopkg.in/go-playground/assert.v1 v1.2.1 // indirect + gopkg.in/go-playground/validator.v9 v9.26.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 000000000..5c95ccf7b --- /dev/null +++ b/go.sum @@ -0,0 +1,53 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-playground/locales v0.12.1 h1:2FITxuFt/xuCNP1Acdhv62OzaCiviiE4kotfhkmOqEc= +github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM= +github.com/go-playground/universal-translator v0.16.0 h1:X++omBR/4cE2MNg91AoC3rmGrCjJ8eAeUP/K/EKx4DM= +github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/leodido/go-urn v1.1.0 h1:Sm1gr51B1kKyfD2BlRcLSiEkffoG96g6TPv6eRoEiB8= +github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= +github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.9.2 h1:awm861/B8OKDd2I/6o1dy3ra4BamzKhYOiGItCeZ740= +github.com/prometheus/client_golang v0.9.2/go.mod h1:OsXs2jCmiKlQ1lTBmv21f2mNfw4xf/QclQDMrYNZzcM= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 h1:idejC8f05m9MGOsuEi1ATq9shN03HrxNkD/luQvxCv8= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/common v0.0.0-20181126121408-4724e9255275 h1:PnBWHBf+6L0jOqq0gIVUe6Yk0/QMZ640k6NvkxcBf+8= +github.com/prometheus/common v0.0.0-20181126121408-4724e9255275/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= +github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a h1:9a8MnZMP0X2nLJdBg+pBmGgkJlSaKC2KaQmTCk1XDtE= +github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/grpc v1.18.0 h1:IZl7mfBGfbhYx2p2rKRtYgDFw6SBz+kclmxYrCksPPA= +google.golang.org/grpc v1.18.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= +gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM= +gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= +gopkg.in/go-playground/validator.v9 v9.26.0 h1:2NPPsBpD0ZoxshmLWewQru8rWmbT5JqSzz9D1ZrAjYQ= +gopkg.in/go-playground/validator.v9 v9.26.0/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ= +honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/pkg/conf/dsn/README.md b/pkg/conf/dsn/README.md new file mode 100644 index 000000000..57b23b585 --- /dev/null +++ b/pkg/conf/dsn/README.md @@ -0,0 +1,5 @@ +# dsn + +## 项目简介 + +通用数据源地址解析 diff --git a/pkg/conf/dsn/doc.go b/pkg/conf/dsn/doc.go new file mode 100644 index 000000000..d93013d87 --- /dev/null +++ b/pkg/conf/dsn/doc.go @@ -0,0 +1,63 @@ +// Package dsn implements dsn parse with struct bind +/* +DSN 格式类似 URI, DSN 结构如下图 + + network:[//[username[:password]@]address[:port][,address[:port]]][/path][?query][#fragment] + +与 URI 的主要区别在于 scheme 被替换为 network, host 被替换为 address 并且支持多个 address. +network 与 net 包中 network 意义相同, tcp、udp、unix 等, address 支持多个使用 ',' 分割, 如果 +network 为 unix 等本地 sock 协议则使用 Path, 有且只有一个 + +dsn 包主要提供了 Parse, Bind 和 validate 功能 + +Parse 解析 dsn 字符串成 DSN struct, DSN struct 与 url.URL 几乎完全一样 + +Bind 提供将 DSN 数据绑定到一个 struct 的功能, 通过 tag dsn:"key,[default]" 指定绑定的字段, 目前支持两种类型的数据绑定 + +内置变量 key: + network string tcp, udp, unix 等, 参考 net 包中的 network + username string + password string + address string or []string address 可以绑定到 string 或者 []string, 如果为 string 则取 address 第一个 + +Query: 通过 query.name 可以取到 query 上的数据 + + 数组可以通过传递多个获得 + + array=1&array=2&array3 -> []int `tag:"query.array"` + + struct 支持嵌套 + + foo.sub.name=hello&foo.tm=hello + + struct Foo { + Tm string `dsn:"query.tm"` + Sub struct { + Name string `dsn:"query.name"` + } `dsn:"query.sub"` + } + +默认值: 通过 dsn:"key,[default]" 默认值暂时不支持数组 + +忽略 Bind: 通过 dsn:"-" 忽略 Bind + +自定义 Bind: 可以同时实现 encoding.TextUnmarshaler 自定义 Bind 实现 + +Validate: 参考 https://github.com/go-playground/validator + +使用参考: example_test.go + +DSN 命名规范: + +没有历史遗留的情况下,尽量使用 Address, Network, Username, Password 等命名,代替之前的 Proto 和 Addr 等命名 + +Query 命名参考, 使用驼峰小写开头: + + timeout 通用超时 + dialTimeout 连接建立超时 + readTimeout 读操作超时 + writeTimeout 写操作超时 + readsTimeout 批量读超时 + writesTimeout 批量写超时 +*/ +package dsn diff --git a/pkg/conf/dsn/dsn.go b/pkg/conf/dsn/dsn.go new file mode 100644 index 000000000..7cb209e2b --- /dev/null +++ b/pkg/conf/dsn/dsn.go @@ -0,0 +1,106 @@ +package dsn + +import ( + "net/url" + "reflect" + "strings" + + validator "gopkg.in/go-playground/validator.v9" +) + +var _validator *validator.Validate + +func init() { + _validator = validator.New() +} + +// DSN a DSN represents a parsed DSN as same as url.URL. +type DSN struct { + *url.URL +} + +// Bind dsn to specify struct and validate use use go-playground/validator format +// +// The bind of each struct field can be customized by the format string +// stored under the 'dsn' key in the struct field's tag. The format string +// gives the name of the field, possibly followed by a comma-separated +// list of options. The name may be empty in order to specify options +// without overriding the default field name. +// +// A two type data you can bind to struct +// built-in values, use below keys to bind built-in value +// username +// password +// address +// network +// the value in query string, use query.{name} to bind value in query string +// +// As a special case, if the field tag is "-", the field is always omitted. +// NOTE: that a field with name "-" can still be generated using the tag "-,". +// +// Examples of struct field tags and their meanings: +// // Field bind username +// Field string `dsn:"username"` +// // Field is ignored by this package. +// Field string `dsn:"-"` +// // Field bind value from query +// Field string `dsn:"query.name"` +// +func (d *DSN) Bind(v interface{}) (url.Values, error) { + assignFuncs := make(map[string]assignFunc) + if d.User != nil { + username := d.User.Username() + password, ok := d.User.Password() + if ok { + assignFuncs["password"] = stringsAssignFunc(password) + } + assignFuncs["username"] = stringsAssignFunc(username) + } + assignFuncs["address"] = addressesAssignFunc(d.Addresses()) + assignFuncs["network"] = stringsAssignFunc(d.Scheme) + query, err := bindQuery(d.Query(), v, assignFuncs) + if err != nil { + return nil, err + } + return query, _validator.Struct(v) +} + +func addressesAssignFunc(addresses []string) assignFunc { + return func(v reflect.Value, to tagOpt) error { + if v.Kind() == reflect.String { + if addresses[0] == "" && to.Default != "" { + v.SetString(to.Default) + } else { + v.SetString(addresses[0]) + } + return nil + } + if !(v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.String) { + return &BindTypeError{Value: strings.Join(addresses, ","), Type: v.Type()} + } + vals := reflect.MakeSlice(v.Type(), len(addresses), len(addresses)) + for i, address := range addresses { + vals.Index(i).SetString(address) + } + if v.CanSet() { + v.Set(vals) + } + return nil + } +} + +// Addresses parse host split by ',' +// For Unix networks, return ['path'] +func (d *DSN) Addresses() []string { + switch d.Scheme { + case "unix", "unixgram", "unixpacket": + return []string{d.Path} + } + return strings.Split(d.Host, ",") +} + +// Parse parses rawdsn into a URL structure. +func Parse(rawdsn string) (*DSN, error) { + u, err := url.Parse(rawdsn) + return &DSN{URL: u}, err +} diff --git a/pkg/conf/dsn/dsn_test.go b/pkg/conf/dsn/dsn_test.go new file mode 100644 index 000000000..815ead818 --- /dev/null +++ b/pkg/conf/dsn/dsn_test.go @@ -0,0 +1,79 @@ +package dsn + +import ( + "net/url" + "reflect" + "testing" + "time" + + xtime "github.com/bilibili/Kratos/pkg/time" +) + +type config struct { + Network string `dsn:"network"` + Addresses []string `dsn:"address"` + Username string `dsn:"username"` + Password string `dsn:"password"` + Timeout xtime.Duration `dsn:"query.timeout"` + Sub Sub `dsn:"query.sub"` + Def string `dsn:"query.def,hello"` +} + +type Sub struct { + Foo int `dsn:"query.foo"` +} + +func TestBind(t *testing.T) { + var cfg config + rawdsn := "tcp://root:toor@172.12.23.34,178.23.34.45?timeout=1s&sub.foo=1&hello=world" + dsn, err := Parse(rawdsn) + if err != nil { + t.Fatal(err) + } + values, err := dsn.Bind(&cfg) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(values, url.Values{"hello": {"world"}}) { + t.Errorf("unexpect values get %v", values) + } + cfg2 := config{ + Network: "tcp", + Addresses: []string{"172.12.23.34", "178.23.34.45"}, + Password: "toor", + Username: "root", + Sub: Sub{Foo: 1}, + Timeout: xtime.Duration(time.Second), + Def: "hello", + } + if !reflect.DeepEqual(cfg, cfg2) { + t.Errorf("unexpect config get %v, expect %v", cfg, cfg2) + } +} + +type config2 struct { + Network string `dsn:"network"` + Address string `dsn:"address"` + Timeout xtime.Duration `dsn:"query.timeout"` +} + +func TestUnix(t *testing.T) { + var cfg config2 + rawdsn := "unix:///run/xxx.sock?timeout=1s&sub.foo=1&hello=world" + dsn, err := Parse(rawdsn) + if err != nil { + t.Fatal(err) + } + _, err = dsn.Bind(&cfg) + if err != nil { + t.Error(err) + } + cfg2 := config2{ + Network: "unix", + Address: "/run/xxx.sock", + Timeout: xtime.Duration(time.Second), + } + if !reflect.DeepEqual(cfg, cfg2) { + t.Errorf("unexpect config2 get %v, expect %v", cfg, cfg2) + } +} diff --git a/pkg/conf/dsn/example_test.go b/pkg/conf/dsn/example_test.go new file mode 100644 index 000000000..ed81e4204 --- /dev/null +++ b/pkg/conf/dsn/example_test.go @@ -0,0 +1,31 @@ +package dsn_test + +import ( + "log" + + "github.com/bilibili/Kratos/pkg/conf/dsn" + xtime "github.com/bilibili/Kratos/pkg/time" +) + +// Config struct +type Config struct { + Network string `dsn:"network" validate:"required"` + Host string `dsn:"host" validate:"required"` + Username string `dsn:"username" validate:"required"` + Password string `dsn:"password" validate:"required"` + Timeout xtime.Duration `dsn:"query.timeout,1s"` + Offset int `dsn:"query.offset" validate:"gte=0"` +} + +func ExampleParse() { + cfg := &Config{} + d, err := dsn.Parse("tcp://root:toor@172.12.12.23:2233?timeout=10s") + if err != nil { + log.Fatal(err) + } + _, err = d.Bind(cfg) + if err != nil { + log.Fatal(err) + } + log.Printf("%v", cfg) +} diff --git a/pkg/conf/dsn/query.go b/pkg/conf/dsn/query.go new file mode 100644 index 000000000..f622d0b07 --- /dev/null +++ b/pkg/conf/dsn/query.go @@ -0,0 +1,422 @@ +package dsn + +import ( + "encoding" + "net/url" + "reflect" + "runtime" + "strconv" + "strings" +) + +const ( + _tagID = "dsn" + _queryPrefix = "query." +) + +// InvalidBindError describes an invalid argument passed to DecodeQuery. +// (The argument to DecodeQuery must be a non-nil pointer.) +type InvalidBindError struct { + Type reflect.Type +} + +func (e *InvalidBindError) Error() string { + if e.Type == nil { + return "Bind(nil)" + } + + if e.Type.Kind() != reflect.Ptr { + return "Bind(non-pointer " + e.Type.String() + ")" + } + return "Bind(nil " + e.Type.String() + ")" +} + +// BindTypeError describes a query value that was +// not appropriate for a value of a specific Go type. +type BindTypeError struct { + Value string + Type reflect.Type +} + +func (e *BindTypeError) Error() string { + return "cannot decode " + e.Value + " into Go value of type " + e.Type.String() +} + +type assignFunc func(v reflect.Value, to tagOpt) error + +func stringsAssignFunc(val string) assignFunc { + return func(v reflect.Value, to tagOpt) error { + if v.Kind() != reflect.String || !v.CanSet() { + return &BindTypeError{Value: "string", Type: v.Type()} + } + if val == "" { + v.SetString(to.Default) + } else { + v.SetString(val) + } + return nil + } +} + +// bindQuery parses url.Values and stores the result in the value pointed to by v. +// if v is nil or not a pointer, bindQuery returns an InvalidDecodeError +func bindQuery(query url.Values, v interface{}, assignFuncs map[string]assignFunc) (url.Values, error) { + if assignFuncs == nil { + assignFuncs = make(map[string]assignFunc) + } + d := decodeState{ + data: query, + used: make(map[string]bool), + assignFuncs: assignFuncs, + } + err := d.decode(v) + ret := d.unused() + return ret, err +} + +type tagOpt struct { + Name string + Default string +} + +func parseTag(tag string) tagOpt { + vs := strings.SplitN(tag, ",", 2) + if len(vs) == 2 { + return tagOpt{Name: vs[0], Default: vs[1]} + } + return tagOpt{Name: vs[0]} +} + +type decodeState struct { + data url.Values + used map[string]bool + assignFuncs map[string]assignFunc +} + +func (d *decodeState) unused() url.Values { + ret := make(url.Values) + for k, v := range d.data { + if !d.used[k] { + ret[k] = v + } + } + return ret +} + +func (d *decodeState) decode(v interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + err = r.(error) + } + }() + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return &InvalidBindError{reflect.TypeOf(v)} + } + return d.root(rv) +} + +func (d *decodeState) root(v reflect.Value) error { + var tu encoding.TextUnmarshaler + tu, v = d.indirect(v) + if tu != nil { + return tu.UnmarshalText([]byte(d.data.Encode())) + } + // TODO support map, slice as root + if v.Kind() != reflect.Struct { + return &BindTypeError{Value: d.data.Encode(), Type: v.Type()} + } + tv := v.Type() + for i := 0; i < tv.NumField(); i++ { + fv := v.Field(i) + field := tv.Field(i) + to := parseTag(field.Tag.Get(_tagID)) + if to.Name == "-" { + continue + } + if af, ok := d.assignFuncs[to.Name]; ok { + if err := af(fv, tagOpt{}); err != nil { + return err + } + continue + } + if !strings.HasPrefix(to.Name, _queryPrefix) { + continue + } + to.Name = to.Name[len(_queryPrefix):] + if err := d.value(fv, "", to); err != nil { + return err + } + } + return nil +} + +func combinekey(prefix string, to tagOpt) string { + key := to.Name + if prefix != "" { + key = prefix + "." + key + } + return key +} + +func (d *decodeState) value(v reflect.Value, prefix string, to tagOpt) (err error) { + key := combinekey(prefix, to) + d.used[key] = true + var tu encoding.TextUnmarshaler + tu, v = d.indirect(v) + if tu != nil { + if val, ok := d.data[key]; ok { + return tu.UnmarshalText([]byte(val[0])) + } + if to.Default != "" { + return tu.UnmarshalText([]byte(to.Default)) + } + return + } + switch v.Kind() { + case reflect.Bool: + err = d.valueBool(v, prefix, to) + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: + err = d.valueInt64(v, prefix, to) + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + err = d.valueUint64(v, prefix, to) + case reflect.Float32, reflect.Float64: + err = d.valueFloat64(v, prefix, to) + case reflect.String: + err = d.valueString(v, prefix, to) + case reflect.Slice: + err = d.valueSlice(v, prefix, to) + case reflect.Struct: + err = d.valueStruct(v, prefix, to) + case reflect.Ptr: + if !d.hasKey(combinekey(prefix, to)) { + break + } + if !v.CanSet() { + break + } + nv := reflect.New(v.Type().Elem()) + v.Set(nv) + err = d.value(nv, prefix, to) + } + return +} + +func (d *decodeState) hasKey(key string) bool { + for k := range d.data { + if strings.HasPrefix(k, key+".") || k == key { + return true + } + } + return false +} + +func (d *decodeState) valueBool(v reflect.Value, prefix string, to tagOpt) error { + key := combinekey(prefix, to) + val := d.data.Get(key) + if val == "" { + if to.Default == "" { + return nil + } + val = to.Default + } + return d.setBool(v, val) +} + +func (d *decodeState) setBool(v reflect.Value, val string) error { + bval, err := strconv.ParseBool(val) + if err != nil { + return &BindTypeError{Value: val, Type: v.Type()} + } + v.SetBool(bval) + return nil +} + +func (d *decodeState) valueInt64(v reflect.Value, prefix string, to tagOpt) error { + key := combinekey(prefix, to) + val := d.data.Get(key) + if val == "" { + if to.Default == "" { + return nil + } + val = to.Default + } + return d.setInt64(v, val) +} + +func (d *decodeState) setInt64(v reflect.Value, val string) error { + ival, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return &BindTypeError{Value: val, Type: v.Type()} + } + v.SetInt(ival) + return nil +} + +func (d *decodeState) valueUint64(v reflect.Value, prefix string, to tagOpt) error { + key := combinekey(prefix, to) + val := d.data.Get(key) + if val == "" { + if to.Default == "" { + return nil + } + val = to.Default + } + return d.setUint64(v, val) +} + +func (d *decodeState) setUint64(v reflect.Value, val string) error { + uival, err := strconv.ParseUint(val, 10, 64) + if err != nil { + return &BindTypeError{Value: val, Type: v.Type()} + } + v.SetUint(uival) + return nil +} + +func (d *decodeState) valueFloat64(v reflect.Value, prefix string, to tagOpt) error { + key := combinekey(prefix, to) + val := d.data.Get(key) + if val == "" { + if to.Default == "" { + return nil + } + val = to.Default + } + return d.setFloat64(v, val) +} + +func (d *decodeState) setFloat64(v reflect.Value, val string) error { + fval, err := strconv.ParseFloat(val, 64) + if err != nil { + return &BindTypeError{Value: val, Type: v.Type()} + } + v.SetFloat(fval) + return nil +} + +func (d *decodeState) valueString(v reflect.Value, prefix string, to tagOpt) error { + key := combinekey(prefix, to) + val := d.data.Get(key) + if val == "" { + if to.Default == "" { + return nil + } + val = to.Default + } + return d.setString(v, val) +} + +func (d *decodeState) setString(v reflect.Value, val string) error { + v.SetString(val) + return nil +} + +func (d *decodeState) valueSlice(v reflect.Value, prefix string, to tagOpt) error { + key := combinekey(prefix, to) + strs, ok := d.data[key] + if !ok { + strs = strings.Split(to.Default, ",") + } + if len(strs) == 0 { + return nil + } + et := v.Type().Elem() + var setFunc func(reflect.Value, string) error + switch et.Kind() { + case reflect.Bool: + setFunc = d.setBool + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: + setFunc = d.setInt64 + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + setFunc = d.setUint64 + case reflect.Float32, reflect.Float64: + setFunc = d.setFloat64 + case reflect.String: + setFunc = d.setString + default: + return &BindTypeError{Type: et, Value: strs[0]} + } + vals := reflect.MakeSlice(v.Type(), len(strs), len(strs)) + for i, str := range strs { + if err := setFunc(vals.Index(i), str); err != nil { + return err + } + } + if v.CanSet() { + v.Set(vals) + } + return nil +} + +func (d *decodeState) valueStruct(v reflect.Value, prefix string, to tagOpt) error { + tv := v.Type() + for i := 0; i < tv.NumField(); i++ { + fv := v.Field(i) + field := tv.Field(i) + fto := parseTag(field.Tag.Get(_tagID)) + if fto.Name == "-" { + continue + } + if af, ok := d.assignFuncs[fto.Name]; ok { + if err := af(fv, tagOpt{}); err != nil { + return err + } + continue + } + if !strings.HasPrefix(fto.Name, _queryPrefix) { + continue + } + fto.Name = fto.Name[len(_queryPrefix):] + if err := d.value(fv, to.Name, fto); err != nil { + return err + } + } + return nil +} + +func (d *decodeState) indirect(v reflect.Value) (encoding.TextUnmarshaler, reflect.Value) { + v0 := v + haveAddr := false + + if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + haveAddr = true + v = v.Addr() + } + for { + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() && e.Elem().Kind() == reflect.Ptr { + haveAddr = false + v = e + continue + } + } + + if v.Kind() != reflect.Ptr { + break + } + + if v.Elem().Kind() != reflect.Ptr && v.CanSet() { + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if v.Type().NumMethod() > 0 { + if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { + return u, reflect.Value{} + } + } + if haveAddr { + v = v0 + haveAddr = false + } else { + v = v.Elem() + } + } + return nil, v +} diff --git a/pkg/conf/dsn/query_test.go b/pkg/conf/dsn/query_test.go new file mode 100644 index 000000000..0f7ea15c1 --- /dev/null +++ b/pkg/conf/dsn/query_test.go @@ -0,0 +1,128 @@ +package dsn + +import ( + "net/url" + "reflect" + "testing" + "time" + + xtime "github.com/bilibili/Kratos/pkg/time" +) + +type cfg1 struct { + Name string `dsn:"query.name"` + Def string `dsn:"query.def,hello"` + DefSlice []int `dsn:"query.defslice,1,2,3,4"` + Ignore string `dsn:"-"` + FloatNum float64 `dsn:"query.floatNum"` +} + +type cfg2 struct { + Timeout xtime.Duration `dsn:"query.timeout"` +} + +type cfg3 struct { + Username string `dsn:"username"` + Timeout xtime.Duration `dsn:"query.timeout"` +} + +type cfg4 struct { + Timeout xtime.Duration `dsn:"query.timeout,1s"` +} + +func TestDecodeQuery(t *testing.T) { + type args struct { + query url.Values + v interface{} + assignFuncs map[string]assignFunc + } + tests := []struct { + name string + args args + want url.Values + cfg interface{} + wantErr bool + }{ + { + name: "test generic", + args: args{ + query: url.Values{ + "name": {"hello"}, + "Ignore": {"test"}, + "floatNum": {"22.33"}, + "adb": {"123"}, + }, + v: &cfg1{}, + }, + want: url.Values{ + "Ignore": {"test"}, + "adb": {"123"}, + }, + cfg: &cfg1{ + Name: "hello", + Def: "hello", + DefSlice: []int{1, 2, 3, 4}, + FloatNum: 22.33, + }, + }, + { + name: "test github.com/bilibili/Kratos/pkg/time", + args: args{ + query: url.Values{ + "timeout": {"1s"}, + }, + v: &cfg2{}, + }, + want: url.Values{}, + cfg: &cfg2{xtime.Duration(time.Second)}, + }, + { + name: "test empty github.com/bilibili/Kratos/pkg/time", + args: args{ + query: url.Values{}, + v: &cfg2{}, + }, + want: url.Values{}, + cfg: &cfg2{}, + }, + { + name: "test github.com/bilibili/Kratos/pkg/time", + args: args{ + query: url.Values{}, + v: &cfg4{}, + }, + want: url.Values{}, + cfg: &cfg4{xtime.Duration(time.Second)}, + }, + { + name: "test build-in value", + args: args{ + query: url.Values{ + "timeout": {"1s"}, + }, + v: &cfg3{}, + assignFuncs: map[string]assignFunc{"username": stringsAssignFunc("hello")}, + }, + want: url.Values{}, + cfg: &cfg3{ + Timeout: xtime.Duration(time.Second), + Username: "hello", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := bindQuery(tt.args.query, tt.args.v, tt.args.assignFuncs) + if (err != nil) != tt.wantErr { + t.Errorf("DecodeQuery() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("DecodeQuery() = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(tt.args.v, tt.cfg) { + t.Errorf("DecodeQuery() = %v, want %v", tt.args.v, tt.cfg) + } + }) + } +} diff --git a/pkg/conf/env/README.md b/pkg/conf/env/README.md new file mode 100644 index 000000000..5afa4ba9b --- /dev/null +++ b/pkg/conf/env/README.md @@ -0,0 +1,5 @@ +# env + +## 项目简介 + +全局公用环境变量 diff --git a/pkg/conf/env/env.go b/pkg/conf/env/env.go new file mode 100644 index 000000000..bca8d81b6 --- /dev/null +++ b/pkg/conf/env/env.go @@ -0,0 +1,68 @@ +// Package env get env & app config, all the public field must after init() +// finished and flag.Parse(). +package env + +import ( + "flag" + "os" +) + +// deploy env. +const ( + DeployEnvDev = "dev" + DeployEnvFat = "fat" + DeployEnvUat = "uat" + DeployEnvPre = "pre" + DeployEnvProd = "prod" +) + +// env default value. +const ( + // env + _region = "region01" + _zone = "zone01" + _deployEnv = "dev" +) + +// env configuration. +var ( + // Region avaliable region where app at. + Region string + // Zone avaliable zone where app at. + Zone string + // Hostname machine hostname. + Hostname string + // DeployEnv deploy env where app at. + DeployEnv string + // AppID is global unique application id, register by service tree. + // such as main.arch.disocvery. + AppID string + // Color is the identification of different experimental group in one caster cluster. + Color string +) + +func init() { + var err error + if Hostname, err = os.Hostname(); err != nil || Hostname == "" { + Hostname = os.Getenv("HOSTNAME") + } + + addFlag(flag.CommandLine) +} + +func addFlag(fs *flag.FlagSet) { + // env + fs.StringVar(&Region, "region", defaultString("REGION", _region), "avaliable region. or use REGION env variable, value: sh etc.") + fs.StringVar(&Zone, "zone", defaultString("ZONE", _zone), "avaliable zone. or use ZONE env variable, value: sh001/sh002 etc.") + fs.StringVar(&AppID, "appid", os.Getenv("APP_ID"), "appid is global unique application id, register by service tree. or use APP_ID env variable.") + fs.StringVar(&DeployEnv, "deploy.env", defaultString("DEPLOY_ENV", _deployEnv), "deploy env. or use DEPLOY_ENV env variable, value: dev/fat1/uat/pre/prod etc.") + fs.StringVar(&Color, "deploy.color", os.Getenv("DEPLOY_COLOR"), "deploy.color is the identification of different experimental group.") +} + +func defaultString(env, value string) string { + v := os.Getenv(env) + if v == "" { + return value + } + return v +} diff --git a/pkg/conf/env/env_test.go b/pkg/conf/env/env_test.go new file mode 100644 index 000000000..a99de0caf --- /dev/null +++ b/pkg/conf/env/env_test.go @@ -0,0 +1,104 @@ +package env + +import ( + "flag" + "fmt" + "os" + "testing" +) + +func TestDefaultString(t *testing.T) { + v := defaultString("a", "test") + if v != "test" { + t.Fatal("v must be test") + } + if err := os.Setenv("a", "test1"); err != nil { + t.Fatal(err) + } + v = defaultString("a", "test") + if v != "test1" { + t.Fatal("v must be test1") + } +} + +func TestEnv(t *testing.T) { + tests := []struct { + flag string + env string + def string + val *string + }{ + { + "region", + "REGION", + _region, + &Region, + }, + { + "zone", + "ZONE", + _zone, + &Zone, + }, + { + "deploy.env", + "DEPLOY_ENV", + _deployEnv, + &DeployEnv, + }, + { + "appid", + "APP_ID", + "", + &AppID, + }, + { + "deploy.color", + "DEPLOY_COLOR", + "", + &Color, + }, + } + for _, test := range tests { + // flag set value + t.Run(fmt.Sprintf("%s: flag set", test.env), func(t *testing.T) { + fs := flag.NewFlagSet("", flag.ContinueOnError) + addFlag(fs) + err := fs.Parse([]string{fmt.Sprintf("-%s=%s", test.flag, "test")}) + if err != nil { + t.Fatal(err) + } + if *test.val != "test" { + t.Fatal("val must be test") + } + }) + // flag not set, env set + t.Run(fmt.Sprintf("%s: flag not set, env set", test.env), func(t *testing.T) { + *test.val = "" + os.Setenv(test.env, "test2") + fs := flag.NewFlagSet("", flag.ContinueOnError) + addFlag(fs) + err := fs.Parse([]string{}) + if err != nil { + t.Fatal(err) + } + if *test.val != "test2" { + t.Fatal("val must be test") + } + }) + // flag not set, env not set + t.Run(fmt.Sprintf("%s: flag not set, env not set", test.env), func(t *testing.T) { + *test.val = "" + os.Setenv(test.env, "") + fs := flag.NewFlagSet("", flag.ContinueOnError) + addFlag(fs) + err := fs.Parse([]string{}) + if err != nil { + t.Fatal(err) + } + if *test.val != test.def { + t.Fatal("val must be test") + } + }) + } +} diff --git a/pkg/container/pool/README.md b/pkg/container/pool/README.md new file mode 100644 index 000000000..05d5198f9 --- /dev/null +++ b/pkg/container/pool/README.md @@ -0,0 +1,5 @@ +# pool + +## 项目简介 + +通用连接池实现 diff --git a/pkg/container/pool/list.go b/pkg/container/pool/list.go new file mode 100644 index 000000000..97e9afc5c --- /dev/null +++ b/pkg/container/pool/list.go @@ -0,0 +1,227 @@ +package pool + +import ( + "container/list" + "context" + "io" + "sync" + "time" +) + +var _ Pool = &List{} + +// List . +type List struct { + // New is an application supplied function for creating and configuring a + // item. + // + // The item returned from new must not be in a special state + // (subscribed to pubsub channel, transaction started, ...). + New func(ctx context.Context) (io.Closer, error) + + // mu protects fields defined below. + mu sync.Mutex + cond chan struct{} + closed bool + active int + // clean stale items + cleanerCh chan struct{} + + // Stack of item with most recently used at the front. + idles list.List + + // Config pool configuration + conf *Config +} + +// NewList creates a new pool. +func NewList(c *Config) *List { + // check Config + if c == nil || c.Active < c.Idle { + panic("config nil or Idle Must <= Active") + } + // new pool + p := &List{conf: c} + p.cond = make(chan struct{}) + p.startCleanerLocked(time.Duration(c.IdleTimeout)) + return p +} + +// Reload reload config. +func (p *List) Reload(c *Config) error { + p.mu.Lock() + p.startCleanerLocked(time.Duration(c.IdleTimeout)) + p.conf = c + p.mu.Unlock() + return nil +} + +// startCleanerLocked +func (p *List) startCleanerLocked(d time.Duration) { + if d <= 0 { + // if set 0, staleCleaner() will return directly + return + } + if d < time.Duration(p.conf.IdleTimeout) && p.cleanerCh != nil { + select { + case p.cleanerCh <- struct{}{}: + default: + } + } + // run only one, clean stale items. + if p.cleanerCh == nil { + p.cleanerCh = make(chan struct{}, 1) + go p.staleCleaner() + } +} + +// staleCleaner clean stale items proc. +func (p *List) staleCleaner() { + ticker := time.NewTicker(100 * time.Millisecond) + for { + select { + case <-ticker.C: + case <-p.cleanerCh: // maxLifetime was changed or db was closed. + } + p.mu.Lock() + if p.closed || p.conf.IdleTimeout <= 0 { + p.mu.Unlock() + return + } + for i, n := 0, p.idles.Len(); i < n; i++ { + e := p.idles.Back() + if e == nil { + // no possible + break + } + ic := e.Value.(item) + if !ic.expired(time.Duration(p.conf.IdleTimeout)) { + // not need continue. + break + } + p.idles.Remove(e) + p.release() + p.mu.Unlock() + ic.c.Close() + p.mu.Lock() + } + p.mu.Unlock() + } +} + +// Get returns a item from the idles List or +// get a new item. +func (p *List) Get(ctx context.Context) (io.Closer, error) { + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return nil, ErrPoolClosed + } + for { + // get idles item. + for i, n := 0, p.idles.Len(); i < n; i++ { + e := p.idles.Front() + if e == nil { + break + } + ic := e.Value.(item) + p.idles.Remove(e) + p.mu.Unlock() + if !ic.expired(time.Duration(p.conf.IdleTimeout)) { + return ic.c, nil + } + ic.c.Close() + p.mu.Lock() + p.release() + } + // Check for pool closed before dialing a new item. + if p.closed { + p.mu.Unlock() + return nil, ErrPoolClosed + } + // new item if under limit. + if p.conf.Active == 0 || p.active < p.conf.Active { + newItem := p.New + p.active++ + p.mu.Unlock() + c, err := newItem(ctx) + if err != nil { + p.mu.Lock() + p.release() + p.mu.Unlock() + c = nil + } + return c, err + } + if p.conf.WaitTimeout == 0 && !p.conf.Wait { + p.mu.Unlock() + return nil, ErrPoolExhausted + } + wt := p.conf.WaitTimeout + p.mu.Unlock() + + // slowpath: reset context timeout + nctx := ctx + cancel := func() {} + if wt > 0 { + _, nctx, cancel = wt.Shrink(ctx) + } + select { + case <-nctx.Done(): + cancel() + return nil, nctx.Err() + case <-p.cond: + } + cancel() + p.mu.Lock() + } +} + +// Put put item into pool. +func (p *List) Put(ctx context.Context, c io.Closer, forceClose bool) error { + p.mu.Lock() + if !p.closed && !forceClose { + p.idles.PushFront(item{createdAt: nowFunc(), c: c}) + if p.idles.Len() > p.conf.Idle { + c = p.idles.Remove(p.idles.Back()).(item).c + } else { + c = nil + } + } + if c == nil { + p.signal() + p.mu.Unlock() + return nil + } + p.release() + p.mu.Unlock() + return c.Close() +} + +// Close releases the resources used by the pool. +func (p *List) Close() error { + p.mu.Lock() + idles := p.idles + p.idles.Init() + p.closed = true + p.active -= idles.Len() + p.mu.Unlock() + for e := idles.Front(); e != nil; e = e.Next() { + e.Value.(item).c.Close() + } + return nil +} + +// release decrements the active count and signals waiters. The caller must +// hold p.mu during the call. +func (p *List) release() { + p.active-- + p.signal() +} + +func (p *List) signal() { + select { + default: + case p.cond <- struct{}{}: + } +} diff --git a/pkg/container/pool/list_test.go b/pkg/container/pool/list_test.go new file mode 100644 index 000000000..421e22771 --- /dev/null +++ b/pkg/container/pool/list_test.go @@ -0,0 +1,322 @@ +package pool + +import ( + "context" + "io" + "testing" + "time" + + xtime "github.com/bilibili/Kratos/pkg/time" + + "github.com/stretchr/testify/assert" +) + +func TestListGetPut(t *testing.T) { + // new pool + config := &Config{ + Active: 1, + Idle: 1, + IdleTimeout: xtime.Duration(90 * time.Second), + WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewList(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + + // test Get Put + conn, err := pool.Get(context.TODO()) + assert.Nil(t, err) + c1 := connection{pool: pool, c: conn} + c1.HandleNormal() + c1.Close() +} + +func TestListPut(t *testing.T) { + var id = 0 + type connID struct { + io.Closer + id int + } + config := &Config{ + Active: 1, + Idle: 1, + IdleTimeout: xtime.Duration(1 * time.Second), + // WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewList(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + id = id + 1 + return &connID{id: id, Closer: &closer{}}, nil + } + // test Put(ctx, conn, true) + conn, err := pool.Get(context.TODO()) + assert.Nil(t, err) + conn1 := conn.(*connID) + // Put(ctx, conn, true) drop the connection. + pool.Put(context.TODO(), conn, true) + conn, err = pool.Get(context.TODO()) + assert.Nil(t, err) + conn2 := conn.(*connID) + assert.NotEqual(t, conn1.id, conn2.id) +} + +func TestListIdleTimeout(t *testing.T) { + var id = 0 + type connID struct { + io.Closer + id int + } + config := &Config{ + Active: 1, + Idle: 1, + // conn timeout + IdleTimeout: xtime.Duration(1 * time.Millisecond), + } + pool := NewList(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + id = id + 1 + return &connID{id: id, Closer: &closer{}}, nil + } + // test Put(ctx, conn, true) + conn, err := pool.Get(context.TODO()) + assert.Nil(t, err) + conn1 := conn.(*connID) + // Put(ctx, conn, true) drop the connection. + pool.Put(context.TODO(), conn, false) + time.Sleep(5 * time.Millisecond) + // idletimeout and get new conn + conn, err = pool.Get(context.TODO()) + assert.Nil(t, err) + conn2 := conn.(*connID) + assert.NotEqual(t, conn1.id, conn2.id) +} + +func TestListContextTimeout(t *testing.T) { + // new pool + config := &Config{ + Active: 1, + Idle: 1, + IdleTimeout: xtime.Duration(90 * time.Second), + WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewList(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + // test context timeout + ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond) + defer cancel() + conn, err := pool.Get(ctx) + assert.Nil(t, err) + _, err = pool.Get(ctx) + // context timeout error + assert.NotNil(t, err) + pool.Put(context.TODO(), conn, false) + _, err = pool.Get(ctx) + assert.Nil(t, err) +} + +func TestListPoolExhausted(t *testing.T) { + // test pool exhausted + config := &Config{ + Active: 1, + Idle: 1, + IdleTimeout: xtime.Duration(90 * time.Second), + // WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewList(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + + ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond) + defer cancel() + conn, err := pool.Get(context.TODO()) + assert.Nil(t, err) + _, err = pool.Get(ctx) + // config active == 1, so no avaliable conns make connection exhausted. + assert.NotNil(t, err) + pool.Put(context.TODO(), conn, false) + _, err = pool.Get(ctx) + assert.Nil(t, err) +} + +func TestListStaleClean(t *testing.T) { + var id = 0 + type connID struct { + io.Closer + id int + } + config := &Config{ + Active: 1, + Idle: 1, + IdleTimeout: xtime.Duration(1 * time.Second), + // WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewList(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + id = id + 1 + return &connID{id: id, Closer: &closer{}}, nil + } + conn, err := pool.Get(context.TODO()) + assert.Nil(t, err) + conn1 := conn.(*connID) + pool.Put(context.TODO(), conn, false) + conn, err = pool.Get(context.TODO()) + assert.Nil(t, err) + conn2 := conn.(*connID) + assert.Equal(t, conn1.id, conn2.id) + pool.Put(context.TODO(), conn, false) + // sleep more than idleTimeout + time.Sleep(2 * time.Second) + conn, err = pool.Get(context.TODO()) + assert.Nil(t, err) + conn3 := conn.(*connID) + assert.NotEqual(t, conn1.id, conn3.id) +} + +func BenchmarkList1(b *testing.B) { + config := &Config{ + Active: 30, + Idle: 30, + IdleTimeout: xtime.Duration(90 * time.Second), + WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewList(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get(context.TODO()) + if err != nil { + b.Error(err) + continue + } + c1 := connection{pool: pool, c: conn} + c1.HandleQuick() + c1.Close() + } + }) +} + +func BenchmarkList2(b *testing.B) { + config := &Config{ + Active: 30, + Idle: 30, + IdleTimeout: xtime.Duration(90 * time.Second), + WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewList(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get(context.TODO()) + if err != nil { + b.Error(err) + continue + } + c1 := connection{pool: pool, c: conn} + c1.HandleNormal() + c1.Close() + } + }) +} + +func BenchmarkPool3(b *testing.B) { + config := &Config{ + Active: 30, + Idle: 30, + IdleTimeout: xtime.Duration(90 * time.Second), + WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewList(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get(context.TODO()) + if err != nil { + b.Error(err) + continue + } + c1 := connection{pool: pool, c: conn} + c1.HandleSlow() + c1.Close() + } + }) +} + +func BenchmarkList4(b *testing.B) { + config := &Config{ + Active: 30, + Idle: 30, + IdleTimeout: xtime.Duration(90 * time.Second), + // WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewList(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get(context.TODO()) + if err != nil { + b.Error(err) + continue + } + c1 := connection{pool: pool, c: conn} + c1.HandleSlow() + c1.Close() + } + }) +} + +func BenchmarkList5(b *testing.B) { + config := &Config{ + Active: 30, + Idle: 30, + IdleTimeout: xtime.Duration(90 * time.Second), + // WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: true, + } + pool := NewList(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get(context.TODO()) + if err != nil { + b.Error(err) + continue + } + c1 := connection{pool: pool, c: conn} + c1.HandleSlow() + c1.Close() + } + }) +} diff --git a/pkg/container/pool/pool.go b/pkg/container/pool/pool.go new file mode 100644 index 000000000..13ad43e6f --- /dev/null +++ b/pkg/container/pool/pool.go @@ -0,0 +1,62 @@ +package pool + +import ( + "context" + "errors" + "io" + "time" + + xtime "github.com/bilibili/Kratos/pkg/time" +) + +var ( + // ErrPoolExhausted connections are exhausted. + ErrPoolExhausted = errors.New("container/pool exhausted") + // ErrPoolClosed connection pool is closed. + ErrPoolClosed = errors.New("container/pool closed") + + // nowFunc returns the current time; it's overridden in tests. + nowFunc = time.Now +) + +// Config is the pool configuration struct. +type Config struct { + // Active number of items allocated by the pool at a given time. + // When zero, there is no limit on the number of items in the pool. + Active int + // Idle number of idle items in the pool. + Idle int + // Close items after remaining item for this duration. If the value + // is zero, then item items are not closed. Applications should set + // the timeout to a value less than the server's timeout. + IdleTimeout xtime.Duration + // If WaitTimeout is set and the pool is at the Active limit, then Get() waits WatiTimeout + // until a item to be returned to the pool before returning. + WaitTimeout xtime.Duration + // If WaitTimeout is not set, then Wait effects. + // if Wait is set true, then wait until ctx timeout, or default flase and return directly. + Wait bool +} + +type item struct { + createdAt time.Time + c io.Closer +} + +func (i *item) expired(timeout time.Duration) bool { + if timeout <= 0 { + return false + } + return i.createdAt.Add(timeout).Before(nowFunc()) +} + +func (i *item) close() error { + return i.c.Close() +} + +// Pool interface. +type Pool interface { + Get(ctx context.Context) (io.Closer, error) + Put(ctx context.Context, c io.Closer, forceClose bool) error + Close() error +} diff --git a/pkg/container/pool/slice.go b/pkg/container/pool/slice.go new file mode 100644 index 000000000..6876565ed --- /dev/null +++ b/pkg/container/pool/slice.go @@ -0,0 +1,423 @@ +package pool + +import ( + "context" + "io" + "sync" + "time" +) + +var _ Pool = &Slice{} + +// Slice . +type Slice struct { + // New is an application supplied function for creating and configuring a + // item. + // + // The item returned from new must not be in a special state + // (subscribed to pubsub channel, transaction started, ...). + New func(ctx context.Context) (io.Closer, error) + stop func() // stop cancels the item opener. + + // mu protects fields defined below. + mu sync.Mutex + freeItem []*item + itemRequests map[uint64]chan item + nextRequest uint64 // Next key to use in itemRequests. + active int // number of opened and pending open items + // Used to signal the need for new items + // a goroutine running itemOpener() reads on this chan and + // maybeOpenNewItems sends on the chan (one send per needed item) + // It is closed during db.Close(). The close tells the itemOpener + // goroutine to exit. + openerCh chan struct{} + closed bool + cleanerCh chan struct{} + + // Config pool configuration + conf *Config +} + +// NewSlice creates a new pool. +func NewSlice(c *Config) *Slice { + // check Config + if c == nil || c.Active < c.Idle { + panic("config nil or Idle Must <= Active") + } + ctx, cancel := context.WithCancel(context.Background()) + // new pool + p := &Slice{ + conf: c, + stop: cancel, + itemRequests: make(map[uint64]chan item), + openerCh: make(chan struct{}, 1000000), + } + p.startCleanerLocked(time.Duration(c.IdleTimeout)) + + go p.itemOpener(ctx) + return p +} + +// Reload reload config. +func (p *Slice) Reload(c *Config) error { + p.mu.Lock() + p.startCleanerLocked(time.Duration(c.IdleTimeout)) + p.setActive(c.Active) + p.setIdle(c.Idle) + p.conf = c + p.mu.Unlock() + return nil +} + +// Get returns a newly-opened or cached *item. +func (p *Slice) Get(ctx context.Context) (io.Closer, error) { + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return nil, ErrPoolClosed + } + idleTimeout := time.Duration(p.conf.IdleTimeout) + // Prefer a free item, if possible. + numFree := len(p.freeItem) + for numFree > 0 { + i := p.freeItem[0] + copy(p.freeItem, p.freeItem[1:]) + p.freeItem = p.freeItem[:numFree-1] + p.mu.Unlock() + if i.expired(idleTimeout) { + i.close() + p.mu.Lock() + p.release() + } else { + return i.c, nil + } + numFree = len(p.freeItem) + } + + // Out of free items or we were asked not to use one. If we're not + // allowed to open any more items, make a request and wait. + if p.conf.Active > 0 && p.active >= p.conf.Active { + // check WaitTimeout and return directly + if p.conf.WaitTimeout == 0 && !p.conf.Wait { + p.mu.Unlock() + return nil, ErrPoolExhausted + } + // Make the item channel. It's buffered so that the + // itemOpener doesn't block while waiting for the req to be read. + req := make(chan item, 1) + reqKey := p.nextRequestKeyLocked() + p.itemRequests[reqKey] = req + wt := p.conf.WaitTimeout + p.mu.Unlock() + + // reset context timeout + if wt > 0 { + var cancel func() + _, ctx, cancel = wt.Shrink(ctx) + defer cancel() + } + // Timeout the item request with the context. + select { + case <-ctx.Done(): + // Remove the item request and ensure no value has been sent + // on it after removing. + p.mu.Lock() + delete(p.itemRequests, reqKey) + p.mu.Unlock() + return nil, ctx.Err() + case ret, ok := <-req: + if !ok { + return nil, ErrPoolClosed + } + if ret.expired(idleTimeout) { + ret.close() + p.mu.Lock() + p.release() + } else { + return ret.c, nil + } + } + } + + p.active++ // optimistically + p.mu.Unlock() + c, err := p.New(ctx) + if err != nil { + p.mu.Lock() + p.release() + p.mu.Unlock() + return nil, err + } + return c, nil +} + +// Put adds a item to the p's free pool. +// err is optionally the last error that occurred on this item. +func (p *Slice) Put(ctx context.Context, c io.Closer, forceClose bool) error { + p.mu.Lock() + defer p.mu.Unlock() + if forceClose { + p.release() + return c.Close() + } + added := p.putItemLocked(c) + if !added { + p.active-- + return c.Close() + } + return nil +} + +// Satisfy a item or put the item in the idle pool and return true +// or return false. +// putItemLocked will satisfy a item if there is one, or it will +// return the *item to the freeItem list if err == nil and the idle +// item limit will not be exceeded. +// If err != nil, the value of i is ignored. +// If err == nil, then i must not equal nil. +// If a item was fulfilled or the *item was placed in the +// freeItem list, then true is returned, otherwise false is returned. +func (p *Slice) putItemLocked(c io.Closer) bool { + if p.closed { + return false + } + if p.conf.Active > 0 && p.active > p.conf.Active { + return false + } + i := item{ + c: c, + createdAt: nowFunc(), + } + if l := len(p.itemRequests); l > 0 { + var req chan item + var reqKey uint64 + for reqKey, req = range p.itemRequests { + break + } + delete(p.itemRequests, reqKey) // Remove from pending requests. + req <- i + return true + } else if !p.closed && p.maxIdleItemsLocked() > len(p.freeItem) { + p.freeItem = append(p.freeItem, &i) + return true + } + return false +} + +// Runs in a separate goroutine, opens new item when requested. +func (p *Slice) itemOpener(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-p.openerCh: + p.openNewItem(ctx) + } + } +} + +func (p *Slice) maybeOpenNewItems() { + numRequests := len(p.itemRequests) + if p.conf.Active > 0 { + numCanOpen := p.conf.Active - p.active + if numRequests > numCanOpen { + numRequests = numCanOpen + } + } + for numRequests > 0 { + p.active++ // optimistically + numRequests-- + if p.closed { + return + } + p.openerCh <- struct{}{} + } +} + +// openNewItem one new item +func (p *Slice) openNewItem(ctx context.Context) { + // maybeOpenNewConnctions has already executed p.active++ before it sent + // on p.openerCh. This function must execute p.active-- if the + // item fails or is closed before returning. + c, err := p.New(ctx) + p.mu.Lock() + defer p.mu.Unlock() + if err != nil { + p.release() + return + } + if !p.putItemLocked(c) { + p.active-- + c.Close() + } +} + +// setIdle sets the maximum number of items in the idle +// item pool. +// +// If MaxOpenConns is greater than 0 but less than the new IdleConns +// then the new IdleConns will be reduced to match the MaxOpenConns limit +// +// If n <= 0, no idle items are retained. +func (p *Slice) setIdle(n int) { + p.mu.Lock() + if n > 0 { + p.conf.Idle = n + } else { + // No idle items. + p.conf.Idle = -1 + } + // Make sure maxIdle doesn't exceed maxOpen + if p.conf.Active > 0 && p.maxIdleItemsLocked() > p.conf.Active { + p.conf.Idle = p.conf.Active + } + var closing []*item + idleCount := len(p.freeItem) + maxIdle := p.maxIdleItemsLocked() + if idleCount > maxIdle { + closing = p.freeItem[maxIdle:] + p.freeItem = p.freeItem[:maxIdle] + } + p.mu.Unlock() + for _, c := range closing { + c.close() + } +} + +// setActive sets the maximum number of open items to the database. +// +// If IdleConns is greater than 0 and the new MaxOpenConns is less than +// IdleConns, then IdleConns will be reduced to match the new +// MaxOpenConns limit +// +// If n <= 0, then there is no limit on the number of open items. +// The default is 0 (unlimited). +func (p *Slice) setActive(n int) { + p.mu.Lock() + p.conf.Active = n + if n < 0 { + p.conf.Active = 0 + } + syncIdle := p.conf.Active > 0 && p.maxIdleItemsLocked() > p.conf.Active + p.mu.Unlock() + if syncIdle { + p.setIdle(n) + } +} + +// startCleanerLocked starts itemCleaner if needed. +func (p *Slice) startCleanerLocked(d time.Duration) { + if d <= 0 { + // if set 0, staleCleaner() will return directly + return + } + if d < time.Duration(p.conf.IdleTimeout) && p.cleanerCh != nil { + select { + case p.cleanerCh <- struct{}{}: + default: + } + } + // run only one, clean stale items. + if p.cleanerCh == nil { + p.cleanerCh = make(chan struct{}, 1) + go p.staleCleaner(time.Duration(p.conf.IdleTimeout)) + } +} + +func (p *Slice) staleCleaner(d time.Duration) { + const minInterval = 100 * time.Millisecond + + if d < minInterval { + d = minInterval + } + t := time.NewTimer(d) + + for { + select { + case <-t.C: + case <-p.cleanerCh: // maxLifetime was changed or db was closed. + } + p.mu.Lock() + d = time.Duration(p.conf.IdleTimeout) + if p.closed || d <= 0 { + p.mu.Unlock() + return + } + + expiredSince := nowFunc().Add(-d) + var closing []*item + for i := 0; i < len(p.freeItem); i++ { + c := p.freeItem[i] + if c.createdAt.Before(expiredSince) { + closing = append(closing, c) + p.active-- + last := len(p.freeItem) - 1 + p.freeItem[i] = p.freeItem[last] + p.freeItem[last] = nil + p.freeItem = p.freeItem[:last] + i-- + } + } + p.mu.Unlock() + + for _, c := range closing { + c.close() + } + + if d < minInterval { + d = minInterval + } + t.Reset(d) + } +} + +// nextRequestKeyLocked returns the next item request key. +// It is assumed that nextRequest will not overflow. +func (p *Slice) nextRequestKeyLocked() uint64 { + next := p.nextRequest + p.nextRequest++ + return next +} + +const defaultIdleItems = 2 + +func (p *Slice) maxIdleItemsLocked() int { + n := p.conf.Idle + switch { + case n == 0: + return defaultIdleItems + case n < 0: + return 0 + default: + return n + } +} + +func (p *Slice) release() { + p.active-- + p.maybeOpenNewItems() +} + +// Close close pool. +func (p *Slice) Close() error { + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return nil + } + if p.cleanerCh != nil { + close(p.cleanerCh) + } + var err error + for _, i := range p.freeItem { + i.close() + } + p.freeItem = nil + p.closed = true + for _, req := range p.itemRequests { + close(req) + } + p.mu.Unlock() + p.stop() + return err +} diff --git a/pkg/container/pool/slice_test.go b/pkg/container/pool/slice_test.go new file mode 100644 index 000000000..4fa19930d --- /dev/null +++ b/pkg/container/pool/slice_test.go @@ -0,0 +1,350 @@ +package pool + +import ( + "context" + "io" + "testing" + "time" + + xtime "github.com/bilibili/Kratos/pkg/time" + + "github.com/stretchr/testify/assert" +) + +type closer struct { +} + +func (c *closer) Close() error { + return nil +} + +type connection struct { + c io.Closer + pool Pool +} + +func (c *connection) HandleQuick() { + // time.Sleep(1 * time.Millisecond) +} + +func (c *connection) HandleNormal() { + time.Sleep(20 * time.Millisecond) +} + +func (c *connection) HandleSlow() { + time.Sleep(500 * time.Millisecond) +} + +func (c *connection) Close() { + c.pool.Put(context.Background(), c.c, false) +} + +func TestSliceGetPut(t *testing.T) { + // new pool + config := &Config{ + Active: 1, + Idle: 1, + IdleTimeout: xtime.Duration(90 * time.Second), + WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewSlice(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + + // test Get Put + conn, err := pool.Get(context.TODO()) + assert.Nil(t, err) + c1 := connection{pool: pool, c: conn} + c1.HandleNormal() + c1.Close() +} + +func TestSlicePut(t *testing.T) { + var id = 0 + type connID struct { + io.Closer + id int + } + config := &Config{ + Active: 1, + Idle: 1, + IdleTimeout: xtime.Duration(1 * time.Second), + // WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewSlice(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + id = id + 1 + return &connID{id: id, Closer: &closer{}}, nil + } + // test Put(ctx, conn, true) + conn, err := pool.Get(context.TODO()) + assert.Nil(t, err) + conn1 := conn.(*connID) + // Put(ctx, conn, true) drop the connection. + pool.Put(context.TODO(), conn, true) + conn, err = pool.Get(context.TODO()) + assert.Nil(t, err) + conn2 := conn.(*connID) + assert.NotEqual(t, conn1.id, conn2.id) +} + +func TestSliceIdleTimeout(t *testing.T) { + var id = 0 + type connID struct { + io.Closer + id int + } + config := &Config{ + Active: 1, + Idle: 1, + // conn timeout + IdleTimeout: xtime.Duration(1 * time.Millisecond), + } + pool := NewSlice(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + id = id + 1 + return &connID{id: id, Closer: &closer{}}, nil + } + // test Put(ctx, conn, true) + conn, err := pool.Get(context.TODO()) + assert.Nil(t, err) + conn1 := conn.(*connID) + // Put(ctx, conn, true) drop the connection. + pool.Put(context.TODO(), conn, false) + time.Sleep(5 * time.Millisecond) + // idletimeout and get new conn + conn, err = pool.Get(context.TODO()) + assert.Nil(t, err) + conn2 := conn.(*connID) + assert.NotEqual(t, conn1.id, conn2.id) +} + +func TestSliceContextTimeout(t *testing.T) { + // new pool + config := &Config{ + Active: 1, + Idle: 1, + IdleTimeout: xtime.Duration(90 * time.Second), + WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewSlice(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + // test context timeout + ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond) + defer cancel() + conn, err := pool.Get(ctx) + assert.Nil(t, err) + _, err = pool.Get(ctx) + // context timeout error + assert.NotNil(t, err) + pool.Put(context.TODO(), conn, false) + _, err = pool.Get(ctx) + assert.Nil(t, err) +} + +func TestSlicePoolExhausted(t *testing.T) { + // test pool exhausted + config := &Config{ + Active: 1, + Idle: 1, + IdleTimeout: xtime.Duration(90 * time.Second), + // WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewSlice(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + + ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond) + defer cancel() + conn, err := pool.Get(context.TODO()) + assert.Nil(t, err) + _, err = pool.Get(ctx) + // config active == 1, so no avaliable conns make connection exhausted. + assert.NotNil(t, err) + pool.Put(context.TODO(), conn, false) + _, err = pool.Get(ctx) + assert.Nil(t, err) +} + +func TestSliceStaleClean(t *testing.T) { + var id = 0 + type connID struct { + io.Closer + id int + } + config := &Config{ + Active: 1, + Idle: 1, + IdleTimeout: xtime.Duration(1 * time.Second), + // WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewList(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + id = id + 1 + return &connID{id: id, Closer: &closer{}}, nil + } + conn, err := pool.Get(context.TODO()) + assert.Nil(t, err) + conn1 := conn.(*connID) + pool.Put(context.TODO(), conn, false) + conn, err = pool.Get(context.TODO()) + assert.Nil(t, err) + conn2 := conn.(*connID) + assert.Equal(t, conn1.id, conn2.id) + pool.Put(context.TODO(), conn, false) + // sleep more than idleTimeout + time.Sleep(2 * time.Second) + conn, err = pool.Get(context.TODO()) + assert.Nil(t, err) + conn3 := conn.(*connID) + assert.NotEqual(t, conn1.id, conn3.id) +} + +func BenchmarkSlice1(b *testing.B) { + config := &Config{ + Active: 30, + Idle: 30, + IdleTimeout: xtime.Duration(90 * time.Second), + WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewSlice(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get(context.TODO()) + if err != nil { + b.Error(err) + continue + } + c1 := connection{pool: pool, c: conn} + c1.HandleQuick() + c1.Close() + } + }) +} + +func BenchmarkSlice2(b *testing.B) { + config := &Config{ + Active: 30, + Idle: 30, + IdleTimeout: xtime.Duration(90 * time.Second), + WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewSlice(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get(context.TODO()) + if err != nil { + b.Error(err) + continue + } + c1 := connection{pool: pool, c: conn} + c1.HandleNormal() + c1.Close() + } + }) +} + +func BenchmarkSlice3(b *testing.B) { + config := &Config{ + Active: 30, + Idle: 30, + IdleTimeout: xtime.Duration(90 * time.Second), + WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewSlice(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get(context.TODO()) + if err != nil { + b.Error(err) + continue + } + c1 := connection{pool: pool, c: conn} + c1.HandleSlow() + c1.Close() + } + }) +} + +func BenchmarkSlice4(b *testing.B) { + config := &Config{ + Active: 30, + Idle: 30, + IdleTimeout: xtime.Duration(90 * time.Second), + // WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: false, + } + pool := NewSlice(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get(context.TODO()) + if err != nil { + b.Error(err) + continue + } + c1 := connection{pool: pool, c: conn} + c1.HandleSlow() + c1.Close() + } + }) +} + +func BenchmarkSlice5(b *testing.B) { + config := &Config{ + Active: 30, + Idle: 30, + IdleTimeout: xtime.Duration(90 * time.Second), + // WaitTimeout: xtime.Duration(10 * time.Millisecond), + Wait: true, + } + pool := NewSlice(config) + pool.New = func(ctx context.Context) (io.Closer, error) { + return &closer{}, nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get(context.TODO()) + if err != nil { + b.Error(err) + continue + } + c1 := connection{pool: pool, c: conn} + c1.HandleSlow() + c1.Close() + } + }) +} diff --git a/pkg/container/queue/aqm/README.md b/pkg/container/queue/aqm/README.md new file mode 100644 index 000000000..273de9b8d --- /dev/null +++ b/pkg/container/queue/aqm/README.md @@ -0,0 +1,5 @@ +# aqm + +## 项目简介 + +队列管理算法 diff --git a/pkg/container/queue/aqm/codel.go b/pkg/container/queue/aqm/codel.go new file mode 100644 index 000000000..4f0866251 --- /dev/null +++ b/pkg/container/queue/aqm/codel.go @@ -0,0 +1,185 @@ +package aqm + +import ( + "context" + "math" + "sync" + "time" + + "github.com/bilibili/Kratos/pkg/ecode" +) + +// Config codel config. +type Config struct { + Target int64 // target queue delay (default 20 ms). + Internal int64 // sliding minimum time window width (default 500 ms) +} + +// Stat is the Statistics of codel. +type Stat struct { + Dropping bool + FaTime int64 + DropNext int64 + Packets int +} + +type packet struct { + ch chan bool + ts int64 +} + +var defaultConf = &Config{ + Target: 50, + Internal: 500, +} + +// Queue queue is CoDel req buffer queue. +type Queue struct { + pool sync.Pool + packets chan packet + + mux sync.RWMutex + conf *Config + count int64 + dropping bool // Equal to 1 if in drop state + faTime int64 // Time when we'll declare we're above target (0 if below) + dropNext int64 // Packets dropped since going into drop state +} + +// Default new a default codel queue. +func Default() *Queue { + return New(defaultConf) +} + +// New new codel queue. +func New(conf *Config) *Queue { + if conf == nil { + conf = defaultConf + } + q := &Queue{ + packets: make(chan packet, 2048), + conf: conf, + } + q.pool.New = func() interface{} { + return make(chan bool) + } + return q +} + +// Reload set queue config. +func (q *Queue) Reload(c *Config) { + if c == nil || c.Internal <= 0 || c.Target <= 0 { + return + } + // TODO codel queue size + q.mux.Lock() + q.conf = c + q.mux.Unlock() +} + +// Stat return the statistics of codel +func (q *Queue) Stat() Stat { + q.mux.Lock() + defer q.mux.Unlock() + return Stat{ + Dropping: q.dropping, + FaTime: q.faTime, + DropNext: q.dropNext, + Packets: len(q.packets), + } +} + +// Push req into CoDel request buffer queue. +// if return error is nil,the caller must call q.Done() after finish request handling +func (q *Queue) Push(ctx context.Context) (err error) { + r := packet{ + ch: q.pool.Get().(chan bool), + ts: time.Now().UnixNano() / int64(time.Millisecond), + } + select { + case q.packets <- r: + default: + err = ecode.LimitExceed + q.pool.Put(r.ch) + } + if err == nil { + select { + case drop := <-r.ch: + if drop { + err = ecode.LimitExceed + } + q.pool.Put(r.ch) + case <-ctx.Done(): + err = ecode.Deadline + } + } + return +} + +// Pop req from CoDel request buffer queue. +func (q *Queue) Pop() { + for { + select { + case p := <-q.packets: + drop := q.judge(p) + select { + case p.ch <- drop: + if !drop { + return + } + default: + q.pool.Put(p.ch) + } + default: + return + } + } +} + +func (q *Queue) controlLaw(now int64) int64 { + q.dropNext = now + int64(float64(q.conf.Internal)/math.Sqrt(float64(q.count))) + return q.dropNext +} + +// judge decide if the packet should drop or not. +func (q *Queue) judge(p packet) (drop bool) { + now := time.Now().UnixNano() / int64(time.Millisecond) + sojurn := now - p.ts + q.mux.Lock() + defer q.mux.Unlock() + if sojurn < q.conf.Target { + q.faTime = 0 + } else if q.faTime == 0 { + q.faTime = now + q.conf.Internal + } else if now >= q.faTime { + drop = true + } + if q.dropping { + if !drop { + // sojourn time below target - leave dropping state + q.dropping = false + } else if now > q.dropNext { + q.count++ + q.dropNext = q.controlLaw(q.dropNext) + drop = true + return + } + } else if drop && (now-q.dropNext < q.conf.Internal || now-q.faTime >= q.conf.Internal) { + q.dropping = true + // If we're in a drop cycle, the drop rate that controlled the queue + // on the last cycle is a good starting point to control it now. + if now-q.dropNext < q.conf.Internal { + if q.count > 2 { + q.count = q.count - 2 + } else { + q.count = 1 + } + } else { + q.count = 1 + } + q.dropNext = q.controlLaw(now) + drop = true + return + } + return +} diff --git a/pkg/container/queue/aqm/codel_test.go b/pkg/container/queue/aqm/codel_test.go new file mode 100644 index 000000000..5857108a3 --- /dev/null +++ b/pkg/container/queue/aqm/codel_test.go @@ -0,0 +1,97 @@ +package aqm + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/bilibili/Kratos/pkg/ecode" +) + +var testConf = &Config{ + Target: 20, + Internal: 500, +} + +var qps = time.Microsecond * 2000 + +func TestCoDel1200(t *testing.T) { + q := New(testConf) + drop := new(int64) + tm := new(int64) + delay := time.Millisecond * 3000 + testPush(q, qps, delay, drop, tm) + fmt.Printf("qps %v process time %v drop %d timeout %d \n", int64(time.Second/qps), delay, *drop, *tm) + time.Sleep(time.Second) +} + +func TestCoDel200(t *testing.T) { + q := New(testConf) + drop := new(int64) + tm := new(int64) + delay := time.Millisecond * 2000 + testPush(q, qps, delay, drop, tm) + fmt.Printf("qps %v process time %v drop %d timeout %d \n", int64(time.Second/qps), delay, *drop, *tm) + time.Sleep(time.Second) + +} + +func TestCoDel100(t *testing.T) { + q := New(testConf) + drop := new(int64) + tm := new(int64) + delay := time.Millisecond * 1000 + testPush(q, qps, delay, drop, tm) + fmt.Printf("qps %v process time %v drop %d timeout %d \n", int64(time.Second/qps), delay, *drop, *tm) + +} + +func TestCoDel50(t *testing.T) { + q := New(testConf) + drop := new(int64) + tm := new(int64) + delay := time.Millisecond * 500 + testPush(q, qps, delay, drop, tm) + fmt.Printf("qps %v process time %v drop %d timeout %d \n", int64(time.Second/qps), delay, *drop, *tm) +} + +func testPush(q *Queue, sleep time.Duration, delay time.Duration, drop *int64, tm *int64) { + var group sync.WaitGroup + for i := 0; i < 5000; i++ { + time.Sleep(sleep) + group.Add(1) + go func() { + defer group.Done() + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*1000)) + defer cancel() + if err := q.Push(ctx); err != nil { + if err == ecode.LimitExceed { + atomic.AddInt64(drop, 1) + } else { + atomic.AddInt64(tm, 1) + } + } else { + time.Sleep(delay) + q.Pop() + } + }() + } + group.Wait() +} + +func BenchmarkAQM(b *testing.B) { + q := Default() + b.RunParallel(func(p *testing.PB) { + for p.Next() { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*5)) + err := q.Push(ctx) + if err == nil { + q.Pop() + } + cancel() + } + }) +} diff --git a/pkg/ecode/common_ecode.go b/pkg/ecode/common_ecode.go new file mode 100644 index 000000000..00f4e6047 --- /dev/null +++ b/pkg/ecode/common_ecode.go @@ -0,0 +1,19 @@ +package ecode + +// All common ecode +var ( + OK = add(0) // 正确 + + NotModified = add(-304) // 木有改动 + TemporaryRedirect = add(-307) // 撞车跳转 + RequestErr = add(-400) // 请求错误 + Unauthorized = add(-401) // 未认证 + AccessDenied = add(-403) // 访问权限不足 + NothingFound = add(-404) // 啥都木有 + MethodNotAllowed = add(-405) // 不支持该方法 + Conflict = add(-409) // 冲突 + ServerErr = add(-500) // 服务器错误 + ServiceUnavailable = add(-503) // 过载保护,服务暂不可用 + Deadline = add(-504) // 服务调用超时 + LimitExceed = add(-509) // 超出限制 +) diff --git a/pkg/ecode/ecode.go b/pkg/ecode/ecode.go new file mode 100644 index 000000000..3e627970f --- /dev/null +++ b/pkg/ecode/ecode.go @@ -0,0 +1,120 @@ +package ecode + +import ( + "fmt" + "strconv" + "sync/atomic" + + "github.com/pkg/errors" +) + +var ( + _messages atomic.Value // NOTE: stored map[string]map[int]string + _codes = map[int]struct{}{} // register codes. +) + +// Register register ecode message map. +func Register(cm map[int]string) { + _messages.Store(cm) +} + +// New new a ecode.Codes by int value. +// NOTE: ecode must unique in global, the New will check repeat and then panic. +func New(e int) Code { + if e <= 0 { + panic("business ecode must greater than zero") + } + return add(e) +} + +func add(e int) Code { + if _, ok := _codes[e]; ok { + panic(fmt.Sprintf("ecode: %d already exist", e)) + } + _codes[e] = struct{}{} + return Int(e) +} + +// Codes ecode error interface which has a code & message. +type Codes interface { + // sometimes Error return Code in string form + // NOTE: don't use Error in monitor report even it also work for now + Error() string + // Code get error code. + Code() int + // Message get code message. + Message() string + //Detail get error detail,it may be nil. + Details() []interface{} +} + +// A Code is an int error code spec. +type Code int + +func (e Code) Error() string { + return strconv.FormatInt(int64(e), 10) +} + +// Code return error code +func (e Code) Code() int { return int(e) } + +// Message return error message +func (e Code) Message() string { + if cm, ok := _messages.Load().(map[int]string); ok { + if msg, ok := cm[e.Code()]; ok { + return msg + } + } + return e.Error() +} + +// Details return details. +func (e Code) Details() []interface{} { return nil } + +// Equal for compatible. +// Deprecated: please use ecode.EqualError. +func (e Code) Equal(err error) bool { return EqualError(e, err) } + +// Int parse code int to error. +func Int(i int) Code { return Code(i) } + +// String parse code string to error. +func String(e string) Code { + if e == "" { + return OK + } + // try error string + i, err := strconv.Atoi(e) + if err != nil { + return ServerErr + } + return Code(i) +} + +// Cause cause from error to ecode. +func Cause(e error) Codes { + if e == nil { + return OK + } + ec, ok := errors.Cause(e).(Codes) + if ok { + return ec + } + return String(e.Error()) +} + +// Equal equal a and b by code int. +func Equal(a, b Codes) bool { + if a == nil { + a = OK + } + if b == nil { + b = OK + } + return a.Code() == b.Code() +} + +// EqualError equal error +func EqualError(code Codes, err error) bool { + return Cause(err).Code() == code.Code() +} diff --git a/pkg/naming/README.md b/pkg/naming/README.md new file mode 100644 index 000000000..5c4e8ac6a --- /dev/null +++ b/pkg/naming/README.md @@ -0,0 +1,5 @@ +# naming + +## 项目简介 + +服务发现、服务注册相关的SDK集合 diff --git a/pkg/naming/naming.go b/pkg/naming/naming.go new file mode 100644 index 000000000..b2cb6b803 --- /dev/null +++ b/pkg/naming/naming.go @@ -0,0 +1,58 @@ +package naming + +import ( + "context" +) + +// metadata common key +const ( + MetaColor = "color" + MetaWeight = "weight" + MetaCluster = "cluster" + MetaZone = "zone" +) + +// Instance represents a server the client connects to. +type Instance struct { + // Region bj/sh/gz + Region string `json:"region"` + // Zone is IDC. + Zone string `json:"zone"` + // Env prod/pre、uat/fat1 + Env string `json:"env"` + // AppID is mapping servicetree appid. + AppID string `json:"appid"` + // Hostname is hostname from docker. + Hostname string `json:"hostname"` + // Addrs is the adress of app instance + // format: scheme://host + Addrs []string `json:"addrs"` + // Version is publishing version. + Version string `json:"version"` + // LastTs is instance latest updated timestamp + LastTs int64 `json:"latest_timestamp"` + // Metadata is the information associated with Addr, which may be used + // to make load balancing decision. + Metadata map[string]string `json:"metadata"` + Status int64 +} + +// Resolver resolve naming service +type Resolver interface { + Fetch(context.Context) (map[string][]*Instance, bool) + //Unwatch(id string) + Watch() <-chan struct{} + Close() error +} + +// Registry Register an instance and renew automatically +type Registry interface { + Register(context.Context, *Instance) (context.CancelFunc, error) + Close() error +} + +// Builder resolver builder. +type Builder interface { + Build(id string) Resolver + Scheme() string +} diff --git a/pkg/net/ip/ip.go b/pkg/net/ip/ip.go new file mode 100644 index 000000000..54e241d1a --- /dev/null +++ b/pkg/net/ip/ip.go @@ -0,0 +1,66 @@ +package ip + +import ( + "net" + "strings" +) + +// ExternalIP get external ip. +func ExternalIP() (res []string) { + inters, err := net.Interfaces() + if err != nil { + return + } + for _, inter := range inters { + if !strings.HasPrefix(inter.Name, "lo") { + addrs, err := inter.Addrs() + if err != nil { + continue + } + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + if ipnet.IP.IsLoopback() || ipnet.IP.IsLinkLocalMulticast() || ipnet.IP.IsLinkLocalUnicast() { + continue + } + if ip4 := ipnet.IP.To4(); ip4 != nil { + switch true { + case ip4[0] == 10: + continue + case ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31: + continue + case ip4[0] == 192 && ip4[1] == 168: + continue + default: + res = append(res, ipnet.IP.String()) + } + } + } + } + } + } + return +} + +// InternalIP get internal ip. +func InternalIP() string { + inters, err := net.Interfaces() + if err != nil { + return "" + } + for _, inter := range inters { + if !strings.HasPrefix(inter.Name, "lo") { + addrs, err := inter.Addrs() + if err != nil { + continue + } + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { + if ipnet.IP.To4() != nil { + return ipnet.IP.String() + } + } + } + } + } + return "" +} diff --git a/pkg/net/metadata/README.md b/pkg/net/metadata/README.md new file mode 100644 index 000000000..eb51734b8 --- /dev/null +++ b/pkg/net/metadata/README.md @@ -0,0 +1,5 @@ +# net/metadata + +## 项目简介 + +用于储存各种元信息 diff --git a/pkg/net/metadata/key.go b/pkg/net/metadata/key.go new file mode 100644 index 000000000..0b86eecd3 --- /dev/null +++ b/pkg/net/metadata/key.go @@ -0,0 +1,39 @@ +package metadata + +// metadata common key +const ( + + // Network + RemoteIP = "remote_ip" + RemotePort = "remote_port" + ServerAddr = "server_addr" + ClientAddr = "client_addr" + + // Router + Cluster = "cluster" + Color = "color" + + // Trace + Trace = "trace" + Caller = "caller" + + // Timeout + Timeout = "timeout" + + // Dispatch + CPUUsage = "cpu_usage" + Errors = "errors" + Requests = "requests" + + // Mirror + Mirror = "mirror" + + // Mid 外网账户用户id + Mid = "mid" // NOTE: !!!业务可重新修改key名!!! + + // Username LDAP平台的username + Username = "username" + + // Device 客户端信息 + Device = "device" +) diff --git a/pkg/net/metadata/metadata.go b/pkg/net/metadata/metadata.go new file mode 100644 index 000000000..fb46f3d57 --- /dev/null +++ b/pkg/net/metadata/metadata.go @@ -0,0 +1,134 @@ +package metadata + +import ( + "context" + "fmt" + "strconv" +) + +// MD is a mapping from metadata keys to values. +type MD map[string]interface{} + +type mdKey struct{} + +// Len returns the number of items in md. +func (md MD) Len() int { + return len(md) +} + +// Copy returns a copy of md. +func (md MD) Copy() MD { + return Join(md) +} + +// New creates an MD from a given key-value map. +func New(m map[string]interface{}) MD { + md := MD{} + for k, val := range m { + md[k] = val + } + return md +} + +// Join joins any number of mds into a single MD. +// The order of values for each key is determined by the order in which +// the mds containing those values are presented to Join. +func Join(mds ...MD) MD { + out := MD{} + for _, md := range mds { + for k, v := range md { + out[k] = v + } + } + return out +} + +// Pairs returns an MD formed by the mapping of key, value ... +// Pairs panics if len(kv) is odd. +func Pairs(kv ...interface{}) MD { + if len(kv)%2 == 1 { + panic(fmt.Sprintf("metadata: Pairs got the odd number of input pairs for metadata: %d", len(kv))) + } + md := MD{} + var key string + for i, s := range kv { + if i%2 == 0 { + key = s.(string) + continue + } + md[key] = s + } + return md +} + +// NewContext creates a new context with md attached. +func NewContext(ctx context.Context, md MD) context.Context { + return context.WithValue(ctx, mdKey{}, md) +} + +// FromContext returns the incoming metadata in ctx if it exists. The +// returned MD should not be modified. Writing to it may cause races. +// Modification should be made to copies of the returned MD. +func FromContext(ctx context.Context) (md MD, ok bool) { + md, ok = ctx.Value(mdKey{}).(MD) + return +} + +// String get string value from metadata in context +func String(ctx context.Context, key string) string { + md, ok := ctx.Value(mdKey{}).(MD) + if !ok { + return "" + } + str, _ := md[key].(string) + return str +} + +// Int64 get int64 value from metadata in context +func Int64(ctx context.Context, key string) int64 { + md, ok := ctx.Value(mdKey{}).(MD) + if !ok { + return 0 + } + i64, _ := md[key].(int64) + return i64 +} + +// Value get value from metadata in context return nil if not found +func Value(ctx context.Context, key string) interface{} { + md, ok := ctx.Value(mdKey{}).(MD) + if !ok { + return nil + } + return md[key] +} + +// WithContext return no deadline context and retain metadata. +func WithContext(c context.Context) context.Context { + md, ok := FromContext(c) + if ok { + nmd := md.Copy() + // NOTE: temporary delete prevent asynchronous task reuse finished task + delete(nmd, Trace) + return NewContext(context.Background(), nmd) + } + return context.Background() +} + +// Bool get boolean from metadata in context use strconv.Parse. +func Bool(ctx context.Context, key string) bool { + md, ok := ctx.Value(mdKey{}).(MD) + if !ok { + return false + } + + switch md[key].(type) { + case bool: + return md[key].(bool) + case string: + ok, _ = strconv.ParseBool(md[key].(string)) + return ok + default: + return false + } +} diff --git a/pkg/net/metadata/metadata_test.go b/pkg/net/metadata/metadata_test.go new file mode 100644 index 000000000..db1e73647 --- /dev/null +++ b/pkg/net/metadata/metadata_test.go @@ -0,0 +1,96 @@ +package metadata + +import ( + "context" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPairsMD(t *testing.T) { + for _, test := range []struct { + // input + kv []interface{} + // output + md MD + }{ + {[]interface{}{}, MD{}}, + {[]interface{}{"k1", "v1", "k1", "v2"}, MD{"k1": "v2"}}, + } { + md := Pairs(test.kv...) + if !reflect.DeepEqual(md, test.md) { + t.Fatalf("Pairs(%v) = %v, want %v", test.kv, md, test.md) + } + } +} +func TestCopy(t *testing.T) { + const key, val = "key", "val" + orig := Pairs(key, val) + copy := orig.Copy() + if !reflect.DeepEqual(orig, copy) { + t.Errorf("copied value not equal to the original, got %v, want %v", copy, orig) + } + orig[key] = "foo" + if v := copy[key]; v != val { + t.Errorf("change in original should not affect copy, got %q, want %q", v, val) + } +} +func TestJoin(t *testing.T) { + for _, test := range []struct { + mds []MD + want MD + }{ + {[]MD{}, MD{}}, + {[]MD{Pairs("foo", "bar")}, Pairs("foo", "bar")}, + {[]MD{Pairs("foo", "bar"), Pairs("foo", "baz")}, Pairs("foo", "bar", "foo", "baz")}, + {[]MD{Pairs("foo", "bar"), Pairs("foo", "baz"), Pairs("zip", "zap")}, Pairs("foo", "bar", "foo", "baz", "zip", "zap")}, + } { + md := Join(test.mds...) + if !reflect.DeepEqual(md, test.want) { + t.Errorf("context's metadata is %v, want %v", md, test.want) + } + } +} + +func TestWithContext(t *testing.T) { + md := MD(map[string]interface{}{RemoteIP: "127.0.0.1", Color: "red", Mirror: true}) + c := NewContext(context.Background(), md) + ctx := WithContext(c) + md1, ok := FromContext(ctx) + if !ok { + t.Errorf("expect ok be true") + t.FailNow() + } + if !reflect.DeepEqual(md1, md) { + t.Errorf("expect md1 equal to md") + t.FailNow() + } +} + +func TestBool(t *testing.T) { + md := MD{RemoteIP: "127.0.0.1", Color: "red"} + mdcontext := NewContext(context.Background(), md) + assert.Equal(t, false, Bool(mdcontext, Mirror)) + + mdcontext = NewContext(context.Background(), MD{Mirror: true}) + assert.Equal(t, true, Bool(mdcontext, Mirror)) + + mdcontext = NewContext(context.Background(), MD{Mirror: "true"}) + assert.Equal(t, true, Bool(mdcontext, Mirror)) + + mdcontext = NewContext(context.Background(), MD{Mirror: "1"}) + assert.Equal(t, true, Bool(mdcontext, Mirror)) + + mdcontext = NewContext(context.Background(), MD{Mirror: "0"}) + assert.Equal(t, false, Bool(mdcontext, Mirror)) + +} +func TestInt64(t *testing.T) { + mdcontext := NewContext(context.Background(), MD{Mid: int64(1)}) + assert.Equal(t, int64(1), Int64(mdcontext, Mid)) + mdcontext = NewContext(context.Background(), MD{Mid: int64(2)}) + assert.NotEqual(t, int64(1), Int64(mdcontext, Mid)) + mdcontext = NewContext(context.Background(), MD{Mid: 10}) + assert.NotEqual(t, int64(10), Int64(mdcontext, Mid)) +} diff --git a/pkg/net/trace/README.md b/pkg/net/trace/README.md new file mode 100644 index 000000000..ce87fa101 --- /dev/null +++ b/pkg/net/trace/README.md @@ -0,0 +1,20 @@ +# net/trace + +## 项目简介 +1. 提供Trace的接口规范 +2. 提供 trace 对Tracer接口的实现,供业务接入使用 + +## 接入示例 +1. 启动接入示例 + ```go + trace.Init(traceConfig) // traceConfig is Config object with value. + ``` +2. 配置参考 + ```toml + [tracer] + network = "unixgram" + addr = "/var/run/dapper-collect/dapper-collect.sock" + ``` + +## 测试 +1. 执行当前目录下所有测试文件,测试所有功能 diff --git a/pkg/net/trace/config.go b/pkg/net/trace/config.go new file mode 100644 index 000000000..d6f918acd --- /dev/null +++ b/pkg/net/trace/config.go @@ -0,0 +1,75 @@ +package trace + +import ( + "flag" + "fmt" + "os" + "time" + + "github.com/pkg/errors" + + "github.com/bilibili/Kratos/pkg/conf/dsn" + "github.com/bilibili/Kratos/pkg/conf/env" + xtime "github.com/bilibili/Kratos/pkg/time" +) + +var _traceDSN = "unixgram:///var/run/dapper-collect/dapper-collect.sock" + +func init() { + if v := os.Getenv("TRACE"); v != "" { + _traceDSN = v + } + flag.StringVar(&_traceDSN, "trace", _traceDSN, "trace report dsn, or use TRACE env.") +} + +// Config config. +type Config struct { + // Report network e.g. unixgram, tcp, udp + Network string `dsn:"network"` + // For TCP and UDP networks, the addr has the form "host:port". + // For Unix networks, the address must be a file system path. + Addr string `dsn:"address"` + // Report timeout + Timeout xtime.Duration `dsn:"query.timeout,200ms"` + // DisableSample + DisableSample bool `dsn:"query.disable_sample"` + // ProtocolVersion + ProtocolVersion int32 `dsn:"query.protocol_version,1"` + // Probability probability sampling + Probability float32 `dsn:"-"` +} + +func parseDSN(rawdsn string) (*Config, error) { + d, err := dsn.Parse(rawdsn) + if err != nil { + return nil, errors.Wrapf(err, "trace: invalid dsn: %s", rawdsn) + } + cfg := new(Config) + if _, err = d.Bind(cfg); err != nil { + return nil, errors.Wrapf(err, "trace: invalid dsn: %s", rawdsn) + } + return cfg, nil +} + +// TracerFromEnvFlag new tracer from env and flag +func TracerFromEnvFlag() (Tracer, error) { + cfg, err := parseDSN(_traceDSN) + if err != nil { + return nil, err + } + report := newReport(cfg.Network, cfg.Addr, time.Duration(cfg.Timeout), cfg.ProtocolVersion) + return newTracer(env.AppID, report, cfg), nil +} + +// Init init trace report. +func Init(cfg *Config) { + if cfg == nil { + // paser config from env + var err error + if cfg, err = parseDSN(_traceDSN); err != nil { + panic(fmt.Errorf("parse trace dsn error: %s", err)) + } + } + report := newReport(cfg.Network, cfg.Addr, time.Duration(cfg.Timeout), cfg.ProtocolVersion) + SetGlobalTracer(newTracer(env.AppID, report, cfg)) +} diff --git a/pkg/net/trace/config_test.go b/pkg/net/trace/config_test.go new file mode 100644 index 000000000..1d7726806 --- /dev/null +++ b/pkg/net/trace/config_test.go @@ -0,0 +1,33 @@ +package trace + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseDSN(t *testing.T) { + _, err := parseDSN(_traceDSN) + if err != nil { + t.Error(err) + } +} + +func TestTraceFromEnvFlag(t *testing.T) { + _, err := TracerFromEnvFlag() + if err != nil { + t.Error(err) + } +} + +func TestInit(t *testing.T) { + Init(nil) + _, ok := _tracer.(nooptracer) + assert.False(t, ok) + + _tracer = nooptracer{} + + Init(&Config{Network: "unixgram", Addr: "unixgram:///var/run/dapper-collect/dapper-collect.sock"}) + _, ok = _tracer.(nooptracer) + assert.False(t, ok) +} diff --git a/pkg/net/trace/const.go b/pkg/net/trace/const.go new file mode 100644 index 000000000..616b0fe08 --- /dev/null +++ b/pkg/net/trace/const.go @@ -0,0 +1,7 @@ +package trace + +// Trace key +const ( + KratosTraceID = "kratos-trace-id" + KratosTraceDebug = "kratos-trace-debug" +) diff --git a/pkg/net/trace/context.go b/pkg/net/trace/context.go new file mode 100644 index 000000000..e3d20bd94 --- /dev/null +++ b/pkg/net/trace/context.go @@ -0,0 +1,110 @@ +package trace + +import ( + "strconv" + "strings" + + "github.com/pkg/errors" +) + +const ( + flagSampled = 0x01 + flagDebug = 0x02 +) + +var ( + errEmptyTracerString = errors.New("trace: cannot convert empty string to spancontext") + errInvalidTracerString = errors.New("trace: string does not match spancontext string format") +) + +// SpanContext implements opentracing.SpanContext +type spanContext struct { + // traceID represents globally unique ID of the trace. + // Usually generated as a random number. + traceID uint64 + + // spanID represents span ID that must be unique within its trace, + // but does not have to be globally unique. + spanID uint64 + + // parentID refers to the ID of the parent span. + // Should be 0 if the current span is a root span. + parentID uint64 + + // flags is a bitmap containing such bits as 'sampled' and 'debug'. + flags byte + + // probability + probability float32 + + // current level + level int +} + +func (c spanContext) isSampled() bool { + return (c.flags & flagSampled) == flagSampled +} + +func (c spanContext) isDebug() bool { + return (c.flags & flagDebug) == flagDebug +} + +// IsValid check spanContext valid +func (c spanContext) IsValid() bool { + return c.traceID != 0 && c.spanID != 0 +} + +// emptyContext emptyContext +var emptyContext = spanContext{} + +// String convert spanContext to String +// {TraceID}:{SpanID}:{ParentID}:{flags}:[extend...] +// TraceID: uint64 base16 +// SpanID: uint64 base16 +// ParentID: uint64 base16 +// flags: +// - :0 sampled flag +// - :1 debug flag +// extend: +// sample-rate: s-{base16(BigEndian(float32))} +func (c spanContext) String() string { + base := make([]string, 4) + base[0] = strconv.FormatUint(uint64(c.traceID), 16) + base[1] = strconv.FormatUint(uint64(c.spanID), 16) + base[2] = strconv.FormatUint(uint64(c.parentID), 16) + base[3] = strconv.FormatUint(uint64(c.flags), 16) + return strings.Join(base, ":") +} + +// ContextFromString parse spanContext form string +func contextFromString(value string) (spanContext, error) { + if value == "" { + return emptyContext, errEmptyTracerString + } + items := strings.Split(value, ":") + if len(items) < 4 { + return emptyContext, errInvalidTracerString + } + parseHexUint64 := func(hexs []string) ([]uint64, error) { + rets := make([]uint64, len(hexs)) + var err error + for i, hex := range hexs { + rets[i], err = strconv.ParseUint(hex, 16, 64) + if err != nil { + break + } + } + return rets, err + } + rets, err := parseHexUint64(items[0:4]) + if err != nil { + return emptyContext, errInvalidTracerString + } + sctx := spanContext{ + traceID: rets[0], + spanID: rets[1], + parentID: rets[2], + flags: byte(rets[3]), + } + return sctx, nil +} diff --git a/pkg/net/trace/context_test.go b/pkg/net/trace/context_test.go new file mode 100644 index 000000000..d2d85baf9 --- /dev/null +++ b/pkg/net/trace/context_test.go @@ -0,0 +1,26 @@ +package trace + +import ( + "testing" +) + +func TestSpanContext(t *testing.T) { + pctx := &spanContext{ + parentID: genID(), + spanID: genID(), + traceID: genID(), + flags: flagSampled, + } + if !pctx.isSampled() { + t.Error("expect sampled") + } + value := pctx.String() + t.Logf("bili-trace-id: %s", value) + pctx2, err := contextFromString(value) + if err != nil { + t.Error(err) + } + if pctx2.parentID != pctx.parentID || pctx2.spanID != pctx.spanID || pctx2.traceID != pctx.traceID || pctx2.flags != pctx.flags { + t.Errorf("wrong spancontext get %+v -> %+v", pctx, pctx2) + } +} diff --git a/pkg/net/trace/dapper.go b/pkg/net/trace/dapper.go new file mode 100644 index 000000000..5a1675ddf --- /dev/null +++ b/pkg/net/trace/dapper.go @@ -0,0 +1,189 @@ +package trace + +import ( + "log" + "os" + "sync" + "time" +) + +const ( + _maxLevel = 64 + _probability = 0.00025 +) + +func newTracer(serviceName string, report reporter, cfg *Config) Tracer { + // hard code reset probability at 0.00025, 1/4000 + cfg.Probability = _probability + sampler := newSampler(cfg.Probability) + + // default internal tags + tags := extendTag() + stdlog := log.New(os.Stderr, "trace", log.LstdFlags) + return &dapper{ + cfg: cfg, + serviceName: serviceName, + propagators: map[interface{}]propagator{ + HTTPFormat: httpPropagator{}, + GRPCFormat: grpcPropagator{}, + }, + reporter: report, + sampler: sampler, + tags: tags, + pool: &sync.Pool{New: func() interface{} { return new(span) }}, + stdlog: stdlog, + } +} + +type dapper struct { + cfg *Config + serviceName string + tags []Tag + reporter reporter + propagators map[interface{}]propagator + pool *sync.Pool + stdlog *log.Logger + sampler sampler +} + +func (d *dapper) New(operationName string, opts ...Option) Trace { + opt := defaultOption + for _, fn := range opts { + fn(&opt) + } + traceID := genID() + var sampled bool + var probability float32 + if d.cfg.DisableSample { + sampled = true + probability = 1 + } else { + sampled, probability = d.sampler.IsSampled(traceID, operationName) + } + pctx := spanContext{traceID: traceID} + if sampled { + pctx.flags = flagSampled + pctx.probability = probability + } + if opt.Debug { + pctx.flags |= flagDebug + return d.newSpanWithContext(operationName, pctx).SetTag(TagString(TagSpanKind, "server")).SetTag(TagBool("debug", true)) + } + // 为了兼容临时为 New 的 Span 设置 span.kind + return d.newSpanWithContext(operationName, pctx).SetTag(TagString(TagSpanKind, "server")) +} + +func (d *dapper) newSpanWithContext(operationName string, pctx spanContext) Trace { + sp := d.getSpan() + // is span is not sampled just return a span with this context, no need clear it + //if !pctx.isSampled() { + // sp.context = pctx + // return sp + //} + if pctx.level > _maxLevel { + // if span reach max limit level return noopspan + return noopspan{} + } + level := pctx.level + 1 + nctx := spanContext{ + traceID: pctx.traceID, + parentID: pctx.spanID, + flags: pctx.flags, + level: level, + } + if pctx.spanID == 0 { + nctx.spanID = pctx.traceID + } else { + nctx.spanID = genID() + } + sp.operationName = operationName + sp.context = nctx + sp.startTime = time.Now() + sp.tags = append(sp.tags, d.tags...) + return sp +} + +func (d *dapper) Inject(t Trace, format interface{}, carrier interface{}) error { + // if carrier implement Carrier use direct, ignore format + carr, ok := carrier.(Carrier) + if ok { + t.Visit(carr.Set) + return nil + } + // use Built-in propagators + pp, ok := d.propagators[format] + if !ok { + return ErrUnsupportedFormat + } + carr, err := pp.Inject(carrier) + if err != nil { + return err + } + if t != nil { + t.Visit(carr.Set) + } + return nil +} + +func (d *dapper) Extract(format interface{}, carrier interface{}) (Trace, error) { + sp, err := d.extract(format, carrier) + if err != nil { + return sp, err + } + // 为了兼容临时为 New 的 Span 设置 span.kind + return sp.SetTag(TagString(TagSpanKind, "server")), nil +} + +func (d *dapper) extract(format interface{}, carrier interface{}) (Trace, error) { + // if carrier implement Carrier use direct, ignore format + carr, ok := carrier.(Carrier) + if !ok { + // use Built-in propagators + pp, ok := d.propagators[format] + if !ok { + return nil, ErrUnsupportedFormat + } + var err error + if carr, err = pp.Extract(carrier); err != nil { + return nil, err + } + } + pctx, err := contextFromString(carr.Get(KratosTraceID)) + if err != nil { + return nil, err + } + // NOTE: call SetTitle after extract trace + return d.newSpanWithContext("", pctx), nil +} + +func (d *dapper) Close() error { + return d.reporter.Close() +} + +func (d *dapper) report(sp *span) { + if sp.context.isSampled() { + if err := d.reporter.WriteSpan(sp); err != nil { + d.stdlog.Printf("marshal trace span error: %s", err) + } + } + d.putSpan(sp) +} + +func (d *dapper) putSpan(sp *span) { + if len(sp.tags) > 32 { + sp.tags = nil + } + if len(sp.logs) > 32 { + sp.logs = nil + } + d.pool.Put(sp) +} + +func (d *dapper) getSpan() *span { + sp := d.pool.Get().(*span) + sp.dapper = d + sp.childs = 0 + sp.tags = sp.tags[:0] + sp.logs = sp.logs[:0] + return sp +} diff --git a/pkg/net/trace/dapper_test.go b/pkg/net/trace/dapper_test.go new file mode 100644 index 000000000..f5e3ab7da --- /dev/null +++ b/pkg/net/trace/dapper_test.go @@ -0,0 +1,136 @@ +package trace + +import ( + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" +) + +type mockReport struct { + sps []*span +} + +func (m *mockReport) WriteSpan(sp *span) error { + m.sps = append(m.sps, sp) + return nil +} + +func (m *mockReport) Close() error { + return nil +} + +func TestDapperPropagation(t *testing.T) { + t.Run("test HTTP progagation", func(t *testing.T) { + report := &mockReport{} + t1 := newTracer("service1", report, &Config{DisableSample: true}) + t2 := newTracer("service2", report, &Config{DisableSample: true}) + sp1 := t1.New("opt_1") + sp2 := sp1.Fork("", "opt_client") + header := make(http.Header) + t1.Inject(sp2, HTTPFormat, header) + sp3, err := t2.Extract(HTTPFormat, header) + if err != nil { + t.Fatal(err) + } + sp3.Finish(nil) + sp2.Finish(nil) + sp1.Finish(nil) + + assert.Len(t, report.sps, 3) + assert.Equal(t, report.sps[2].context.parentID, uint64(0)) + assert.Equal(t, report.sps[0].context.traceID, report.sps[1].context.traceID) + assert.Equal(t, report.sps[2].context.traceID, report.sps[1].context.traceID) + + assert.Equal(t, report.sps[1].context.parentID, report.sps[2].context.spanID) + assert.Equal(t, report.sps[0].context.parentID, report.sps[1].context.spanID) + }) + t.Run("test gRPC progagation", func(t *testing.T) { + report := &mockReport{} + t1 := newTracer("service1", report, &Config{DisableSample: true}) + t2 := newTracer("service2", report, &Config{DisableSample: true}) + sp1 := t1.New("opt_1") + sp2 := sp1.Fork("", "opt_client") + md := make(metadata.MD) + t1.Inject(sp2, GRPCFormat, md) + sp3, err := t2.Extract(GRPCFormat, md) + if err != nil { + t.Fatal(err) + } + sp3.Finish(nil) + sp2.Finish(nil) + sp1.Finish(nil) + + assert.Len(t, report.sps, 3) + assert.Equal(t, report.sps[2].context.parentID, uint64(0)) + assert.Equal(t, report.sps[0].context.traceID, report.sps[1].context.traceID) + assert.Equal(t, report.sps[2].context.traceID, report.sps[1].context.traceID) + + assert.Equal(t, report.sps[1].context.parentID, report.sps[2].context.spanID) + assert.Equal(t, report.sps[0].context.parentID, report.sps[1].context.spanID) + }) + t.Run("test normal", func(t *testing.T) { + report := &mockReport{} + t1 := newTracer("service1", report, &Config{Probability: 0.000000001}) + sp1 := t1.New("test123") + sp1.Finish(nil) + }) + t.Run("test debug progagation", func(t *testing.T) { + report := &mockReport{} + t1 := newTracer("service1", report, &Config{}) + t2 := newTracer("service2", report, &Config{}) + sp1 := t1.New("opt_1", EnableDebug()) + sp2 := sp1.Fork("", "opt_client") + header := make(http.Header) + t1.Inject(sp2, HTTPFormat, header) + sp3, err := t2.Extract(HTTPFormat, header) + if err != nil { + t.Fatal(err) + } + sp3.Finish(nil) + sp2.Finish(nil) + sp1.Finish(nil) + + assert.Len(t, report.sps, 3) + assert.Equal(t, report.sps[2].context.parentID, uint64(0)) + assert.Equal(t, report.sps[0].context.traceID, report.sps[1].context.traceID) + assert.Equal(t, report.sps[2].context.traceID, report.sps[1].context.traceID) + + assert.Equal(t, report.sps[1].context.parentID, report.sps[2].context.spanID) + assert.Equal(t, report.sps[0].context.parentID, report.sps[1].context.spanID) + }) +} + +func BenchmarkSample(b *testing.B) { + err := fmt.Errorf("test error") + report := &mockReport{} + t1 := newTracer("service1", report, &Config{}) + for i := 0; i < b.N; i++ { + sp1 := t1.New("test_opt1") + sp1.SetTag(TagString("test", "123")) + sp2 := sp1.Fork("", "opt2") + sp3 := sp2.Fork("", "opt3") + sp3.SetTag(TagString("test", "123")) + sp3.Finish(nil) + sp2.Finish(&err) + sp1.Finish(nil) + } +} + +func BenchmarkDisableSample(b *testing.B) { + err := fmt.Errorf("test error") + report := &mockReport{} + t1 := newTracer("service1", report, &Config{DisableSample: true}) + for i := 0; i < b.N; i++ { + sp1 := t1.New("test_opt1") + sp1.SetTag(TagString("test", "123")) + sp2 := sp1.Fork("", "opt2") + sp3 := sp2.Fork("", "opt3") + sp3.SetTag(TagString("test", "123")) + sp3.Finish(nil) + sp2.Finish(&err) + sp1.Finish(nil) + } +} diff --git a/pkg/net/trace/marshal.go b/pkg/net/trace/marshal.go new file mode 100644 index 000000000..414dfb59c --- /dev/null +++ b/pkg/net/trace/marshal.go @@ -0,0 +1,106 @@ +package trace + +import ( + "encoding/binary" + errs "errors" + "fmt" + "math" + "time" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes/duration" + "github.com/golang/protobuf/ptypes/timestamp" + + protogen "github.com/bilibili/Kratos/pkg/net/trace/proto" +) + +const protoVersion1 int32 = 1 + +var ( + errSpanVersion = errs.New("trace: marshal not support version") +) + +func marshalSpan(sp *span, version int32) ([]byte, error) { + if version == protoVersion1 { + return marshalSpanV1(sp) + } + return nil, errSpanVersion +} + +func marshalSpanV1(sp *span) ([]byte, error) { + protoSpan := new(protogen.Span) + protoSpan.Version = protoVersion1 + protoSpan.ServiceName = sp.dapper.serviceName + protoSpan.OperationName = sp.operationName + protoSpan.TraceId = sp.context.traceID + protoSpan.SpanId = sp.context.spanID + protoSpan.ParentId = sp.context.parentID + protoSpan.SamplingProbability = sp.context.probability + protoSpan.StartTime = ×tamp.Timestamp{ + Seconds: sp.startTime.Unix(), + Nanos: int32(sp.startTime.Nanosecond()), + } + protoSpan.Duration = &duration.Duration{ + Seconds: int64(sp.duration / time.Second), + Nanos: int32(sp.duration % time.Second), + } + protoSpan.Tags = make([]*protogen.Tag, len(sp.tags)) + for i := range sp.tags { + protoSpan.Tags[i] = toProtoTag(sp.tags[i]) + } + protoSpan.Logs = sp.logs + return proto.Marshal(protoSpan) +} + +func toProtoTag(tag Tag) *protogen.Tag { + ptag := &protogen.Tag{Key: tag.Key} + switch value := tag.Value.(type) { + case string: + ptag.Kind = protogen.Tag_STRING + ptag.Value = []byte(value) + case int: + ptag.Kind = protogen.Tag_INT + ptag.Value = serializeInt64(int64(value)) + case int32: + ptag.Kind = protogen.Tag_INT + ptag.Value = serializeInt64(int64(value)) + case int64: + ptag.Kind = protogen.Tag_INT + ptag.Value = serializeInt64(value) + case bool: + ptag.Kind = protogen.Tag_BOOL + ptag.Value = serializeBool(value) + case float32: + ptag.Kind = protogen.Tag_BOOL + ptag.Value = serializeFloat64(float64(value)) + case float64: + ptag.Kind = protogen.Tag_BOOL + ptag.Value = serializeFloat64(value) + default: + ptag.Kind = protogen.Tag_STRING + ptag.Value = []byte((fmt.Sprintf("%v", tag.Value))) + } + return ptag +} + +func serializeInt64(v int64) []byte { + data := make([]byte, 8) + binary.BigEndian.PutUint64(data, uint64(v)) + return data +} + +func serializeFloat64(v float64) []byte { + data := make([]byte, 8) + binary.BigEndian.PutUint64(data, math.Float64bits(v)) + return data +} + +func serializeBool(v bool) []byte { + data := make([]byte, 1) + if v { + data[0] = byte(1) + } else { + data[0] = byte(0) + } + return data +} diff --git a/pkg/net/trace/marshal_test.go b/pkg/net/trace/marshal_test.go new file mode 100644 index 000000000..30fe50cc9 --- /dev/null +++ b/pkg/net/trace/marshal_test.go @@ -0,0 +1,18 @@ +package trace + +import ( + "testing" +) + +func TestMarshalSpanV1(t *testing.T) { + report := &mockReport{} + t1 := newTracer("service1", report, &Config{DisableSample: true}) + sp1 := t1.New("opt_test").(*span) + sp1.SetLog(Log("hello", "test123")) + sp1.SetTag(TagString("tag1", "hell"), TagBool("booltag", true), TagFloat64("float64tag", 3.14159)) + sp1.Finish(nil) + _, err := marshalSpanV1(sp1) + if err != nil { + t.Error(err) + } +} diff --git a/pkg/net/trace/noop.go b/pkg/net/trace/noop.go new file mode 100644 index 000000000..239c29c45 --- /dev/null +++ b/pkg/net/trace/noop.go @@ -0,0 +1,45 @@ +package trace + +var ( + _ Tracer = nooptracer{} +) + +type nooptracer struct{} + +func (n nooptracer) New(title string, opts ...Option) Trace { + return noopspan{} +} + +func (n nooptracer) Inject(t Trace, format interface{}, carrier interface{}) error { + return nil +} + +func (n nooptracer) Extract(format interface{}, carrier interface{}) (Trace, error) { + return noopspan{}, nil +} + +type noopspan struct{} + +func (n noopspan) Fork(string, string) Trace { + return noopspan{} +} + +func (n noopspan) Follow(string, string) Trace { + return noopspan{} +} + +func (n noopspan) Finish(err *error) {} + +func (n noopspan) SetTag(tags ...Tag) Trace { + return noopspan{} +} + +func (n noopspan) SetLog(logs ...LogField) Trace { + return noopspan{} +} + +func (n noopspan) Visit(func(k, v string)) {} + +func (n noopspan) SetTitle(string) {} + +func (n noopspan) String() string { return "" } diff --git a/pkg/net/trace/option.go b/pkg/net/trace/option.go new file mode 100644 index 000000000..f0865c2d7 --- /dev/null +++ b/pkg/net/trace/option.go @@ -0,0 +1,17 @@ +package trace + +var defaultOption = option{} + +type option struct { + Debug bool +} + +// Option dapper Option +type Option func(*option) + +// EnableDebug enable debug mode +func EnableDebug() Option { + return func(opt *option) { + opt.Debug = true + } +} diff --git a/pkg/net/trace/propagation.go b/pkg/net/trace/propagation.go new file mode 100644 index 000000000..0eb45bd23 --- /dev/null +++ b/pkg/net/trace/propagation.go @@ -0,0 +1,177 @@ +package trace + +import ( + errs "errors" + "net/http" + + "google.golang.org/grpc/metadata" +) + +var ( + // ErrUnsupportedFormat occurs when the `format` passed to Tracer.Inject() or + // Tracer.Extract() is not recognized by the Tracer implementation. + ErrUnsupportedFormat = errs.New("trace: Unknown or unsupported Inject/Extract format") + + // ErrTraceNotFound occurs when the `carrier` passed to + // Tracer.Extract() is valid and uncorrupted but has insufficient + // information to extract a Trace. + ErrTraceNotFound = errs.New("trace: Trace not found in Extract carrier") + + // ErrInvalidTrace errors occur when Tracer.Inject() is asked to + // operate on a Trace which it is not prepared to handle (for + // example, since it was created by a different tracer implementation). + ErrInvalidTrace = errs.New("trace: Trace type incompatible with tracer") + + // ErrInvalidCarrier errors occur when Tracer.Inject() or Tracer.Extract() + // implementations expect a different type of `carrier` than they are + // given. + ErrInvalidCarrier = errs.New("trace: Invalid Inject/Extract carrier") + + // ErrTraceCorrupted occurs when the `carrier` passed to + // Tracer.Extract() is of the expected type but is corrupted. + ErrTraceCorrupted = errs.New("trace: Trace data corrupted in Extract carrier") +) + +// BuiltinFormat is used to demarcate the values within package `trace` +// that are intended for use with the Tracer.Inject() and Tracer.Extract() +// methods. +type BuiltinFormat byte + +// support format list +const ( + // HTTPFormat represents Trace as HTTP header string pairs. + // + // the HTTPFormat format requires that the keys and values + // be valid as HTTP headers as-is (i.e., character casing may be unstable + // and special characters are disallowed in keys, values should be + // URL-escaped, etc). + // + // the carrier must be a `http.Header`. + HTTPFormat BuiltinFormat = iota + // GRPCFormat represents Trace as gRPC metadata. + // + // the carrier must be a `google.golang.org/grpc/metadata.MD`. + GRPCFormat +) + +// Carrier propagator must convert generic interface{} to something this +// implement Carrier interface, Trace can use Carrier to represents itself. +type Carrier interface { + Set(key, val string) + Get(key string) string +} + +// propagator is responsible for injecting and extracting `Trace` instances +// from a format-specific "carrier" +type propagator interface { + Inject(carrier interface{}) (Carrier, error) + Extract(carrier interface{}) (Carrier, error) +} + +type httpPropagator struct{} + +type httpCarrier http.Header + +func (h httpCarrier) Set(key, val string) { + http.Header(h).Set(key, val) +} + +func (h httpCarrier) Get(key string) string { + return http.Header(h).Get(key) +} + +func (httpPropagator) Inject(carrier interface{}) (Carrier, error) { + header, ok := carrier.(http.Header) + if !ok { + return nil, ErrInvalidCarrier + } + if header == nil { + return nil, ErrInvalidTrace + } + return httpCarrier(header), nil +} + +func (httpPropagator) Extract(carrier interface{}) (Carrier, error) { + header, ok := carrier.(http.Header) + if !ok { + return nil, ErrInvalidCarrier + } + if header == nil { + return nil, ErrTraceNotFound + } + return httpCarrier(header), nil +} + +const legacyGRPCKey = "trace" + +type grpcPropagator struct{} + +type grpcCarrier map[string][]string + +func (g grpcCarrier) Get(key string) string { + if v, ok := g[key]; ok && len(v) > 0 { + return v[0] + } + // ts := g[legacyGRPCKey] + // if len(ts) != 8 { + // return "" + // } + // switch key { + // case KeyTraceID: + // return ts[0] + // case KeyTraceSpanID: + // return ts[1] + // case KeyTraceParentID: + // return ts[2] + // case KeyTraceLevel: + // return ts[3] + // case KeyTraceSampled: + // return ts[4] + // case KeyTraceCaller: + // return ts[5] + // } + return "" +} + +func (g grpcCarrier) Set(key, val string) { + // ts := make([]string, 8) + // g[legacyGRPCKey] = ts + // switch key { + // case KeyTraceID: + // ts[0] = val + // case KeyTraceSpanID: + // ts[1] = val + // case KeyTraceParentID: + // ts[2] = val + // case KeyTraceLevel: + // ts[3] = val + // case KeyTraceSampled: + // ts[4] = val + // case KeyTraceCaller: + // ts[5] = val + // default: + g[key] = append(g[key], val) + // } +} + +func (grpcPropagator) Inject(carrier interface{}) (Carrier, error) { + md, ok := carrier.(metadata.MD) + if !ok { + return nil, ErrInvalidCarrier + } + if md == nil { + return nil, ErrInvalidTrace + } + return grpcCarrier(md), nil +} + +func (grpcPropagator) Extract(carrier interface{}) (Carrier, error) { + md, ok := carrier.(metadata.MD) + if !ok { + return nil, ErrInvalidCarrier + } + if md == nil { + return nil, ErrTraceNotFound + } + return grpcCarrier(md), nil +} diff --git a/pkg/net/trace/proto/span.pb.go b/pkg/net/trace/proto/span.pb.go new file mode 100644 index 000000000..55f612fc9 --- /dev/null +++ b/pkg/net/trace/proto/span.pb.go @@ -0,0 +1,557 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: proto/span.proto + +package protogen + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" +import duration "github.com/golang/protobuf/ptypes/duration" +import timestamp "github.com/golang/protobuf/ptypes/timestamp" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type Tag_Kind int32 + +const ( + Tag_STRING Tag_Kind = 0 + Tag_INT Tag_Kind = 1 + Tag_BOOL Tag_Kind = 2 + Tag_FLOAT Tag_Kind = 3 +) + +var Tag_Kind_name = map[int32]string{ + 0: "STRING", + 1: "INT", + 2: "BOOL", + 3: "FLOAT", +} +var Tag_Kind_value = map[string]int32{ + "STRING": 0, + "INT": 1, + "BOOL": 2, + "FLOAT": 3, +} + +func (x Tag_Kind) String() string { + return proto.EnumName(Tag_Kind_name, int32(x)) +} +func (Tag_Kind) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_span_68a8dae26ef502a2, []int{0, 0} +} + +type Log_Kind int32 + +const ( + Log_STRING Log_Kind = 0 + Log_INT Log_Kind = 1 + Log_BOOL Log_Kind = 2 + Log_FLOAT Log_Kind = 3 +) + +var Log_Kind_name = map[int32]string{ + 0: "STRING", + 1: "INT", + 2: "BOOL", + 3: "FLOAT", +} +var Log_Kind_value = map[string]int32{ + "STRING": 0, + "INT": 1, + "BOOL": 2, + "FLOAT": 3, +} + +func (x Log_Kind) String() string { + return proto.EnumName(Log_Kind_name, int32(x)) +} +func (Log_Kind) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_span_68a8dae26ef502a2, []int{2, 0} +} + +type SpanRef_RefType int32 + +const ( + SpanRef_CHILD_OF SpanRef_RefType = 0 + SpanRef_FOLLOWS_FROM SpanRef_RefType = 1 +) + +var SpanRef_RefType_name = map[int32]string{ + 0: "CHILD_OF", + 1: "FOLLOWS_FROM", +} +var SpanRef_RefType_value = map[string]int32{ + "CHILD_OF": 0, + "FOLLOWS_FROM": 1, +} + +func (x SpanRef_RefType) String() string { + return proto.EnumName(SpanRef_RefType_name, int32(x)) +} +func (SpanRef_RefType) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_span_68a8dae26ef502a2, []int{3, 0} +} + +type Tag struct { + Key string `protobuf:"bytes,1,opt,name=key" json:"key,omitempty"` + Kind Tag_Kind `protobuf:"varint,2,opt,name=kind,enum=dapper.trace.Tag_Kind" json:"kind,omitempty"` + Value []byte `protobuf:"bytes,3,opt,name=value,proto3" json:"value,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Tag) Reset() { *m = Tag{} } +func (m *Tag) String() string { return proto.CompactTextString(m) } +func (*Tag) ProtoMessage() {} +func (*Tag) Descriptor() ([]byte, []int) { + return fileDescriptor_span_68a8dae26ef502a2, []int{0} +} +func (m *Tag) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Tag.Unmarshal(m, b) +} +func (m *Tag) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Tag.Marshal(b, m, deterministic) +} +func (dst *Tag) XXX_Merge(src proto.Message) { + xxx_messageInfo_Tag.Merge(dst, src) +} +func (m *Tag) XXX_Size() int { + return xxx_messageInfo_Tag.Size(m) +} +func (m *Tag) XXX_DiscardUnknown() { + xxx_messageInfo_Tag.DiscardUnknown(m) +} + +var xxx_messageInfo_Tag proto.InternalMessageInfo + +func (m *Tag) GetKey() string { + if m != nil { + return m.Key + } + return "" +} + +func (m *Tag) GetKind() Tag_Kind { + if m != nil { + return m.Kind + } + return Tag_STRING +} + +func (m *Tag) GetValue() []byte { + if m != nil { + return m.Value + } + return nil +} + +type Field struct { + Key string `protobuf:"bytes,1,opt,name=key" json:"key,omitempty"` + Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Field) Reset() { *m = Field{} } +func (m *Field) String() string { return proto.CompactTextString(m) } +func (*Field) ProtoMessage() {} +func (*Field) Descriptor() ([]byte, []int) { + return fileDescriptor_span_68a8dae26ef502a2, []int{1} +} +func (m *Field) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Field.Unmarshal(m, b) +} +func (m *Field) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Field.Marshal(b, m, deterministic) +} +func (dst *Field) XXX_Merge(src proto.Message) { + xxx_messageInfo_Field.Merge(dst, src) +} +func (m *Field) XXX_Size() int { + return xxx_messageInfo_Field.Size(m) +} +func (m *Field) XXX_DiscardUnknown() { + xxx_messageInfo_Field.DiscardUnknown(m) +} + +var xxx_messageInfo_Field proto.InternalMessageInfo + +func (m *Field) GetKey() string { + if m != nil { + return m.Key + } + return "" +} + +func (m *Field) GetValue() []byte { + if m != nil { + return m.Value + } + return nil +} + +type Log struct { + Key string `protobuf:"bytes,1,opt,name=key" json:"key,omitempty"` + Kind Log_Kind `protobuf:"varint,2,opt,name=kind,enum=dapper.trace.Log_Kind" json:"kind,omitempty"` + Value []byte `protobuf:"bytes,3,opt,name=value,proto3" json:"value,omitempty"` + Timestamp int64 `protobuf:"varint,4,opt,name=timestamp" json:"timestamp,omitempty"` + Fields []*Field `protobuf:"bytes,5,rep,name=fields" json:"fields,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Log) Reset() { *m = Log{} } +func (m *Log) String() string { return proto.CompactTextString(m) } +func (*Log) ProtoMessage() {} +func (*Log) Descriptor() ([]byte, []int) { + return fileDescriptor_span_68a8dae26ef502a2, []int{2} +} +func (m *Log) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Log.Unmarshal(m, b) +} +func (m *Log) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Log.Marshal(b, m, deterministic) +} +func (dst *Log) XXX_Merge(src proto.Message) { + xxx_messageInfo_Log.Merge(dst, src) +} +func (m *Log) XXX_Size() int { + return xxx_messageInfo_Log.Size(m) +} +func (m *Log) XXX_DiscardUnknown() { + xxx_messageInfo_Log.DiscardUnknown(m) +} + +var xxx_messageInfo_Log proto.InternalMessageInfo + +func (m *Log) GetKey() string { + if m != nil { + return m.Key + } + return "" +} + +func (m *Log) GetKind() Log_Kind { + if m != nil { + return m.Kind + } + return Log_STRING +} + +func (m *Log) GetValue() []byte { + if m != nil { + return m.Value + } + return nil +} + +func (m *Log) GetTimestamp() int64 { + if m != nil { + return m.Timestamp + } + return 0 +} + +func (m *Log) GetFields() []*Field { + if m != nil { + return m.Fields + } + return nil +} + +// SpanRef describes causal relationship of the current span to another span (e.g. 'child-of') +type SpanRef struct { + RefType SpanRef_RefType `protobuf:"varint,1,opt,name=ref_type,json=refType,enum=dapper.trace.SpanRef_RefType" json:"ref_type,omitempty"` + TraceId uint64 `protobuf:"varint,2,opt,name=trace_id,json=traceId" json:"trace_id,omitempty"` + SpanId uint64 `protobuf:"varint,3,opt,name=span_id,json=spanId" json:"span_id,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *SpanRef) Reset() { *m = SpanRef{} } +func (m *SpanRef) String() string { return proto.CompactTextString(m) } +func (*SpanRef) ProtoMessage() {} +func (*SpanRef) Descriptor() ([]byte, []int) { + return fileDescriptor_span_68a8dae26ef502a2, []int{3} +} +func (m *SpanRef) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_SpanRef.Unmarshal(m, b) +} +func (m *SpanRef) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_SpanRef.Marshal(b, m, deterministic) +} +func (dst *SpanRef) XXX_Merge(src proto.Message) { + xxx_messageInfo_SpanRef.Merge(dst, src) +} +func (m *SpanRef) XXX_Size() int { + return xxx_messageInfo_SpanRef.Size(m) +} +func (m *SpanRef) XXX_DiscardUnknown() { + xxx_messageInfo_SpanRef.DiscardUnknown(m) +} + +var xxx_messageInfo_SpanRef proto.InternalMessageInfo + +func (m *SpanRef) GetRefType() SpanRef_RefType { + if m != nil { + return m.RefType + } + return SpanRef_CHILD_OF +} + +func (m *SpanRef) GetTraceId() uint64 { + if m != nil { + return m.TraceId + } + return 0 +} + +func (m *SpanRef) GetSpanId() uint64 { + if m != nil { + return m.SpanId + } + return 0 +} + +// Span represents a named unit of work performed by a service. +type Span struct { + Version int32 `protobuf:"varint,99,opt,name=version" json:"version,omitempty"` + ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName" json:"service_name,omitempty"` + OperationName string `protobuf:"bytes,2,opt,name=operation_name,json=operationName" json:"operation_name,omitempty"` + // Deprecated: caller no long required + Caller string `protobuf:"bytes,3,opt,name=caller" json:"caller,omitempty"` + TraceId uint64 `protobuf:"varint,4,opt,name=trace_id,json=traceId" json:"trace_id,omitempty"` + SpanId uint64 `protobuf:"varint,5,opt,name=span_id,json=spanId" json:"span_id,omitempty"` + ParentId uint64 `protobuf:"varint,6,opt,name=parent_id,json=parentId" json:"parent_id,omitempty"` + // Deprecated: level no long required + Level int32 `protobuf:"varint,7,opt,name=level" json:"level,omitempty"` + // Deprecated: use start_time instead instead of start_at + StartAt int64 `protobuf:"varint,8,opt,name=start_at,json=startAt" json:"start_at,omitempty"` + // Deprecated: use duration instead instead of finish_at + FinishAt int64 `protobuf:"varint,9,opt,name=finish_at,json=finishAt" json:"finish_at,omitempty"` + SamplingProbability float32 `protobuf:"fixed32,10,opt,name=sampling_probability,json=samplingProbability" json:"sampling_probability,omitempty"` + Env string `protobuf:"bytes,19,opt,name=env" json:"env,omitempty"` + StartTime *timestamp.Timestamp `protobuf:"bytes,20,opt,name=start_time,json=startTime" json:"start_time,omitempty"` + Duration *duration.Duration `protobuf:"bytes,21,opt,name=duration" json:"duration,omitempty"` + References []*SpanRef `protobuf:"bytes,22,rep,name=references" json:"references,omitempty"` + Tags []*Tag `protobuf:"bytes,11,rep,name=tags" json:"tags,omitempty"` + Logs []*Log `protobuf:"bytes,12,rep,name=logs" json:"logs,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Span) Reset() { *m = Span{} } +func (m *Span) String() string { return proto.CompactTextString(m) } +func (*Span) ProtoMessage() {} +func (*Span) Descriptor() ([]byte, []int) { + return fileDescriptor_span_68a8dae26ef502a2, []int{4} +} +func (m *Span) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Span.Unmarshal(m, b) +} +func (m *Span) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Span.Marshal(b, m, deterministic) +} +func (dst *Span) XXX_Merge(src proto.Message) { + xxx_messageInfo_Span.Merge(dst, src) +} +func (m *Span) XXX_Size() int { + return xxx_messageInfo_Span.Size(m) +} +func (m *Span) XXX_DiscardUnknown() { + xxx_messageInfo_Span.DiscardUnknown(m) +} + +var xxx_messageInfo_Span proto.InternalMessageInfo + +func (m *Span) GetVersion() int32 { + if m != nil { + return m.Version + } + return 0 +} + +func (m *Span) GetServiceName() string { + if m != nil { + return m.ServiceName + } + return "" +} + +func (m *Span) GetOperationName() string { + if m != nil { + return m.OperationName + } + return "" +} + +func (m *Span) GetCaller() string { + if m != nil { + return m.Caller + } + return "" +} + +func (m *Span) GetTraceId() uint64 { + if m != nil { + return m.TraceId + } + return 0 +} + +func (m *Span) GetSpanId() uint64 { + if m != nil { + return m.SpanId + } + return 0 +} + +func (m *Span) GetParentId() uint64 { + if m != nil { + return m.ParentId + } + return 0 +} + +func (m *Span) GetLevel() int32 { + if m != nil { + return m.Level + } + return 0 +} + +func (m *Span) GetStartAt() int64 { + if m != nil { + return m.StartAt + } + return 0 +} + +func (m *Span) GetFinishAt() int64 { + if m != nil { + return m.FinishAt + } + return 0 +} + +func (m *Span) GetSamplingProbability() float32 { + if m != nil { + return m.SamplingProbability + } + return 0 +} + +func (m *Span) GetEnv() string { + if m != nil { + return m.Env + } + return "" +} + +func (m *Span) GetStartTime() *timestamp.Timestamp { + if m != nil { + return m.StartTime + } + return nil +} + +func (m *Span) GetDuration() *duration.Duration { + if m != nil { + return m.Duration + } + return nil +} + +func (m *Span) GetReferences() []*SpanRef { + if m != nil { + return m.References + } + return nil +} + +func (m *Span) GetTags() []*Tag { + if m != nil { + return m.Tags + } + return nil +} + +func (m *Span) GetLogs() []*Log { + if m != nil { + return m.Logs + } + return nil +} + +func init() { + proto.RegisterType((*Tag)(nil), "dapper.trace.Tag") + proto.RegisterType((*Field)(nil), "dapper.trace.Field") + proto.RegisterType((*Log)(nil), "dapper.trace.Log") + proto.RegisterType((*SpanRef)(nil), "dapper.trace.SpanRef") + proto.RegisterType((*Span)(nil), "dapper.trace.Span") + proto.RegisterEnum("dapper.trace.Tag_Kind", Tag_Kind_name, Tag_Kind_value) + proto.RegisterEnum("dapper.trace.Log_Kind", Log_Kind_name, Log_Kind_value) + proto.RegisterEnum("dapper.trace.SpanRef_RefType", SpanRef_RefType_name, SpanRef_RefType_value) +} + +func init() { proto.RegisterFile("proto/span.proto", fileDescriptor_span_68a8dae26ef502a2) } + +var fileDescriptor_span_68a8dae26ef502a2 = []byte{ + // 669 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x94, 0xdd, 0x6e, 0xd3, 0x4a, + 0x10, 0xc7, 0xeb, 0xd8, 0x89, 0x9d, 0x49, 0x4e, 0xe5, 0xb3, 0xfd, 0x38, 0xdb, 0x1e, 0x3e, 0x4c, + 0xa4, 0x4a, 0x06, 0x24, 0x07, 0x82, 0x2a, 0xc1, 0x65, 0x4b, 0x15, 0x88, 0x30, 0x0d, 0xda, 0x46, + 0x42, 0xe2, 0x26, 0xda, 0x24, 0x63, 0x63, 0xd5, 0xb1, 0x2d, 0x7b, 0x1b, 0x29, 0xcf, 0xc0, 0x5b, + 0xf0, 0x50, 0xdc, 0xf1, 0x2e, 0x68, 0xd7, 0x4e, 0x9a, 0xd2, 0x22, 0x04, 0x77, 0x3b, 0xf3, 0xff, + 0xed, 0xce, 0xcc, 0xfa, 0xbf, 0x06, 0x3b, 0xcb, 0x53, 0x91, 0x76, 0x8b, 0x8c, 0x27, 0x9e, 0x5a, + 0x92, 0xf6, 0x8c, 0x67, 0x19, 0xe6, 0x9e, 0xc8, 0xf9, 0x14, 0x0f, 0x1f, 0x86, 0x69, 0x1a, 0xc6, + 0xd8, 0x55, 0xda, 0xe4, 0x2a, 0xe8, 0x8a, 0x68, 0x8e, 0x85, 0xe0, 0xf3, 0xac, 0xc4, 0x0f, 0x1f, + 0xfc, 0x0c, 0xcc, 0xae, 0x72, 0x2e, 0xa2, 0xb4, 0x3a, 0xae, 0xf3, 0x45, 0x03, 0x7d, 0xc4, 0x43, + 0x62, 0x83, 0x7e, 0x89, 0x4b, 0xaa, 0x39, 0x9a, 0xdb, 0x64, 0x72, 0x49, 0x9e, 0x80, 0x71, 0x19, + 0x25, 0x33, 0x5a, 0x73, 0x34, 0x77, 0xbb, 0xb7, 0xef, 0x6d, 0xd6, 0xf5, 0x46, 0x3c, 0xf4, 0xde, + 0x45, 0xc9, 0x8c, 0x29, 0x86, 0xec, 0x42, 0x7d, 0xc1, 0xe3, 0x2b, 0xa4, 0xba, 0xa3, 0xb9, 0x6d, + 0x56, 0x06, 0x9d, 0x67, 0x60, 0x48, 0x86, 0x00, 0x34, 0x2e, 0x46, 0x6c, 0x70, 0xfe, 0xc6, 0xde, + 0x22, 0x26, 0xe8, 0x83, 0xf3, 0x91, 0xad, 0x11, 0x0b, 0x8c, 0xd3, 0xe1, 0xd0, 0xb7, 0x6b, 0xa4, + 0x09, 0xf5, 0xbe, 0x3f, 0x3c, 0x19, 0xd9, 0x7a, 0xa7, 0x0b, 0xf5, 0x7e, 0x84, 0xf1, 0xec, 0x8e, + 0x76, 0xd6, 0x25, 0x6a, 0x9b, 0x25, 0xbe, 0x69, 0xa0, 0xfb, 0xe9, 0x1f, 0xb7, 0xef, 0xa7, 0xbf, + 0x6f, 0x9f, 0xdc, 0x83, 0xe6, 0xfa, 0x36, 0xa9, 0xe1, 0x68, 0xae, 0xce, 0xae, 0x13, 0xe4, 0x29, + 0x34, 0x02, 0xd9, 0x6a, 0x41, 0xeb, 0x8e, 0xee, 0xb6, 0x7a, 0x3b, 0x37, 0x2b, 0xa8, 0x31, 0x58, + 0x85, 0xfc, 0xc5, 0x4d, 0x7c, 0xd5, 0xc0, 0xbc, 0xc8, 0x78, 0xc2, 0x30, 0x20, 0x2f, 0xc1, 0xca, + 0x31, 0x18, 0x8b, 0x65, 0x86, 0x6a, 0xc2, 0xed, 0xde, 0xfd, 0x9b, 0xc5, 0x2a, 0xd0, 0x63, 0x18, + 0x8c, 0x96, 0x19, 0x32, 0x33, 0x2f, 0x17, 0xe4, 0x00, 0x2c, 0x45, 0x8c, 0xa3, 0xf2, 0x22, 0x0c, + 0x66, 0xaa, 0x78, 0x30, 0x23, 0xff, 0x81, 0x29, 0x5d, 0x25, 0x15, 0x5d, 0x29, 0x0d, 0x19, 0x0e, + 0x66, 0x9d, 0xc7, 0x60, 0x56, 0xe7, 0x90, 0x36, 0x58, 0xaf, 0xdf, 0x0e, 0xfc, 0xb3, 0xf1, 0xb0, + 0x6f, 0x6f, 0x11, 0x1b, 0xda, 0xfd, 0xa1, 0xef, 0x0f, 0x3f, 0x5e, 0x8c, 0xfb, 0x6c, 0xf8, 0xde, + 0xd6, 0x3a, 0xdf, 0x0d, 0x30, 0x64, 0x6d, 0x42, 0xc1, 0x5c, 0x60, 0x5e, 0x44, 0x69, 0x42, 0xa7, + 0x8e, 0xe6, 0xd6, 0xd9, 0x2a, 0x24, 0x8f, 0xa0, 0x5d, 0x60, 0xbe, 0x88, 0xa6, 0x38, 0x4e, 0xf8, + 0x1c, 0xab, 0x2f, 0xd4, 0xaa, 0x72, 0xe7, 0x7c, 0x8e, 0xe4, 0x08, 0xb6, 0xd3, 0x0c, 0x4b, 0x57, + 0x96, 0x50, 0x4d, 0x41, 0xff, 0xac, 0xb3, 0x0a, 0xdb, 0x87, 0xc6, 0x94, 0xc7, 0x31, 0xe6, 0xaa, + 0xdf, 0x26, 0xab, 0xa2, 0x1b, 0x33, 0x1a, 0xbf, 0x9c, 0xb1, 0xbe, 0x39, 0x23, 0xf9, 0x1f, 0x9a, + 0x19, 0xcf, 0x31, 0x11, 0x52, 0x6a, 0x28, 0xc9, 0x2a, 0x13, 0x03, 0xe5, 0x86, 0x18, 0x17, 0x18, + 0x53, 0x53, 0x8d, 0x52, 0x06, 0xb2, 0x4c, 0x21, 0x78, 0x2e, 0xc6, 0x5c, 0x50, 0x4b, 0x99, 0xc1, + 0x54, 0xf1, 0x89, 0x90, 0xa7, 0x05, 0x51, 0x12, 0x15, 0x9f, 0xa5, 0xd6, 0x54, 0x9a, 0x55, 0x26, + 0x4e, 0x04, 0x79, 0x0e, 0xbb, 0x05, 0x9f, 0x67, 0x71, 0x94, 0x84, 0xe3, 0x2c, 0x4f, 0x27, 0x7c, + 0x12, 0xc5, 0x91, 0x58, 0x52, 0x70, 0x34, 0xb7, 0xc6, 0x76, 0x56, 0xda, 0x87, 0x6b, 0x49, 0x9a, + 0x19, 0x93, 0x05, 0xdd, 0x29, 0xcd, 0x8c, 0xc9, 0x82, 0xbc, 0x02, 0x28, 0x8b, 0x4b, 0xff, 0xd1, + 0x5d, 0x47, 0x73, 0x5b, 0xbd, 0x43, 0xaf, 0x7c, 0xda, 0xde, 0xea, 0x69, 0x7b, 0xa3, 0x95, 0x39, + 0x59, 0x53, 0xd1, 0x32, 0x26, 0xc7, 0x60, 0xad, 0x9e, 0x3c, 0xdd, 0x53, 0x1b, 0x0f, 0x6e, 0x6d, + 0x3c, 0xab, 0x00, 0xb6, 0x46, 0xc9, 0x31, 0x40, 0x8e, 0x01, 0xe6, 0x98, 0x4c, 0xb1, 0xa0, 0xfb, + 0xca, 0xe2, 0x7b, 0x77, 0xba, 0x8e, 0x6d, 0x80, 0xe4, 0x08, 0x0c, 0xc1, 0xc3, 0x82, 0xb6, 0xd4, + 0x86, 0x7f, 0x6f, 0xfd, 0x34, 0x98, 0x92, 0x25, 0x16, 0xa7, 0x61, 0x41, 0xdb, 0x77, 0x61, 0x7e, + 0x1a, 0x32, 0x25, 0x9f, 0xc2, 0x27, 0x4b, 0xf5, 0x18, 0x62, 0x32, 0x69, 0xa8, 0xd5, 0x8b, 0x1f, + 0x01, 0x00, 0x00, 0xff, 0xff, 0xfe, 0x7b, 0x57, 0x93, 0x12, 0x05, 0x00, 0x00, +} diff --git a/pkg/net/trace/proto/span.proto b/pkg/net/trace/proto/span.proto new file mode 100644 index 000000000..6e9e2ed76 --- /dev/null +++ b/pkg/net/trace/proto/span.proto @@ -0,0 +1,77 @@ +syntax = "proto3"; +package dapper.trace; + +import "google/protobuf/timestamp.proto"; +import "google/protobuf/duration.proto"; + +option go_package = "protogen"; + +message Tag { + enum Kind { + STRING = 0; + INT = 1; + BOOL = 2; + FLOAT = 3; + } + string key = 1; + Kind kind = 2; + bytes value = 3; +} + +message Field { + string key = 1; + bytes value = 2; +} + +message Log { + // Deprecated: Kind no long use + enum Kind { + STRING = 0; + INT = 1; + BOOL = 2; + FLOAT = 3; + } + string key = 1; + // Deprecated: Kind no long use + Kind kind = 2; + // Deprecated: Value no long use + bytes value = 3; + int64 timestamp = 4; + repeated Field fields = 5; +} + +// SpanRef describes causal relationship of the current span to another span (e.g. 'child-of') +message SpanRef { + enum RefType { + CHILD_OF = 0; + FOLLOWS_FROM = 1; + } + RefType ref_type = 1; + uint64 trace_id = 2; + uint64 span_id = 3; +} + +// Span represents a named unit of work performed by a service. +message Span { + int32 version = 99; + string service_name = 1; + string operation_name = 2; + // Deprecated: caller no long required + string caller = 3; + uint64 trace_id = 4; + uint64 span_id = 5; + uint64 parent_id = 6; + // Deprecated: level no long required + int32 level = 7; + // Deprecated: use start_time instead instead of start_at + int64 start_at = 8; + // Deprecated: use duration instead instead of finish_at + int64 finish_at = 9; + float sampling_probability = 10; + string env = 19; + google.protobuf.Timestamp start_time = 20; + google.protobuf.Duration duration = 21; + repeated SpanRef references = 22; + repeated Tag tags = 11; + repeated Log logs = 12; +} diff --git a/pkg/net/trace/report.go b/pkg/net/trace/report.go new file mode 100644 index 000000000..e9b8e4122 --- /dev/null +++ b/pkg/net/trace/report.go @@ -0,0 +1,138 @@ +package trace + +import ( + "fmt" + "net" + "os" + "sync" + "time" +) + +const ( + // MaxPackageSize . + _maxPackageSize = 1024 * 32 + // safe udp package size // MaxPackageSize = 508 _dataChSize = 4096) + // max memory usage 1024 * 32 * 4096 -> 128MB + _dataChSize = 4096 + _defaultWriteChannalTimeout = 50 * time.Millisecond + _defaultWriteTimeout = 200 * time.Millisecond +) + +// reporter trace reporter. +type reporter interface { + WriteSpan(sp *span) error + Close() error +} + +// newReport with network address +func newReport(network, address string, timeout time.Duration, protocolVersion int32) reporter { + if timeout == 0 { + timeout = _defaultWriteTimeout + } + report := &connReport{ + network: network, + address: address, + dataCh: make(chan []byte, _dataChSize), + done: make(chan struct{}), + timeout: timeout, + version: protocolVersion, + } + go report.daemon() + return report +} + +type connReport struct { + version int32 + rmx sync.RWMutex + closed bool + + network, address string + + dataCh chan []byte + + conn net.Conn + + done chan struct{} + + timeout time.Duration +} + +func (c *connReport) daemon() { + for b := range c.dataCh { + c.send(b) + } + c.done <- struct{}{} +} + +func (c *connReport) WriteSpan(sp *span) error { + data, err := marshalSpan(sp, c.version) + if err != nil { + return err + } + return c.writePackage(data) +} + +func (c *connReport) writePackage(data []byte) error { + c.rmx.RLock() + defer c.rmx.RUnlock() + if c.closed { + return fmt.Errorf("report already closed") + } + if len(data) > _maxPackageSize { + return fmt.Errorf("package too large length %d > %d", len(data), _maxPackageSize) + } + select { + case c.dataCh <- data: + return nil + case <-time.After(_defaultWriteChannalTimeout): + return fmt.Errorf("write to data channel timeout") + } +} + +func (c *connReport) Close() error { + c.rmx.Lock() + c.closed = true + c.rmx.Unlock() + + t := time.NewTimer(time.Second) + close(c.dataCh) + select { + case <-t.C: + c.closeConn() + return fmt.Errorf("close report timeout force close") + case <-c.done: + return c.closeConn() + } +} + +func (c *connReport) send(data []byte) { + if c.conn == nil { + if err := c.reconnect(); err != nil { + c.Errorf("connect error: %s retry after second", err) + time.Sleep(time.Second) + return + } + } + c.conn.SetWriteDeadline(time.Now().Add(100 * time.Microsecond)) + if _, err := c.conn.Write(data); err != nil { + c.Errorf("write to conn error: %s, close connect", err) + c.conn.Close() + c.conn = nil + } +} + +func (c *connReport) reconnect() (err error) { + c.conn, err = net.DialTimeout(c.network, c.address, c.timeout) + return +} + +func (c *connReport) closeConn() error { + if c.conn != nil { + return c.conn.Close() + } + return nil +} + +func (c *connReport) Errorf(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, format+"\n", args...) +} diff --git a/pkg/net/trace/report_test.go b/pkg/net/trace/report_test.go new file mode 100644 index 000000000..4d0652921 --- /dev/null +++ b/pkg/net/trace/report_test.go @@ -0,0 +1,88 @@ +package trace + +import ( + "bytes" + "io" + "log" + "net" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func newServer(w io.Writer, network, address string) (func() error, error) { + lis, err := net.Listen(network, address) + if err != nil { + return nil, err + } + done := make(chan struct{}) + go func() { + conn, err := lis.Accept() + if err != nil { + lis.Close() + log.Fatal(err) + } + io.Copy(w, conn) + conn.Close() + done <- struct{}{} + }() + return func() error { + <-done + return lis.Close() + }, nil +} + +func TestReportTCP(t *testing.T) { + buf := &bytes.Buffer{} + cancel, err := newServer(buf, "tcp", "127.0.0.1:6077") + if err != nil { + t.Fatal(err) + } + report := newReport("tcp", "127.0.0.1:6077", 0, 0).(*connReport) + data := []byte("hello, world") + report.writePackage(data) + if err := report.Close(); err != nil { + t.Error(err) + } + cancel() + assert.Equal(t, data, buf.Bytes(), "receive data") +} + +func newUnixgramServer(w io.Writer, address string) (func() error, error) { + conn, err := net.ListenPacket("unixgram", address) + if err != nil { + return nil, err + } + done := make(chan struct{}) + go func() { + p := make([]byte, 4096) + n, _, err := conn.ReadFrom(p) + if err != nil { + log.Fatal(err) + } + w.Write(p[:n]) + done <- struct{}{} + }() + return func() error { + <-done + return conn.Close() + }, nil +} + +func TestReportUnixgram(t *testing.T) { + os.Remove("/tmp/trace.sock") + buf := &bytes.Buffer{} + cancel, err := newUnixgramServer(buf, "/tmp/trace.sock") + if err != nil { + t.Fatal(err) + } + report := newReport("unixgram", "/tmp/trace.sock", 0, 0).(*connReport) + data := []byte("hello, world") + report.writePackage(data) + if err := report.Close(); err != nil { + t.Error(err) + } + cancel() + assert.Equal(t, data, buf.Bytes(), "receive data") +} diff --git a/pkg/net/trace/sample.go b/pkg/net/trace/sample.go new file mode 100644 index 000000000..57f6c53fa --- /dev/null +++ b/pkg/net/trace/sample.go @@ -0,0 +1,67 @@ +package trace + +import ( + "math/rand" + "sync/atomic" + "time" +) + +const ( + slotLength = 2048 +) + +var ignoreds = []string{"/metrics", "/monitor/ping"} + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func oneAtTimeHash(s string) (hash uint32) { + b := []byte(s) + for i := range b { + hash += uint32(b[i]) + hash += hash << 10 + hash ^= hash >> 6 + } + hash += hash << 3 + hash ^= hash >> 11 + hash += hash << 15 + return +} + +// sampler decides whether a new trace should be sampled or not. +type sampler interface { + IsSampled(traceID uint64, operationName string) (bool, float32) + Close() error +} + +type probabilitySampling struct { + probability float32 + slot [slotLength]int64 +} + +func (p *probabilitySampling) IsSampled(traceID uint64, operationName string) (bool, float32) { + for _, ignored := range ignoreds { + if operationName == ignored { + return false, 0 + } + } + now := time.Now().Unix() + idx := oneAtTimeHash(operationName) % slotLength + old := atomic.LoadInt64(&p.slot[idx]) + if old != now { + atomic.SwapInt64(&p.slot[idx], now) + return true, 1 + } + return rand.Float32() < float32(p.probability), float32(p.probability) +} + +func (p *probabilitySampling) Close() error { return nil } + +// newSampler new probability sampler +func newSampler(probability float32) sampler { + if probability <= 0 || probability > 1 { + panic("probability P ∈ (0, 1]") + } + return &probabilitySampling{probability: probability} +} diff --git a/pkg/net/trace/sample_test.go b/pkg/net/trace/sample_test.go new file mode 100644 index 000000000..f3543622b --- /dev/null +++ b/pkg/net/trace/sample_test.go @@ -0,0 +1,35 @@ +package trace + +import ( + "testing" +) + +func TestProbabilitySampling(t *testing.T) { + sampler := newSampler(0.001) + t.Run("test one operationName", func(t *testing.T) { + sampled, probability := sampler.IsSampled(0, "test123") + if !sampled || probability != 1 { + t.Errorf("expect sampled and probability == 1 get: %v %f", sampled, probability) + } + }) + t.Run("test probability", func(t *testing.T) { + sampler.IsSampled(0, "test_opt_2") + count := 0 + for i := 0; i < 100000; i++ { + sampled, _ := sampler.IsSampled(0, "test_opt_2") + if sampled { + count++ + } + } + if count < 80 || count > 120 { + t.Errorf("expect count between 80~120 get %d", count) + } + }) +} + +func BenchmarkProbabilitySampling(b *testing.B) { + sampler := newSampler(0.001) + for i := 0; i < b.N; i++ { + sampler.IsSampled(0, "test_opt_xxx") + } +} diff --git a/pkg/net/trace/span.go b/pkg/net/trace/span.go new file mode 100644 index 000000000..1cd2eb101 --- /dev/null +++ b/pkg/net/trace/span.go @@ -0,0 +1,108 @@ +package trace + +import ( + "fmt" + "time" + + protogen "github.com/bilibili/Kratos/pkg/net/trace/proto" +) + +const ( + _maxChilds = 1024 + _maxTags = 128 + _maxLogs = 256 +) + +var _ Trace = &span{} + +type span struct { + dapper *dapper + context spanContext + operationName string + startTime time.Time + duration time.Duration + tags []Tag + logs []*protogen.Log + childs int +} + +func (s *span) Fork(serviceName, operationName string) Trace { + if s.childs > _maxChilds { + // if child span more than max childs set return noopspan + return noopspan{} + } + s.childs++ + // 为了兼容临时为 New 的 Span 设置 span.kind + return s.dapper.newSpanWithContext(operationName, s.context).SetTag(TagString(TagSpanKind, "client")) +} + +func (s *span) Follow(serviceName, operationName string) Trace { + return s.Fork(serviceName, operationName).SetTag(TagString(TagSpanKind, "producer")) +} + +func (s *span) Finish(perr *error) { + s.duration = time.Since(s.startTime) + if perr != nil && *perr != nil { + err := *perr + s.SetTag(TagBool(TagError, true)) + s.SetLog(Log(LogMessage, err.Error())) + if err, ok := err.(stackTracer); ok { + s.SetLog(Log(LogStack, fmt.Sprintf("%+v", err.StackTrace()))) + } + } + s.dapper.report(s) +} + +func (s *span) SetTag(tags ...Tag) Trace { + if !s.context.isSampled() && !s.context.isDebug() { + return s + } + if len(s.tags) < _maxTags { + s.tags = append(s.tags, tags...) + } + if len(s.tags) == _maxTags { + s.tags = append(s.tags, Tag{Key: "trace.error", Value: "too many tags"}) + } + return s +} + +// LogFields is an efficient and type-checked way to record key:value +// NOTE current unsupport +func (s *span) SetLog(logs ...LogField) Trace { + if !s.context.isSampled() && !s.context.isDebug() { + return s + } + if len(s.logs) < _maxLogs { + s.setLog(logs...) + } + if len(s.logs) == _maxLogs { + s.setLog(LogField{Key: "trace.error", Value: "too many logs"}) + } + return s +} + +func (s *span) setLog(logs ...LogField) Trace { + protoLog := &protogen.Log{ + Timestamp: time.Now().UnixNano(), + Fields: make([]*protogen.Field, len(logs)), + } + for i := range logs { + protoLog.Fields[i] = &protogen.Field{Key: logs[i].Key, Value: []byte(logs[i].Value)} + } + s.logs = append(s.logs, protoLog) + return s +} + +// Visit visits the k-v pair in trace, calling fn for each. +func (s *span) Visit(fn func(k, v string)) { + fn(KratosTraceID, s.context.String()) +} + +// SetTitle reset trace title +func (s *span) SetTitle(operationName string) { + s.operationName = operationName +} + +func (s *span) String() string { + return s.context.String() +} diff --git a/pkg/net/trace/span_test.go b/pkg/net/trace/span_test.go new file mode 100644 index 000000000..17a3cef4c --- /dev/null +++ b/pkg/net/trace/span_test.go @@ -0,0 +1,108 @@ +package trace + +import ( + "fmt" + "strconv" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestSpan(t *testing.T) { + report := &mockReport{} + t1 := newTracer("service1", report, &Config{DisableSample: true}) + t.Run("test span string", func(t *testing.T) { + sp1 := t1.New("testfinish").(*span) + assert.NotEmpty(t, fmt.Sprint(sp1)) + }) + t.Run("test fork", func(t *testing.T) { + sp1 := t1.New("testfork").(*span) + sp2 := sp1.Fork("xxx", "opt_2").(*span) + assert.Equal(t, sp1.context.traceID, sp2.context.traceID) + assert.Equal(t, sp1.context.spanID, sp2.context.parentID) + t.Run("test max fork", func(t *testing.T) { + sp3 := sp2.Fork("xx", "xxx") + for i := 0; i < 100; i++ { + sp3 = sp3.Fork("", "xxx") + } + assert.Equal(t, noopspan{}, sp3) + }) + t.Run("test max childs", func(t *testing.T) { + sp3 := sp2.Fork("xx", "xxx") + for i := 0; i < 4096; i++ { + sp3.Fork("", "xxx") + } + assert.Equal(t, noopspan{}, sp3.Fork("xx", "xx")) + }) + }) + t.Run("test finish", func(t *testing.T) { + t.Run("test finish ok", func(t *testing.T) { + sp1 := t1.New("testfinish").(*span) + time.Sleep(time.Millisecond) + sp1.Finish(nil) + assert.True(t, sp1.startTime.Unix() > 0) + assert.True(t, sp1.duration > time.Microsecond) + }) + t.Run("test finish error", func(t *testing.T) { + sp1 := t1.New("testfinish").(*span) + time.Sleep(time.Millisecond) + err := fmt.Errorf("🍻") + sp1.Finish(&err) + assert.True(t, sp1.startTime.Unix() > 0) + assert.True(t, sp1.duration > time.Microsecond) + errorTag := false + for _, tag := range sp1.tags { + if tag.Key == TagError && tag.Value != nil { + errorTag = true + } + } + assert.True(t, errorTag) + messageLog := false + for _, log := range sp1.logs { + assert.True(t, log.Timestamp != 0) + for _, field := range log.Fields { + if field.Key == LogMessage && len(field.Value) != 0 { + messageLog = true + } + } + } + assert.True(t, messageLog) + }) + t.Run("test finish error stack", func(t *testing.T) { + sp1 := t1.New("testfinish").(*span) + time.Sleep(time.Millisecond) + err := fmt.Errorf("🍻") + err = errors.WithStack(err) + sp1.Finish(&err) + ok := false + for _, log := range sp1.logs { + for _, field := range log.Fields { + if field.Key == LogStack && len(field.Value) != 0 { + ok = true + } + } + } + assert.True(t, ok, "LogStack set") + }) + t.Run("test too many tags", func(t *testing.T) { + sp1 := t1.New("testfinish").(*span) + for i := 0; i < 1024; i++ { + sp1.SetTag(Tag{Key: strconv.Itoa(i), Value: "hello"}) + } + assert.Len(t, sp1.tags, _maxTags+1) + assert.Equal(t, sp1.tags[_maxTags].Key, "trace.error") + assert.Equal(t, sp1.tags[_maxTags].Value, "too many tags") + }) + t.Run("test too many logs", func(t *testing.T) { + sp1 := t1.New("testfinish").(*span) + for i := 0; i < 1024; i++ { + sp1.SetLog(LogField{Key: strconv.Itoa(i), Value: "hello"}) + } + assert.Len(t, sp1.logs, _maxLogs+1) + assert.Equal(t, sp1.logs[_maxLogs].Fields[0].Key, "trace.error") + assert.Equal(t, sp1.logs[_maxLogs].Fields[0].Value, []byte("too many logs")) + }) + }) +} diff --git a/pkg/net/trace/tag.go b/pkg/net/trace/tag.go new file mode 100644 index 000000000..3c7329842 --- /dev/null +++ b/pkg/net/trace/tag.go @@ -0,0 +1,182 @@ +package trace + +// Standard Span tags https://github.com/opentracing/specification/blob/master/semantic_conventions.md#span-tags-table +const ( + // The software package, framework, library, or module that generated the associated Span. + // E.g., "grpc", "django", "JDBI". + // type string + TagComponent = "component" + + // Database instance name. + // E.g., In java, if the jdbc.url="jdbc:mysql://127.0.0.1:3306/customers", the instance name is "customers". + // type string + TagDBInstance = "db.instance" + + // A database statement for the given database type. + // E.g., for db.type="sql", "SELECT * FROM wuser_table"; for db.type="redis", "SET mykey 'WuValue'". + TagDBStatement = "db.statement" + + // Database type. For any SQL database, "sql". For others, the lower-case database category, + // e.g. "cassandra", "hbase", or "redis". + // type string + TagDBType = "db.type" + + // Username for accessing database. E.g., "readonly_user" or "reporting_user" + // type string + TagDBUser = "db.user" + + // true if and only if the application considers the operation represented by the Span to have failed + // type bool + TagError = "error" + + // HTTP method of the request for the associated Span. E.g., "GET", "POST" + // type string + TagHTTPMethod = "http.method" + + // HTTP response status code for the associated Span. E.g., 200, 503, 404 + // type integer + TagHTTPStatusCode = "http.status_code" + + // URL of the request being handled in this segment of the trace, in standard URI format. + // E.g., "https://domain.net/path/to?resource=here" + // type string + TagHTTPURL = "http.url" + + // An address at which messages can be exchanged. + // E.g. A Kafka record has an associated "topic name" that can be extracted by the instrumented producer or consumer and stored using this tag. + // type string + TagMessageBusDestination = "message_bus.destination" + + // Remote "address", suitable for use in a networking client library. + // This may be a "ip:port", a bare "hostname", a FQDN, or even a JDBC substring like "mysql://prod-db:3306" + // type string + TagPeerAddress = "peer.address" + + // Remote hostname. E.g., "opentracing.io", "internal.dns.name" + // type string + TagPeerHostname = "peer.hostname" + + // Remote IPv4 address as a .-separated tuple. E.g., "127.0.0.1" + // type string + TagPeerIPv4 = "peer.ipv4" + + // Remote IPv6 address as a string of colon-separated 4-char hex tuples. + // E.g., "2001:0db8:85a3:0000:0000:8a2e:0370:7334" + // type string + TagPeerIPv6 = "peer.ipv6" + + // Remote port. E.g., 80 + // type integer + TagPeerPort = "peer.port" + + // Remote service name (for some unspecified definition of "service"). + // E.g., "elasticsearch", "a_custom_microservice", "memcache" + // type string + TagPeerService = "peer.service" + + // If greater than 0, a hint to the Tracer to do its best to capture the trace. + // If 0, a hint to the trace to not-capture the trace. If absent, the Tracer should use its default sampling mechanism. + // type string + TagSamplingPriority = "sampling.priority" + + // Either "client" or "server" for the appropriate roles in an RPC, + // and "producer" or "consumer" for the appropriate roles in a messaging scenario. + // type string + TagSpanKind = "span.kind" + + // legacy tag + TagAnnotation = "legacy.annotation" + TagAddress = "legacy.address" + TagComment = "legacy.comment" +) + +// Standard log tags +const ( + // The type or "kind" of an error (only for event="error" logs). E.g., "Exception", "OSError" + // type string + LogErrorKind = "error.kind" + + // For languages that support such a thing (e.g., Java, Python), + // the actual Throwable/Exception/Error object instance itself. + // E.g., A java.lang.UnsupportedOperationException instance, a python exceptions.NameError instance + // type string + LogErrorObject = "error.object" + + // A stable identifier for some notable moment in the lifetime of a Span. For instance, a mutex lock acquisition or release or the sorts of lifetime events in a browser page load described in the Performance.timing specification. E.g., from Zipkin, "cs", "sr", "ss", or "cr". Or, more generally, "initialized" or "timed out". For errors, "error" + // type string + LogEvent = "event" + + // A concise, human-readable, one-line message explaining the event. + // E.g., "Could not connect to backend", "Cache invalidation succeeded" + // type string + LogMessage = "message" + + // A stack trace in platform-conventional format; may or may not pertain to an error. E.g., "File \"example.py\", line 7, in \\ncaller()\nFile \"example.py\", line 5, in caller\ncallee()\nFile \"example.py\", line 2, in callee\nraise Exception(\"Yikes\")\n" + // type string + LogStack = "stack" +) + +// Tag interface +type Tag struct { + Key string + Value interface{} +} + +// TagString new string tag. +func TagString(key string, val string) Tag { + return Tag{Key: key, Value: val} +} + +// TagInt64 new int64 tag. +func TagInt64(key string, val int64) Tag { + return Tag{Key: key, Value: val} +} + +// TagInt new int tag +func TagInt(key string, val int) Tag { + return Tag{Key: key, Value: val} +} + +// TagBool new bool tag +func TagBool(key string, val bool) Tag { + return Tag{Key: key, Value: val} +} + +// TagFloat64 new float64 tag +func TagFloat64(key string, val float64) Tag { + return Tag{Key: key, Value: val} +} + +// TagFloat32 new float64 tag +func TagFloat32(key string, val float32) Tag { + return Tag{Key: key, Value: val} +} + +// String new tag String. +// NOTE: use TagString +func String(key string, val string) Tag { + return TagString(key, val) +} + +// Int new tag Int. +// NOTE: use TagInt +func Int(key string, val int) Tag { + return TagInt(key, val) +} + +// Bool new tagBool +// NOTE: use TagBool +func Bool(key string, val bool) Tag { + return TagBool(key, val) +} + +// Log new log. +func Log(key string, val string) LogField { + return LogField{Key: key, Value: val} +} + +// LogField LogField +type LogField struct { + Key string + Value string +} diff --git a/pkg/net/trace/tag_test.go b/pkg/net/trace/tag_test.go new file mode 100644 index 000000000..15c5a94c3 --- /dev/null +++ b/pkg/net/trace/tag_test.go @@ -0,0 +1 @@ +package trace diff --git a/pkg/net/trace/tracer.go b/pkg/net/trace/tracer.go new file mode 100644 index 000000000..825ebe2e6 --- /dev/null +++ b/pkg/net/trace/tracer.go @@ -0,0 +1,92 @@ +package trace + +import ( + "io" +) + +var ( + // global tracer + _tracer Tracer = nooptracer{} +) + +// SetGlobalTracer SetGlobalTracer +func SetGlobalTracer(tracer Tracer) { + _tracer = tracer +} + +// Tracer is a simple, thin interface for Trace creation and propagation. +type Tracer interface { + // New trace instance with given title. + New(operationName string, opts ...Option) Trace + // Inject takes the Trace instance and injects it for + // propagation within `carrier`. The actual type of `carrier` depends on + // the value of `format`. + Inject(t Trace, format interface{}, carrier interface{}) error + // Extract returns a Trace instance given `format` and `carrier`. + // return `ErrTraceNotFound` if trace not found. + Extract(format interface{}, carrier interface{}) (Trace, error) +} + +// New trace instance with given operationName. +func New(operationName string, opts ...Option) Trace { + return _tracer.New(operationName, opts...) +} + +// Inject takes the Trace instance and injects it for +// propagation within `carrier`. The actual type of `carrier` depends on +// the value of `format`. +func Inject(t Trace, format interface{}, carrier interface{}) error { + return _tracer.Inject(t, format, carrier) +} + +// Extract returns a Trace instance given `format` and `carrier`. +// return `ErrTraceNotFound` if trace not found. +func Extract(format interface{}, carrier interface{}) (Trace, error) { + return _tracer.Extract(format, carrier) +} + +// Close trace flush data. +func Close() error { + if closer, ok := _tracer.(io.Closer); ok { + return closer.Close() + } + return nil +} + +// Trace trace common interface. +type Trace interface { + // Fork fork a trace with client trace. + Fork(serviceName, operationName string) Trace + + // Follow + Follow(serviceName, operationName string) Trace + + // Finish when trace finish call it. + Finish(err *error) + + // Scan scan trace into info. + // Deprecated: method Scan is deprecated, use Inject instead of Scan + // Scan(ti *Info) + + // Adds a tag to the trace. + // + // If there is a pre-existing tag set for `key`, it is overwritten. + // + // Tag values can be numeric types, strings, or bools. The behavior of + // other tag value types is undefined at the OpenTracing level. If a + // tracing system does not know how to handle a particular value type, it + // may ignore the tag, but shall not panic. + // NOTE current only support legacy tag: TagAnnotation TagAddress TagComment + // other will be ignore + SetTag(tags ...Tag) Trace + + // LogFields is an efficient and type-checked way to record key:value + // NOTE current unsupport + SetLog(logs ...LogField) Trace + + // Visit visits the k-v pair in trace, calling fn for each. + Visit(fn func(k, v string)) + + // SetTitle reset trace title + SetTitle(title string) +} diff --git a/pkg/net/trace/util.go b/pkg/net/trace/util.go new file mode 100644 index 000000000..8345e4c54 --- /dev/null +++ b/pkg/net/trace/util.go @@ -0,0 +1,68 @@ +package trace + +import ( + "context" + "encoding/binary" + "math/rand" + "time" + + "github.com/bilibili/Kratos/pkg/conf/env" + "github.com/bilibili/Kratos/pkg/net/ip" + "github.com/bilibili/Kratos/pkg/net/metadata" + + "github.com/pkg/errors" +) + +var _hostHash byte + +func init() { + rand.Seed(time.Now().UnixNano()) + _hostHash = byte(oneAtTimeHash(env.Hostname)) +} + +func extendTag() (tags []Tag) { + tags = append(tags, + TagString("region", env.Region), + TagString("zone", env.Zone), + TagString("hostname", env.Hostname), + TagString("ip", ip.InternalIP()), + ) + return +} + +func genID() uint64 { + var b [8]byte + // i think this code will not survive to 2106-02-07 + binary.BigEndian.PutUint32(b[4:], uint32(time.Now().Unix())>>8) + b[4] = _hostHash + binary.BigEndian.PutUint32(b[:4], uint32(rand.Int31())) + return binary.BigEndian.Uint64(b[:]) +} + +type stackTracer interface { + StackTrace() errors.StackTrace +} + +type ctxKey string + +var _ctxkey ctxKey = "Kratos/pkg/net/trace.trace" + +// FromContext returns the trace bound to the context, if any. +func FromContext(ctx context.Context) (t Trace, ok bool) { + if v := metadata.Value(ctx, metadata.Trace); v != nil { + t, ok = v.(Trace) + return + } + t, ok = ctx.Value(_ctxkey).(Trace) + return +} + +// NewContext new a trace context. +// NOTE: This method is not thread safe. +func NewContext(ctx context.Context, t Trace) context.Context { + if md, ok := metadata.FromContext(ctx); ok { + md[metadata.Trace] = t + return ctx + } + return context.WithValue(ctx, _ctxkey, t) +} diff --git a/pkg/net/trace/util_test.go b/pkg/net/trace/util_test.go new file mode 100644 index 000000000..c6223bd84 --- /dev/null +++ b/pkg/net/trace/util_test.go @@ -0,0 +1,21 @@ +package trace + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFromContext(t *testing.T) { + report := &mockReport{} + t1 := newTracer("service1", report, &Config{DisableSample: true}) + sp1 := t1.New("test123") + ctx := context.Background() + ctx = NewContext(ctx, sp1) + sp2, ok := FromContext(ctx) + if !ok { + t.Fatal("nothing from context") + } + assert.Equal(t, sp1, sp2) +} diff --git a/pkg/stat/README.md b/pkg/stat/README.md new file mode 100644 index 000000000..676067e2f --- /dev/null +++ b/pkg/stat/README.md @@ -0,0 +1,5 @@ +# stat + +## 项目简介 + +数据统计、监控采集等 diff --git a/pkg/stat/prom/README.md b/pkg/stat/prom/README.md new file mode 100644 index 000000000..4bdf9ee3d --- /dev/null +++ b/pkg/stat/prom/README.md @@ -0,0 +1,5 @@ +# prom + +## 项目简介 + +封装prometheus类。TODO:补充grafana通用面板json文件!!! diff --git a/pkg/stat/prom/prometheus.go b/pkg/stat/prom/prometheus.go new file mode 100644 index 000000000..a2dec6090 --- /dev/null +++ b/pkg/stat/prom/prometheus.go @@ -0,0 +1,163 @@ +package prom + +import ( + "flag" + "os" + "sync" + + "github.com/prometheus/client_golang/prometheus" +) + +var ( + // LibClient for mc redis and db client. + LibClient = New().WithTimer("go_lib_client", []string{"method"}).WithCounter("go_lib_client_code", []string{"method", "code"}) + // RPCClient rpc client + RPCClient = New().WithTimer("go_rpc_client", []string{"method"}).WithCounter("go_rpc_client_code", []string{"method", "code"}) + // HTTPClient http client + HTTPClient = New().WithTimer("go_http_client", []string{"method"}).WithCounter("go_http_client_code", []string{"method", "code"}) + // HTTPServer for http server + HTTPServer = New().WithTimer("go_http_server", []string{"user", "method"}).WithCounter("go_http_server_code", []string{"user", "method", "code"}) + // RPCServer for rpc server + RPCServer = New().WithTimer("go_rpc_server", []string{"user", "method"}).WithCounter("go_rpc_server_code", []string{"user", "method", "code"}) + // BusinessErrCount for business err count + BusinessErrCount = New().WithCounter("go_business_err_count", []string{"name"}).WithState("go_business_err_state", []string{"name"}) + // BusinessInfoCount for business info count + BusinessInfoCount = New().WithCounter("go_business_info_count", []string{"name"}).WithState("go_business_info_state", []string{"name"}) + // CacheHit for cache hit + CacheHit = New().WithCounter("go_cache_hit", []string{"name"}) + // CacheMiss for cache miss + CacheMiss = New().WithCounter("go_cache_miss", []string{"name"}) + + // UseSummary use summary for Objectives that defines the quantile rank estimates. + _useSummary bool +) + +// Prom struct info +type Prom struct { + histogram *prometheus.HistogramVec + summary *prometheus.SummaryVec + counter *prometheus.GaugeVec + state *prometheus.GaugeVec + once sync.Once +} + +// New creates a Prom instance. +func New() *Prom { + return &Prom{} +} + +func init() { + addFlag(flag.CommandLine) +} + +func addFlag(fs *flag.FlagSet) { + v := os.Getenv("PROM_SUMMARY") + if v == "true" { + _useSummary = true + } + fs.BoolVar(&_useSummary, "prom_summary", _useSummary, "use summary in prometheus") +} + +// WithTimer with summary timer +func (p *Prom) WithTimer(name string, labels []string) *Prom { + if p == nil { + return p + } + if p.histogram == nil { + p.histogram = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: name, + Help: name, + }, labels) + } + if p.summary == nil { + p.summary = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Name: name, + Help: name, + Objectives: map[float64]float64{0.99: 0.001, 0.9: 0.01}, + }, labels) + } + return p +} + +// WithCounter sets counter. +func (p *Prom) WithCounter(name string, labels []string) *Prom { + if p == nil || p.counter != nil { + return p + } + p.counter = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: name, + Help: name, + }, labels) + prometheus.MustRegister(p.counter) + return p +} + +// WithState sets state. +func (p *Prom) WithState(name string, labels []string) *Prom { + if p == nil || p.state != nil { + return p + } + p.state = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: name, + Help: name, + }, labels) + prometheus.MustRegister(p.state) + return p +} + +// Timing log timing information (in milliseconds) without sampling +func (p *Prom) Timing(name string, time int64, extra ...string) { + p.once.Do(func() { + if _useSummary && p.summary != nil { + prometheus.MustRegister(p.summary) + return + } + if !_useSummary && p.histogram != nil { + prometheus.MustRegister(p.histogram) + } + }) + label := append([]string{name}, extra...) + if _useSummary && p.summary != nil { + p.summary.WithLabelValues(label...).Observe(float64(time)) + return + } + if !_useSummary && p.histogram != nil { + p.histogram.WithLabelValues(label...).Observe(float64(time)) + } +} + +// Incr increments one stat counter without sampling +func (p *Prom) Incr(name string, extra ...string) { + label := append([]string{name}, extra...) + if p.counter != nil { + p.counter.WithLabelValues(label...).Inc() + } +} + +// Decr decrements one stat counter without sampling +func (p *Prom) Decr(name string, extra ...string) { + if p.counter != nil { + label := append([]string{name}, extra...) + p.counter.WithLabelValues(label...).Dec() + } +} + +// State set state +func (p *Prom) State(name string, v int64, extra ...string) { + if p.state != nil { + label := append([]string{name}, extra...) + p.state.WithLabelValues(label...).Set(float64(v)) + } +} + +// Add add count v must > 0 +func (p *Prom) Add(name string, v int64, extra ...string) { + label := append([]string{name}, extra...) + if p.counter != nil { + p.counter.WithLabelValues(label...).Add(float64(v)) + } +} diff --git a/pkg/stat/stat.go b/pkg/stat/stat.go new file mode 100644 index 000000000..7f8c98fb7 --- /dev/null +++ b/pkg/stat/stat.go @@ -0,0 +1,25 @@ +package stat + +import ( + "github.com/bilibili/Kratos/pkg/stat/prom" +) + +// Stat interface. +type Stat interface { + Timing(name string, time int64, extra ...string) + Incr(name string, extra ...string) // name,ext...,code + State(name string, val int64, extra ...string) +} + +// default stat struct. +var ( + // http + HTTPClient Stat = prom.HTTPClient + HTTPServer Stat = prom.HTTPServer + // storage + Cache Stat = prom.LibClient + DB Stat = prom.LibClient + // rpc + RPCClient Stat = prom.RPCClient + RPCServer Stat = prom.RPCServer +) diff --git a/pkg/stat/summary/README.md b/pkg/stat/summary/README.md new file mode 100644 index 000000000..dd9c473ac --- /dev/null +++ b/pkg/stat/summary/README.md @@ -0,0 +1,5 @@ +# summary + +## 项目简介 + +summary计数器 diff --git a/pkg/stat/summary/summary.go b/pkg/stat/summary/summary.go new file mode 100644 index 000000000..a6afc61c9 --- /dev/null +++ b/pkg/stat/summary/summary.go @@ -0,0 +1,129 @@ +package summary + +import ( + "sync" + "time" +) + +type bucket struct { + val int64 + count int64 + next *bucket +} + +func (b *bucket) Add(val int64) { + b.val += val + b.count++ +} + +func (b *bucket) Value() (int64, int64) { + return b.val, b.count +} + +func (b *bucket) Reset() { + b.val = 0 + b.count = 0 +} + +// Summary is a summary interface. +type Summary interface { + Add(int64) + Reset() + Value() (val int64, cnt int64) +} + +type summary struct { + mu sync.RWMutex + buckets []bucket + bucketTime int64 + lastAccess int64 + cur *bucket +} + +// New new a summary. +// +// use RollingCounter creates a new window. windowTime is the time covering the entire +// window. windowBuckets is the number of buckets the window is divided into. +// An example: a 10 second window with 10 buckets will have 10 buckets covering +// 1 second each. +func New(window time.Duration, winBucket int) Summary { + buckets := make([]bucket, winBucket) + bucket := &buckets[0] + for i := 1; i < winBucket; i++ { + bucket.next = &buckets[i] + bucket = bucket.next + } + bucket.next = &buckets[0] + bucketTime := time.Duration(window.Nanoseconds() / int64(winBucket)) + return &summary{ + cur: &buckets[0], + buckets: buckets, + bucketTime: int64(bucketTime), + lastAccess: time.Now().UnixNano(), + } +} + +// Add increments the summary by value. +func (s *summary) Add(val int64) { + s.mu.Lock() + s.lastBucket().Add(val) + s.mu.Unlock() +} + +// Value get the summary value and count. +func (s *summary) Value() (val int64, cnt int64) { + now := time.Now().UnixNano() + s.mu.RLock() + b := s.cur + i := s.elapsed(now) + for j := 0; j < len(s.buckets); j++ { + // skip all future reset bucket. + if i > 0 { + i-- + } else { + v, c := b.Value() + val += v + cnt += c + } + b = b.next + } + s.mu.RUnlock() + return +} + +// Reset reset the counter. +func (s *summary) Reset() { + s.mu.Lock() + for i := range s.buckets { + s.buckets[i].Reset() + } + s.mu.Unlock() +} + +func (s *summary) elapsed(now int64) (i int) { + var e int64 + if e = now - s.lastAccess; e <= s.bucketTime { + return + } + if i = int(e / s.bucketTime); i > len(s.buckets) { + i = len(s.buckets) + } + return +} + +func (s *summary) lastBucket() (b *bucket) { + now := time.Now().UnixNano() + b = s.cur + // reset the buckets between now and number of buckets ago. If + // that is more that the existing buckets, reset all. + if i := s.elapsed(now); i > 0 { + s.lastAccess = now + for ; i > 0; i-- { + // replace the next used bucket. + b = b.next + b.Reset() + } + } + s.cur = b + return +} diff --git a/pkg/stat/summary/summary_test.go b/pkg/stat/summary/summary_test.go new file mode 100644 index 000000000..83d6043c3 --- /dev/null +++ b/pkg/stat/summary/summary_test.go @@ -0,0 +1,69 @@ +package summary + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSummaryMinInterval(t *testing.T) { + count := New(time.Second/2, 10) + tk1 := time.NewTicker(5 * time.Millisecond) + defer tk1.Stop() + for i := 0; i < 100; i++ { + <-tk1.C + count.Add(2) + } + + v, c := count.Value() + t.Logf("count value: %d, %d\n", v, c) + // 10% of error when bucket is 10 + if v < 190 || v > 210 { + t.Errorf("expect value in [90-110] get %d", v) + } + // 10% of error when bucket is 10 + if c < 90 || c > 110 { + t.Errorf("expect value in [90-110] get %d", v) + } +} + +func TestSummary(t *testing.T) { + s := New(time.Second, 10) + + t.Run("add", func(t *testing.T) { + s.Add(1) + v, c := s.Value() + assert.Equal(t, v, int64(1)) + assert.Equal(t, c, int64(1)) + }) + time.Sleep(time.Millisecond * 110) + t.Run("add2", func(t *testing.T) { + s.Add(1) + v, c := s.Value() + assert.Equal(t, v, int64(2)) + assert.Equal(t, c, int64(2)) + }) + time.Sleep(time.Millisecond * 900) // expire one bucket, 110 + 900 + t.Run("expire", func(t *testing.T) { + v, c := s.Value() + assert.Equal(t, v, int64(1)) + assert.Equal(t, c, int64(1)) + s.Add(1) + v, c = s.Value() + assert.Equal(t, v, int64(2)) // expire one bucket + assert.Equal(t, c, int64(2)) // expire one bucket + }) + time.Sleep(time.Millisecond * 1100) + t.Run("expire_all", func(t *testing.T) { + v, c := s.Value() + assert.Equal(t, v, int64(0)) + assert.Equal(t, c, int64(0)) + }) + t.Run("reset", func(t *testing.T) { + s.Reset() + v, c := s.Value() + assert.Equal(t, v, int64(0)) + assert.Equal(t, c, int64(0)) + }) +} diff --git a/pkg/stat/sys/cpu/README.md b/pkg/stat/sys/cpu/README.md new file mode 100644 index 000000000..1028e9e03 --- /dev/null +++ b/pkg/stat/sys/cpu/README.md @@ -0,0 +1,7 @@ +## stat/sys + +System Information + +## 项目简介 + +获取Linux平台下的系统信息,包括cpu主频、cpu使用率等 diff --git a/pkg/stat/sys/cpu/cgroup.go b/pkg/stat/sys/cpu/cgroup.go new file mode 100644 index 000000000..99e80bb98 --- /dev/null +++ b/pkg/stat/sys/cpu/cgroup.go @@ -0,0 +1,125 @@ +// +build linux + +package cpu + +import ( + "bufio" + "fmt" + "io" + "os" + "path" + "strconv" + "strings" +) + +const cgroupRootDir = "/sys/fs/cgroup" + +// cgroup Linux cgroup +type cgroup struct { + cgroupSet map[string]string +} + +// CPUCFSQuotaUs cpu.cfs_quota_us +func (c *cgroup) CPUCFSQuotaUs() (int64, error) { + data, err := readFile(path.Join(c.cgroupSet["cpu"], "cpu.cfs_quota_us")) + if err != nil { + return 0, err + } + return strconv.ParseInt(data, 10, 64) +} + +// CPUCFSPeriodUs cpu.cfs_period_us +func (c *cgroup) CPUCFSPeriodUs() (uint64, error) { + data, err := readFile(path.Join(c.cgroupSet["cpu"], "cpu.cfs_period_us")) + if err != nil { + return 0, err + } + return parseUint(data) +} + +// CPUAcctUsage cpuacct.usage +func (c *cgroup) CPUAcctUsage() (uint64, error) { + data, err := readFile(path.Join(c.cgroupSet["cpuacct"], "cpuacct.usage")) + if err != nil { + return 0, err + } + return parseUint(data) +} + +// CPUAcctUsagePerCPU cpuacct.usage_percpu +func (c *cgroup) CPUAcctUsagePerCPU() ([]uint64, error) { + data, err := readFile(path.Join(c.cgroupSet["cpuacct"], "cpuacct.usage_percpu")) + if err != nil { + return nil, err + } + var usage []uint64 + for _, v := range strings.Fields(string(data)) { + var u uint64 + if u, err = parseUint(v); err != nil { + return nil, err + } + usage = append(usage, u) + } + return usage, nil +} + +// CPUSetCPUs cpuset.cpus +func (c *cgroup) CPUSetCPUs() ([]uint64, error) { + data, err := readFile(path.Join(c.cgroupSet["cpuset"], "cpuset.cpus")) + if err != nil { + return nil, err + } + cpus, err := ParseUintList(data) + if err != nil { + return nil, err + } + var sets []uint64 + for k := range cpus { + sets = append(sets, uint64(k)) + } + return sets, nil +} + +// CurrentcGroup get current process cgroup +func currentcGroup() (*cgroup, error) { + pid := os.Getpid() + cgroupFile := fmt.Sprintf("/proc/%d/cgroup", pid) + cgroupSet := make(map[string]string) + fp, err := os.Open(cgroupFile) + if err != nil { + return nil, err + } + defer fp.Close() + buf := bufio.NewReader(fp) + for { + line, err := buf.ReadString('\n') + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + col := strings.Split(strings.TrimSpace(line), ":") + if len(col) != 3 { + return nil, fmt.Errorf("invalid cgroup format %s", line) + } + dir := col[2] + // When dir is not equal to /, it must be in docker + if dir != "/" { + cgroupSet[col[1]] = path.Join(cgroupRootDir, col[1]) + if strings.Contains(col[1], ",") { + for _, k := range strings.Split(col[1], ",") { + cgroupSet[k] = path.Join(cgroupRootDir, k) + } + } + } else { + cgroupSet[col[1]] = path.Join(cgroupRootDir, col[1], col[2]) + if strings.Contains(col[1], ",") { + for _, k := range strings.Split(col[1], ",") { + cgroupSet[k] = path.Join(cgroupRootDir, k, col[2]) + } + } + } + } + return &cgroup{cgroupSet: cgroupSet}, nil +} diff --git a/pkg/stat/sys/cpu/cgroup_test.go b/pkg/stat/sys/cpu/cgroup_test.go new file mode 100644 index 000000000..9fbb1d151 --- /dev/null +++ b/pkg/stat/sys/cpu/cgroup_test.go @@ -0,0 +1,11 @@ +// +build linux + +package cpu + +import ( + "testing" +) + +func TestCGroup(t *testing.T) { + // TODO +} diff --git a/pkg/stat/sys/cpu/cpu.go b/pkg/stat/sys/cpu/cpu.go new file mode 100644 index 000000000..f9bf95461 --- /dev/null +++ b/pkg/stat/sys/cpu/cpu.go @@ -0,0 +1,107 @@ +package cpu + +import ( + "fmt" + "sync/atomic" + "time" +) + +var ( + cores uint64 + maxFreq uint64 + quota float64 + usage uint64 + + preSystem uint64 + preTotal uint64 +) + +func init() { + cpus, err := perCPUUsage() + if err != nil { + panic(fmt.Errorf("stat/sys/cpu: perCPUUsage() failed!err:=%v", err)) + } + cores = uint64(len(cpus)) + + sets, err := cpuSets() + if err != nil { + panic(fmt.Errorf("stat/sys/cpu: cpuSets() failed!err:=%v", err)) + } + quota = float64(len(sets)) + cq, err := cpuQuota() + if err == nil { + if cq != -1 { + var period uint64 + if period, err = cpuPeriod(); err != nil { + panic(fmt.Errorf("stat/sys/cpu: cpuPeriod() failed!err:=%v", err)) + } + limit := float64(cq) / float64(period) + if limit < quota { + quota = limit + } + } + } + maxFreq = cpuMaxFreq() + + preSystem, err = systemCPUUsage() + if err != nil { + panic(fmt.Errorf("sys/cpu: systemCPUUsage() failed!err:=%v", err)) + } + preTotal, err = totalCPUUsage() + if err != nil { + panic(fmt.Errorf("sys/cpu: totalCPUUsage() failed!err:=%v", err)) + } + + go func() { + ticker := time.NewTicker(time.Millisecond * 250) + defer ticker.Stop() + for { + <-ticker.C + cpu := refreshCPU() + if cpu != 0 { + atomic.StoreUint64(&usage, cpu) + } + } + }() +} + +func refreshCPU() (u uint64) { + total, err := totalCPUUsage() + if err != nil { + return + } + system, err := systemCPUUsage() + if err != nil { + return + } + if system != preSystem { + u = uint64(float64((total-preTotal)*cores*1e3) / (float64(system-preSystem) * quota)) + } + preSystem = system + preTotal = total + return u +} + +// Stat cpu stat. +type Stat struct { + Usage uint64 // cpu use ratio. +} + +// Info cpu info. +type Info struct { + Frequency uint64 + Quota float64 +} + +// ReadStat read cpu stat. +func ReadStat(stat *Stat) { + stat.Usage = atomic.LoadUint64(&usage) +} + +// GetInfo get cpu info. +func GetInfo() Info { + return Info{ + Frequency: maxFreq, + Quota: quota, + } +} diff --git a/pkg/stat/sys/cpu/cpu_darwin.go b/pkg/stat/sys/cpu/cpu_darwin.go new file mode 100644 index 000000000..3081507b0 --- /dev/null +++ b/pkg/stat/sys/cpu/cpu_darwin.go @@ -0,0 +1,20 @@ +// +build darwin + +package cpu + +var su uint64 = 10 +var tu uint64 = 10 + +func systemCPUUsage() (usage uint64, err error) { + su += 1000 + return su, nil +} +func totalCPUUsage() (usage uint64, err error) { + tu += 500 + return tu, nil +} +func perCPUUsage() (usage []uint64, err error) { return []uint64{10, 10, 10, 10}, nil } +func cpuSets() (sets []uint64, err error) { return []uint64{0, 1, 2, 3}, nil } +func cpuQuota() (quota int64, err error) { return 100, nil } +func cpuPeriod() (peroid uint64, err error) { return 10, nil } +func cpuMaxFreq() (feq uint64) { return 10 } diff --git a/pkg/stat/sys/cpu/cpu_linux.go b/pkg/stat/sys/cpu/cpu_linux.go new file mode 100644 index 000000000..7f536eb4a --- /dev/null +++ b/pkg/stat/sys/cpu/cpu_linux.go @@ -0,0 +1,147 @@ +// +build linux + +package cpu + +import ( + "bufio" + "fmt" + "os" + "strconv" + "strings" + + "github.com/pkg/errors" +) + +const nanoSecondsPerSecond = 1e9 + +// ErrNoCFSLimit is no quota limit +var ErrNoCFSLimit = errors.Errorf("no quota limit") + +var clockTicksPerSecond = uint64(GetClockTicks()) + +// systemCPUUsage returns the host system's cpu usage in +// nanoseconds. An error is returned if the format of the underlying +// file does not match. +// +// Uses /proc/stat defined by POSIX. Looks for the cpu +// statistics line and then sums up the first seven fields +// provided. See man 5 proc for details on specific field +// information. +func systemCPUUsage() (usage uint64, err error) { + var ( + line string + f *os.File + ) + if f, err = os.Open("/proc/stat"); err != nil { + return + } + bufReader := bufio.NewReaderSize(nil, 128) + defer func() { + bufReader.Reset(nil) + f.Close() + }() + bufReader.Reset(f) + for err == nil { + if line, err = bufReader.ReadString('\n'); err != nil { + err = errors.WithStack(err) + return + } + parts := strings.Fields(line) + switch parts[0] { + case "cpu": + if len(parts) < 8 { + err = errors.WithStack(fmt.Errorf("bad format of cpu stats")) + return + } + var totalClockTicks uint64 + for _, i := range parts[1:8] { + var v uint64 + if v, err = strconv.ParseUint(i, 10, 64); err != nil { + err = errors.WithStack(fmt.Errorf("error parsing cpu stats")) + return + } + totalClockTicks += v + } + usage = (totalClockTicks * nanoSecondsPerSecond) / clockTicksPerSecond + return + } + } + err = errors.Errorf("bad stats format") + return +} + +func totalCPUUsage() (usage uint64, err error) { + var cg *cgroup + if cg, err = currentcGroup(); err != nil { + return + } + return cg.CPUAcctUsage() +} + +func perCPUUsage() (usage []uint64, err error) { + var cg *cgroup + if cg, err = currentcGroup(); err != nil { + return + } + return cg.CPUAcctUsagePerCPU() +} + +func cpuSets() (sets []uint64, err error) { + var cg *cgroup + if cg, err = currentcGroup(); err != nil { + return + } + return cg.CPUSetCPUs() +} + +func cpuQuota() (quota int64, err error) { + var cg *cgroup + if cg, err = currentcGroup(); err != nil { + return + } + return cg.CPUCFSQuotaUs() +} + +func cpuPeriod() (peroid uint64, err error) { + var cg *cgroup + if cg, err = currentcGroup(); err != nil { + return + } + return cg.CPUCFSPeriodUs() +} + +func cpuFreq() uint64 { + lines, err := readLines("/proc/cpuinfo") + if err != nil { + return 0 + } + for _, line := range lines { + fields := strings.Split(line, ":") + if len(fields) < 2 { + continue + } + key := strings.TrimSpace(fields[0]) + value := strings.TrimSpace(fields[1]) + if key == "cpu MHz" || key == "clock" { + // treat this as the fallback value, thus we ignore error + if t, err := strconv.ParseFloat(strings.Replace(value, "MHz", "", 1), 64); err == nil { + return uint64(t * 1000.0 * 1000.0) + } + } + } + return 0 +} + +func cpuMaxFreq() uint64 { + feq := cpuFreq() + data, err := readFile("/sys/devices/system/cpu/cpu0/cpufreq/cpuinfo_max_freq") + if err != nil { + return feq + } + // override the max freq from /proc/cpuinfo + cfeq, err := parseUint(data) + if err == nil { + feq = cfeq + } + return feq +} diff --git a/pkg/stat/sys/cpu/cpu_other.go b/pkg/stat/sys/cpu/cpu_other.go new file mode 100644 index 000000000..681976ab2 --- /dev/null +++ b/pkg/stat/sys/cpu/cpu_other.go @@ -0,0 +1,11 @@ +// +build windows + +package cpu + +func systemCPUUsage() (usage uint64, err error) { return 10, nil } +func totalCPUUsage() (usage uint64, err error) { return 10, nil } +func perCPUUsage() (usage []uint64, err error) { return []uint64{10, 10, 10, 10}, nil } +func cpuSets() (sets []uint64, err error) { return []uint64{0, 1, 2, 3}, nil } +func cpuQuota() (quota int64, err error) { return 100, nil } +func cpuPeriod() (peroid uint64, err error) { return 10, nil } +func cpuMaxFreq() (feq uint64) { return 10 } diff --git a/pkg/stat/sys/cpu/stat_test.go b/pkg/stat/sys/cpu/stat_test.go new file mode 100644 index 000000000..ed2783043 --- /dev/null +++ b/pkg/stat/sys/cpu/stat_test.go @@ -0,0 +1,20 @@ +package cpu + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestStat(t *testing.T) { + time.Sleep(time.Second * 2) + var s Stat + var i Info + ReadStat(&s) + i = GetInfo() + + assert.NotZero(t, s.Usage) + assert.NotZero(t, i.Frequency) + assert.NotZero(t, i.Quota) +} diff --git a/pkg/stat/sys/cpu/sysconfig_notcgo.go b/pkg/stat/sys/cpu/sysconfig_notcgo.go new file mode 100644 index 000000000..9edab7ef1 --- /dev/null +++ b/pkg/stat/sys/cpu/sysconfig_notcgo.go @@ -0,0 +1,14 @@ +package cpu + +//GetClockTicks get the OS's ticks per second +func GetClockTicks() int { + // TODO figure out a better alternative for platforms where we're missing cgo + // + // TODO Windows. This could be implemented using Win32 QueryPerformanceFrequency(). + // https://msdn.microsoft.com/en-us/library/windows/desktop/ms644905(v=vs.85).aspx + // + // An example of its usage can be found here. + // https://msdn.microsoft.com/en-us/library/windows/desktop/dn553408(v=vs.85).aspx + + return 100 +} diff --git a/pkg/stat/sys/cpu/util.go b/pkg/stat/sys/cpu/util.go new file mode 100644 index 000000000..25df1f9d1 --- /dev/null +++ b/pkg/stat/sys/cpu/util.go @@ -0,0 +1,121 @@ +package cpu + +import ( + "bufio" + "io/ioutil" + "os" + "strconv" + "strings" + + "github.com/pkg/errors" +) + +func readFile(path string) (string, error) { + contents, err := ioutil.ReadFile(path) + if err != nil { + return "", errors.Wrapf(err, "os/stat: read file(%s) failed!", path) + } + return strings.TrimSpace(string(contents)), nil +} + +func parseUint(s string) (uint64, error) { + v, err := strconv.ParseUint(s, 10, 64) + if err != nil { + intValue, intErr := strconv.ParseInt(s, 10, 64) + // 1. Handle negative values greater than MinInt64 (and) + // 2. Handle negative values lesser than MinInt64 + if intErr == nil && intValue < 0 { + return 0, nil + } else if intErr != nil && + intErr.(*strconv.NumError).Err == strconv.ErrRange && + intValue < 0 { + return 0, nil + } + return 0, errors.Wrapf(err, "os/stat: parseUint(%s) failed!", s) + } + return v, nil +} + +// ParseUintList parses and validates the specified string as the value +// found in some cgroup file (e.g. cpuset.cpus, cpuset.mems), which could be +// one of the formats below. Note that duplicates are actually allowed in the +// input string. It returns a map[int]bool with available elements from val +// set to true. +// Supported formats: +// 7 +// 1-6 +// 0,3-4,7,8-10 +// 0-0,0,1-7 +// 03,1-3 <- this is gonna get parsed as [1,2,3] +// 3,2,1 +// 0-2,3,1 +func ParseUintList(val string) (map[int]bool, error) { + if val == "" { + return map[int]bool{}, nil + } + + availableInts := make(map[int]bool) + split := strings.Split(val, ",") + errInvalidFormat := errors.Errorf("os/stat: invalid format: %s", val) + for _, r := range split { + if !strings.Contains(r, "-") { + v, err := strconv.Atoi(r) + if err != nil { + return nil, errInvalidFormat + } + availableInts[v] = true + } else { + split := strings.SplitN(r, "-", 2) + min, err := strconv.Atoi(split[0]) + if err != nil { + return nil, errInvalidFormat + } + max, err := strconv.Atoi(split[1]) + if err != nil { + return nil, errInvalidFormat + } + if max < min { + return nil, errInvalidFormat + } + for i := min; i <= max; i++ { + availableInts[i] = true + } + } + } + return availableInts, nil +} + +// ReadLines reads contents from a file and splits them by new lines. +// A convenience wrapper to ReadLinesOffsetN(filename, 0, -1). +func readLines(filename string) ([]string, error) { + return readLinesOffsetN(filename, 0, -1) +} + +// ReadLinesOffsetN reads contents from file and splits them by new line. +// The offset tells at which line number to start. +// The count determines the number of lines to read (starting from offset): +// n >= 0: at most n lines +// n < 0: whole file +func readLinesOffsetN(filename string, offset uint, n int) ([]string, error) { + f, err := os.Open(filename) + if err != nil { + return []string{""}, err + } + defer f.Close() + + var ret []string + + r := bufio.NewReader(f) + for i := 0; i < n+int(offset) || n < 0; i++ { + line, err := r.ReadString('\n') + if err != nil { + break + } + if i < int(offset) { + continue + } + ret = append(ret, strings.Trim(line, "\n")) + } + + return ret, nil +} diff --git a/pkg/str/str.go b/pkg/str/str.go new file mode 100644 index 000000000..8e93028df --- /dev/null +++ b/pkg/str/str.go @@ -0,0 +1,55 @@ +package xstr + +import ( + "bytes" + "strconv" + "strings" + "sync" +) + +var ( + bfPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer([]byte{}) + }, + } +) + +// JoinInts format int64 slice like:n1,n2,n3. +func JoinInts(is []int64) string { + if len(is) == 0 { + return "" + } + if len(is) == 1 { + return strconv.FormatInt(is[0], 10) + } + buf := bfPool.Get().(*bytes.Buffer) + for _, i := range is { + buf.WriteString(strconv.FormatInt(i, 10)) + buf.WriteByte(',') + } + if buf.Len() > 0 { + buf.Truncate(buf.Len() - 1) + } + s := buf.String() + buf.Reset() + bfPool.Put(buf) + return s +} + +// SplitInts split string into int64 slice. +func SplitInts(s string) ([]int64, error) { + if s == "" { + return nil, nil + } + sArr := strings.Split(s, ",") + res := make([]int64, 0, len(sArr)) + for _, sc := range sArr { + i, err := strconv.ParseInt(sc, 10, 64) + if err != nil { + return nil, err + } + res = append(res, i) + } + return res, nil +} diff --git a/pkg/str/str_test.go b/pkg/str/str_test.go new file mode 100644 index 000000000..a452a5b43 --- /dev/null +++ b/pkg/str/str_test.go @@ -0,0 +1,60 @@ +package xstr + +import ( + "testing" +) + +func TestJoinInts(t *testing.T) { + // test empty slice + is := []int64{} + s := JoinInts(is) + if s != "" { + t.Errorf("input:%v,output:%s,result is incorrect", is, s) + } else { + t.Logf("input:%v,output:%s", is, s) + } + // test len(slice)==1 + is = []int64{1} + s = JoinInts(is) + if s != "1" { + t.Errorf("input:%v,output:%s,result is incorrect", is, s) + } else { + t.Logf("input:%v,output:%s", is, s) + } + // test len(slice)>1 + is = []int64{1, 2, 3} + s = JoinInts(is) + if s != "1,2,3" { + t.Errorf("input:%v,output:%s,result is incorrect", is, s) + } else { + t.Logf("input:%v,output:%s", is, s) + } +} + +func TestSplitInts(t *testing.T) { + // test empty slice + s := "" + is, err := SplitInts(s) + if err != nil || len(is) != 0 { + t.Error(err) + } + // test split int64 + s = "1,2,3" + is, err = SplitInts(s) + if err != nil || len(is) != 3 { + t.Error(err) + } +} + +func BenchmarkJoinInts(b *testing.B) { + is := make([]int64, 10000, 10000) + for i := int64(0); i < 10000; i++ { + is[i] = i + } + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + JoinInts(is) + } + }) +} diff --git a/pkg/sync/errgroup/README.md b/pkg/sync/errgroup/README.md new file mode 100644 index 000000000..d665400da --- /dev/null +++ b/pkg/sync/errgroup/README.md @@ -0,0 +1,3 @@ +# errgroup + +提供带recover和并行数的errgroup,err中包含详细堆栈信息 diff --git a/pkg/sync/errgroup/doc.go b/pkg/sync/errgroup/doc.go new file mode 100644 index 000000000..b35cff432 --- /dev/null +++ b/pkg/sync/errgroup/doc.go @@ -0,0 +1,47 @@ +// Package errgroup provides synchronization, error propagation, and Context +// errgroup 包为一组子任务的 goroutine 提供了 goroutine 同步,错误取消功能. +// +//errgroup 包含三种常用方式 +// +//1、直接使用 此时不会因为一个任务失败导致所有任务被 cancel: +// g := &errgroup.Group{} +// g.Go(func(ctx context.Context) { +// // NOTE: 此时 ctx 为 context.Background() +// // do something +// }) +// +//2、WithContext 使用 WithContext 时不会因为一个任务失败导致所有任务被 cancel: +// g := errgroup.WithContext(ctx) +// g.Go(func(ctx context.Context) { +// // NOTE: 此时 ctx 为 errgroup.WithContext 传递的 ctx +// // do something +// }) +// +//3、WithCancel 使用 WithCancel 时如果有一个人任务失败会导致所有*未进行或进行中*的任务被 cancel: +// g := errgroup.WithCancel(ctx) +// g.Go(func(ctx context.Context) { +// // NOTE: 此时 ctx 是从 errgroup.WithContext 传递的 ctx 派生出的 ctx +// // do something +// }) +// +//设置最大并行数 GOMAXPROCS 对以上三种使用方式均起效 +//NOTE: 由于 errgroup 实现问题,设定 GOMAXPROCS 的 errgroup 需要立即调用 Wait() 例如: +// +// g := errgroup.WithCancel(ctx) +// g.GOMAXPROCS(2) +// // task1 +// g.Go(func(ctx context.Context) { +// fmt.Println("task1") +// }) +// // task2 +// g.Go(func(ctx context.Context) { +// fmt.Println("task2") +// }) +// // task3 +// g.Go(func(ctx context.Context) { +// fmt.Println("task3") +// }) +// // NOTE: 此时设置的 GOMAXPROCS 为2, 添加了三个任务 task1, task2, task3 此时 task3 是不会运行的! +// // 只有调用了 Wait task3 才有运行的机会 +// g.Wait() // task3 运行 +package errgroup diff --git a/pkg/sync/errgroup/errgroup.go b/pkg/sync/errgroup/errgroup.go new file mode 100644 index 000000000..c795e1141 --- /dev/null +++ b/pkg/sync/errgroup/errgroup.go @@ -0,0 +1,119 @@ +package errgroup + +import ( + "context" + "fmt" + "runtime" + "sync" +) + +// A Group is a collection of goroutines working on subtasks that are part of +// the same overall task. +// +// A zero Group is valid and does not cancel on error. +type Group struct { + err error + wg sync.WaitGroup + errOnce sync.Once + + workerOnce sync.Once + ch chan func(ctx context.Context) error + chs []func(ctx context.Context) error + + ctx context.Context + cancel func() +} + +// WithContext create a Group. +// given function from Go will receive this context, +func WithContext(ctx context.Context) *Group { + return &Group{ctx: ctx} +} + +// WithCancel create a new Group and an associated Context derived from ctx. +// +// given function from Go will receive context derived from this ctx, +// The derived Context is canceled the first time a function passed to Go +// returns a non-nil error or the first time Wait returns, whichever occurs +// first. +func WithCancel(ctx context.Context) *Group { + ctx, cancel := context.WithCancel(ctx) + return &Group{ctx: ctx, cancel: cancel} +} + +func (g *Group) do(f func(ctx context.Context) error) { + ctx := g.ctx + if ctx == nil { + ctx = context.Background() + } + var err error + defer func() { + if r := recover(); r != nil { + buf := make([]byte, 64<<10) + buf = buf[:runtime.Stack(buf, false)] + err = fmt.Errorf("errgroup: panic recovered: %s\n%s", r, buf) + } + if err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel() + } + }) + } + g.wg.Done() + }() + err = f(ctx) +} + +// GOMAXPROCS set max goroutine to work. +func (g *Group) GOMAXPROCS(n int) { + if n <= 0 { + panic("errgroup: GOMAXPROCS must great than 0") + } + g.workerOnce.Do(func() { + g.ch = make(chan func(context.Context) error, n) + for i := 0; i < n; i++ { + go func() { + for f := range g.ch { + g.do(f) + } + }() + } + }) +} + +// Go calls the given function in a new goroutine. +// +// The first call to return a non-nil error cancels the group; its error will be +// returned by Wait. +func (g *Group) Go(f func(ctx context.Context) error) { + g.wg.Add(1) + if g.ch != nil { + select { + case g.ch <- f: + default: + g.chs = append(g.chs, f) + } + return + } + go g.do(f) +} + +// Wait blocks until all function calls from the Go method have returned, then +// returns the first non-nil error (if any) from them. +func (g *Group) Wait() error { + if g.ch != nil { + for _, f := range g.chs { + g.ch <- f + } + } + g.wg.Wait() + if g.ch != nil { + close(g.ch) // let all receiver exit + } + if g.cancel != nil { + g.cancel() + } + return g.err +} diff --git a/pkg/sync/errgroup/errgroup_test.go b/pkg/sync/errgroup/errgroup_test.go new file mode 100644 index 000000000..bb050c160 --- /dev/null +++ b/pkg/sync/errgroup/errgroup_test.go @@ -0,0 +1,266 @@ +package errgroup + +import ( + "context" + "errors" + "fmt" + "math" + "net/http" + "os" + "testing" + "time" +) + +type ABC struct { + CBA int +} + +func TestNormal(t *testing.T) { + var ( + abcs = make(map[int]*ABC) + g Group + err error + ) + for i := 0; i < 10; i++ { + abcs[i] = &ABC{CBA: i} + } + g.Go(func(context.Context) (err error) { + abcs[1].CBA++ + return + }) + g.Go(func(context.Context) (err error) { + abcs[2].CBA++ + return + }) + if err = g.Wait(); err != nil { + t.Log(err) + } + t.Log(abcs) +} + +func sleep1s(context.Context) error { + time.Sleep(time.Second) + return nil +} + +func TestGOMAXPROCS(t *testing.T) { + // 没有并发数限制 + g := Group{} + now := time.Now() + g.Go(sleep1s) + g.Go(sleep1s) + g.Go(sleep1s) + g.Go(sleep1s) + g.Wait() + sec := math.Round(time.Since(now).Seconds()) + if sec != 1 { + t.FailNow() + } + // 限制并发数 + g2 := Group{} + g2.GOMAXPROCS(2) + now = time.Now() + g2.Go(sleep1s) + g2.Go(sleep1s) + g2.Go(sleep1s) + g2.Go(sleep1s) + g2.Wait() + sec = math.Round(time.Since(now).Seconds()) + if sec != 2 { + t.FailNow() + } + // context canceled + var canceled bool + g3 := WithCancel(context.Background()) + g3.GOMAXPROCS(2) + g3.Go(func(context.Context) error { + return fmt.Errorf("error for testing errgroup context") + }) + g3.Go(func(ctx context.Context) error { + time.Sleep(time.Second) + select { + case <-ctx.Done(): + canceled = true + default: + } + return nil + }) + g3.Wait() + if !canceled { + t.FailNow() + } +} + +func TestRecover(t *testing.T) { + var ( + abcs = make(map[int]*ABC) + g Group + err error + ) + g.Go(func(context.Context) (err error) { + abcs[1].CBA++ + return + }) + g.Go(func(context.Context) (err error) { + abcs[2].CBA++ + return + }) + if err = g.Wait(); err != nil { + t.Logf("error:%+v", err) + return + } + t.FailNow() +} + +func TestRecover2(t *testing.T) { + var ( + g Group + err error + ) + g.Go(func(context.Context) (err error) { + panic("2233") + }) + if err = g.Wait(); err != nil { + t.Logf("error:%+v", err) + return + } + t.FailNow() +} + +var ( + Web = fakeSearch("web") + Image = fakeSearch("image") + Video = fakeSearch("video") +) + +type Result string +type Search func(ctx context.Context, query string) (Result, error) + +func fakeSearch(kind string) Search { + return func(_ context.Context, query string) (Result, error) { + return Result(fmt.Sprintf("%s result for %q", kind, query)), nil + } +} + +// JustErrors illustrates the use of a Group in place of a sync.WaitGroup to +// simplify goroutine counting and error handling. This example is derived from +// the sync.WaitGroup example at https://golang.org/pkg/sync/#example_WaitGroup. +func ExampleGroup_justErrors() { + var g Group + var urls = []string{ + "http://www.golang.org/", + "http://www.google.com/", + "http://www.somestupidname.com/", + } + for _, url := range urls { + // Launch a goroutine to fetch the URL. + url := url // https://golang.org/doc/faq#closures_and_goroutines + g.Go(func(context.Context) error { + // Fetch the URL. + resp, err := http.Get(url) + if err == nil { + resp.Body.Close() + } + return err + }) + } + // Wait for all HTTP fetches to complete. + if err := g.Wait(); err == nil { + fmt.Println("Successfully fetched all URLs.") + } +} + +// Parallel illustrates the use of a Group for synchronizing a simple parallel +// task: the "Google Search 2.0" function from +// https://talks.golang.org/2012/concurrency.slide#46, augmented with a Context +// and error-handling. +func ExampleGroup_parallel() { + Google := func(ctx context.Context, query string) ([]Result, error) { + g := WithContext(ctx) + + searches := []Search{Web, Image, Video} + results := make([]Result, len(searches)) + for i, search := range searches { + i, search := i, search // https://golang.org/doc/faq#closures_and_goroutines + g.Go(func(context.Context) error { + result, err := search(ctx, query) + if err == nil { + results[i] = result + } + return err + }) + } + if err := g.Wait(); err != nil { + return nil, err + } + return results, nil + } + + results, err := Google(context.Background(), "golang") + if err != nil { + fmt.Fprintln(os.Stderr, err) + return + } + for _, result := range results { + fmt.Println(result) + } + + // Output: + // web result for "golang" + // image result for "golang" + // video result for "golang" +} + +func TestZeroGroup(t *testing.T) { + err1 := errors.New("errgroup_test: 1") + err2 := errors.New("errgroup_test: 2") + + cases := []struct { + errs []error + }{ + {errs: []error{}}, + {errs: []error{nil}}, + {errs: []error{err1}}, + {errs: []error{err1, nil}}, + {errs: []error{err1, nil, err2}}, + } + + for _, tc := range cases { + var g Group + + var firstErr error + for i, err := range tc.errs { + err := err + g.Go(func(context.Context) error { return err }) + + if firstErr == nil && err != nil { + firstErr = err + } + + if gErr := g.Wait(); gErr != firstErr { + t.Errorf("after g.Go(func() error { return err }) for err in %v\n"+ + "g.Wait() = %v; want %v", tc.errs[:i+1], err, firstErr) + } + } + } +} + +func TestWithCancel(t *testing.T) { + g := WithCancel(context.Background()) + g.Go(func(ctx context.Context) error { + time.Sleep(100 * time.Millisecond) + return fmt.Errorf("boom") + }) + var doneErr error + g.Go(func(ctx context.Context) error { + select { + case <-ctx.Done(): + doneErr = ctx.Err() + } + return doneErr + }) + g.Wait() + if doneErr != context.Canceled { + t.Error("error should be Canceled") + } +} diff --git a/pkg/sync/errgroup/example_test.go b/pkg/sync/errgroup/example_test.go new file mode 100644 index 000000000..f2e5b4f47 --- /dev/null +++ b/pkg/sync/errgroup/example_test.go @@ -0,0 +1,63 @@ +package errgroup + +import ( + "context" +) + +func fakeRunTask(ctx context.Context) error { + return nil +} + +func ExampleGroup_group() { + g := Group{} + g.Go(func(context.Context) error { + return fakeRunTask(context.Background()) + }) + g.Go(func(context.Context) error { + return fakeRunTask(context.Background()) + }) + if err := g.Wait(); err != nil { + // handle err + } +} + +func ExampleGroup_ctx() { + g := WithContext(context.Background()) + g.Go(func(ctx context.Context) error { + return fakeRunTask(ctx) + }) + g.Go(func(ctx context.Context) error { + return fakeRunTask(ctx) + }) + if err := g.Wait(); err != nil { + // handle err + } +} + +func ExampleGroup_cancel() { + g := WithCancel(context.Background()) + g.Go(func(ctx context.Context) error { + return fakeRunTask(ctx) + }) + g.Go(func(ctx context.Context) error { + return fakeRunTask(ctx) + }) + if err := g.Wait(); err != nil { + // handle err + } +} + +func ExampleGroup_maxproc() { + g := Group{} + // set max concurrency + g.GOMAXPROCS(2) + g.Go(func(ctx context.Context) error { + return fakeRunTask(context.Background()) + }) + g.Go(func(ctx context.Context) error { + return fakeRunTask(context.Background()) + }) + if err := g.Wait(); err != nil { + // handle err + } +} diff --git a/pkg/time/README.md b/pkg/time/README.md new file mode 100644 index 000000000..6e099f312 --- /dev/null +++ b/pkg/time/README.md @@ -0,0 +1,5 @@ +# time + +## 项目简介 + +Kratos的时间模块,主要用于mysql时间戳转换、配置文件读取并转换、Context超时时间比较 diff --git a/pkg/time/time.go b/pkg/time/time.go new file mode 100644 index 000000000..cb03e9919 --- /dev/null +++ b/pkg/time/time.go @@ -0,0 +1,59 @@ +package time + +import ( + "context" + "database/sql/driver" + "strconv" + xtime "time" +) + +// Time be used to MySql timestamp converting. +type Time int64 + +// Scan scan time. +func (jt *Time) Scan(src interface{}) (err error) { + switch sc := src.(type) { + case xtime.Time: + *jt = Time(sc.Unix()) + case string: + var i int64 + i, err = strconv.ParseInt(sc, 10, 64) + *jt = Time(i) + } + return +} + +// Value get time value. +func (jt Time) Value() (driver.Value, error) { + return xtime.Unix(int64(jt), 0), nil +} + +// Time get time. +func (jt Time) Time() xtime.Time { + return xtime.Unix(int64(jt), 0) +} + +// Duration be used toml unmarshal string time, like 1s, 500ms. +type Duration xtime.Duration + +// UnmarshalText unmarshal text to duration. +func (d *Duration) UnmarshalText(text []byte) error { + tmp, err := xtime.ParseDuration(string(text)) + if err == nil { + *d = Duration(tmp) + } + return err +} + +// Shrink will decrease the duration by comparing with context's timeout duration +// and return new timeout\context\CancelFunc. +func (d Duration) Shrink(c context.Context) (Duration, context.Context, context.CancelFunc) { + if deadline, ok := c.Deadline(); ok { + if ctimeout := xtime.Until(deadline); ctimeout < xtime.Duration(d) { + // deliver small timeout + return Duration(ctimeout), c, func() {} + } + } + ctx, cancel := context.WithTimeout(c, xtime.Duration(d)) + return d, ctx, cancel +} diff --git a/pkg/time/time_test.go b/pkg/time/time_test.go new file mode 100644 index 000000000..f0c2d6604 --- /dev/null +++ b/pkg/time/time_test.go @@ -0,0 +1,60 @@ +package time + +import ( + "context" + "testing" + "time" +) + +func TestShrink(t *testing.T) { + var d Duration + err := d.UnmarshalText([]byte("1s")) + if err != nil { + t.Fatalf("TestShrink: d.UnmarshalText failed!err:=%v", err) + } + c := context.Background() + to, ctx, cancel := d.Shrink(c) + defer cancel() + if time.Duration(to) != time.Second { + t.Fatalf("new timeout must be equal 1 second") + } + if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > time.Second || time.Until(deadline) < time.Millisecond*500 { + t.Fatalf("ctx deadline must be less than 1s and greater than 500ms") + } +} + +func TestShrinkWithTimeout(t *testing.T) { + var d Duration + err := d.UnmarshalText([]byte("1s")) + if err != nil { + t.Fatalf("TestShrink: d.UnmarshalText failed!err:=%v", err) + } + c, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + to, ctx, cancel := d.Shrink(c) + defer cancel() + if time.Duration(to) != time.Second { + t.Fatalf("new timeout must be equal 1 second") + } + if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > time.Second || time.Until(deadline) < time.Millisecond*500 { + t.Fatalf("ctx deadline must be less than 1s and greater than 500ms") + } +} + +func TestShrinkWithDeadline(t *testing.T) { + var d Duration + err := d.UnmarshalText([]byte("1s")) + if err != nil { + t.Fatalf("TestShrink: d.UnmarshalText failed!err:=%v", err) + } + c, cancel := context.WithTimeout(context.Background(), time.Millisecond*500) + defer cancel() + to, ctx, cancel := d.Shrink(c) + defer cancel() + if time.Duration(to) >= time.Millisecond*500 { + t.Fatalf("new timeout must be less than 500 ms") + } + if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > time.Millisecond*500 || time.Until(deadline) < time.Millisecond*200 { + t.Fatalf("ctx deadline must be less than 500ms and greater than 200ms") + } +}