add bts-gen & mc-gen (#96)

add genbts & genmc
pull/108/head
jiankuny 6 years ago committed by GitHub
parent 651a05b72c
commit 0ffb6233bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 30
      doc/wiki-cn/kratos-genbts.md
  2. 71
      doc/wiki-cn/kratos-genmc.md
  3. 38
      doc/wiki-cn/kratos-protoc.md
  4. 11
      doc/wiki-cn/kratos-swagger.md
  5. 49
      doc/wiki-cn/kratos-tool.md
  6. 4
      doc/wiki-cn/summary.md
  7. 46
      tool/kratos-gen-bts/README.md
  8. 31
      tool/kratos-gen-bts/header_template.go
  9. 482
      tool/kratos-gen-bts/main.go
  10. 126
      tool/kratos-gen-bts/multi_template.go
  11. 65
      tool/kratos-gen-bts/none_template.go
  12. 86
      tool/kratos-gen-bts/single_template.go
  13. 202
      tool/kratos-gen-bts/testdata/dao.bts.go
  14. 35
      tool/kratos-gen-bts/testdata/dao.go
  15. 30
      tool/kratos-gen-bts/testdata/multi.go
  16. 67
      tool/kratos-gen-bts/testdata/multi_test.go
  17. 30
      tool/kratos-gen-bts/testdata/none.go
  18. 50
      tool/kratos-gen-bts/testdata/none_test.go
  19. 48
      tool/kratos-gen-bts/testdata/single.go
  20. 50
      tool/kratos-gen-bts/testdata/single_test.go
  21. 44
      tool/kratos-gen-mc/README.md
  22. 30
      tool/kratos-gen-mc/header_template.go
  23. 548
      tool/kratos-gen-mc/main.go
  24. 208
      tool/kratos-gen-mc/multi_template.go
  25. 103
      tool/kratos-gen-mc/none_template.go
  26. 102
      tool/kratos-gen-mc/single_template.go
  27. 93
      tool/kratos-gen-mc/testdata/dao.go
  28. 116
      tool/kratos-gen-mc/testdata/dao_test.go
  29. 320
      tool/kratos-gen-mc/testdata/mc.cache.go
  30. 328
      tool/kratos-gen-mc/testdata/model.pb.go
  31. 14
      tool/kratos-gen-mc/testdata/model.proto
  32. 18
      tool/kratos/tool_index.go
  33. 149
      tool/pkg/common.go

@ -0,0 +1,30 @@
### kratos tool genbts
> 缓存回源代码生成
在internal/dao/dao.go中添加mc缓存interface定义,可以指定对应的[注解参数](../tool/kratos-gen-mc/README.md);
并且在接口前面添加`go:generate kratos tool genmc`;
然后在当前目录执行`go generate`,可以看到自动生成的dao.bts.go代码。
### 回源模板
```go
//go:generate kratos tool genbts
type _bts interface {
// bts: -batch=2 -max_group=20 -batch_err=break -nullcache=&Demo{ID:-1} -check_null_code=$.ID==-1
Demos(c context.Context, keys []int64) (map[int64]*Demo, error)
// bts: -sync=true -nullcache=&Demo{ID:-1} -check_null_code=$.ID==-1
Demo(c context.Context, key int64) (*Demo, error)
// bts: -paging=true
Demo1(c context.Context, key int64, pn int, ps int) (*Demo, error)
// bts: -nullcache=&Demo{ID:-1} -check_null_code=$.ID==-1
None(c context.Context) (*Demo, error)
}
```
### 参考
也可以参考完整的testdata例子:kratos/tool/kratos-gen-bts/testdata
-------------
[文档目录树](summary.md)

@ -0,0 +1,71 @@
### kratos tool genmc
> 缓存代码生成
在internal/dao/dao.go中添加mc缓存interface定义,可以指定对应的[注解参数](../../tool/kratos-gen-mc/README.md);
并且在接口前面添加`go:generate kratos tool genmc`;
然后在当前目录执行`go generate`,可以看到自动生成的mc.cache.go代码。
### 缓存模板
```go
//go:generate kratos tool genmc
type _mc interface {
// mc: -key=demoKey
CacheDemos(c context.Context, keys []int64) (map[int64]*Demo, error)
// mc: -key=demoKey
CacheDemo(c context.Context, key int64) (*Demo, error)
// mc: -key=keyMid
CacheDemo1(c context.Context, key int64, mid int64) (*Demo, error)
// mc: -key=noneKey
CacheNone(c context.Context) (*Demo, error)
// mc: -key=demoKey
CacheString(c context.Context, key int64) (string, error)
// mc: -key=demoKey -expire=d.demoExpire -encode=json
AddCacheDemos(c context.Context, values map[int64]*Demo) error
// mc: -key=demo2Key -expire=d.demoExpire -encode=json
AddCacheDemos2(c context.Context, values map[int64]*Demo, tp int64) error
// 这里也支持自定义注释 会替换默认的注释
// mc: -key=demoKey -expire=d.demoExpire -encode=json|gzip
AddCacheDemo(c context.Context, key int64, value *Demo) error
// mc: -key=keyMid -expire=d.demoExpire -encode=gob
AddCacheDemo1(c context.Context, key int64, value *Demo, mid int64) error
// mc: -key=noneKey
AddCacheNone(c context.Context, value *Demo) error
// mc: -key=demoKey -expire=d.demoExpire
AddCacheString(c context.Context, key int64, value string) error
// mc: -key=demoKey
DelCacheDemos(c context.Context, keys []int64) error
// mc: -key=demoKey
DelCacheDemo(c context.Context, key int64) error
// mc: -key=keyMid
DelCacheDemo1(c context.Context, key int64, mid int64) error
// mc: -key=noneKey
DelCacheNone(c context.Context) error
}
func demoKey(id int64) string {
return fmt.Sprintf("art_%d", id)
}
func demo2Key(id, tp int64) string {
return fmt.Sprintf("art_%d_%d", id, tp)
}
func keyMid(id, mid int64) string {
return fmt.Sprintf("art_%d_%d", id, mid)
}
func noneKey() string {
return "none"
}
```
### 参考
也可以参考完整的testdata例子:kratos/tool/kratos-gen-mc/testdata
-------------
[文档目录树](summary.md)

@ -0,0 +1,38 @@
### kratos tool protoc
```
// generate all
kratos tool protoc api.proto
// generate gRPC
kratos tool protoc --grpc api.proto
// generate BM HTTP
kratos tool protoc --bm api.proto
// generate swagger
kratos tool protoc --swagger api.proto
```
执行对应生成 `api.pb.go/api.bm.go/api.swagger.json` 源文档。
> 该工具在Windows/Linux下运行,需提前安装好 protobuf 工具
该工具实际是一段`shell`脚本,其中自动将`protoc`命令进行了拼接,识别了需要的`*.proto`文件和当前目录下的`proto`文件,最终会拼接为如下命令进行执行:
```shell
export $KRATOS_HOME = kratos路径
export $KRATOS_DEMO = 项目路径
// 生成:api.pb.go
protoc -I$GOPATH/src:$KRATOS_HOME/tool/protobuf/pkg/extensions:$KRATOS_DEMO/api --gogofast_out=plugins=grpc:$KRATOS_DEMO/api $KRATOS_DEMO/api/api.proto
// 生成:api.bm.go
protoc -I$GOPATH/src:$KRATOS_HOME/tool/protobuf/pkg/extensions:$KRATOS_DEMO/api --bm_out=$KRATOS_DEMO/api $KRATOS_DEMO/api/api.proto
// 生成:api.swagger.json
protoc -I$GOPATH/src:$KRATOS_HOME/tool/protobuf/pkg/extensions:$KRATOS_DEMO/api --bswagger_out=$KRATOS_DEMO/api $KRATOS_DEMO/api/api.proto
```
大家也可以参考该命令进行`proto`生成,也可以参考 [protobuf](https://github.com/google/protobuf) 官方参数。
-------------
[文档目录树](summary.md)

@ -0,0 +1,11 @@
### kratos tool swagger
```shell
kratos tool swagger serve api/api.swagger.json
```
执行命令后,浏览器会自动打开swagger文档地址。
同时也可以查看更多的 [go-swagger](https://github.com/go-swagger/go-swagger) 官方参数进行使用。
-------------
[文档目录树](summary.md)

@ -92,51 +92,10 @@ kratos(已安装): Kratos工具集本体 Author(kratos) [2019/04/02]
目前已经集成的工具有: 目前已经集成的工具有:
* kratos 为本体工具,只用于安装更新使用; * kratos 为本体工具,只用于安装更新使用;
* protoc 用于快速生成gRPC、HTTP、Swagger文件,该命令Windows,Linux用户需要手动安装 protobuf 工具。 * protoc 用于快速生成gRPC、HTTP、Swagger文件,该命令Windows,Linux用户需要手动安装 protobuf 工具;
* swagger 用于显示自动生成的HTTP API接口文档,通过 `kratos tool swagger serve api/api.swagger.json` 可以查看文档。 * swagger 用于显示自动生成的HTTP API接口文档,通过 `kratos tool swagger serve api/api.swagger.json` 可以查看文档;
* genmc 用于自动生成memcached缓存代码;
### kratos tool protoc * genbts 用于生成缓存回源代码生成,如果miss则调用回源函数从数据源获取,然后塞入缓存;
```shell
# generate all
kratos tool protoc api.proto
# generate gRPC
kratos tool protoc --grpc api.proto
# generate BM HTTP
kratos tool protoc --bm api.proto
# generate swagger
kratos tool protoc --swagger api.proto
```
执行对应生成 `api.pb.go/api.bm.go/api.swagger.json` 源文档。
> 该工具在Windows/Linux下运行,需提前安装好 protobuf 工具
该工具实际是一段`shell`脚本,其中自动将`protoc`命令进行了拼接,识别了需要的`*.proto`文件和当前目录下的`proto`文件,最终会拼接为如下命令进行执行:
```shell
export $KRATOS_HOME = kratos路径
export $KRATOS_DEMO = 项目路径
# 生成:api.pb.go
protoc -I$GOPATH/src:$KRATOS_HOME/tool/protobuf/pkg/extensions:$KRATOS_DEMO/api --gogofast_out=plugins=grpc:$KRATOS_DEMO/api $KRATOS_DEMO/api/api.proto
# 生成:api.bm.go
protoc -I$GOPATH/src:$KRATOS_HOME/tool/protobuf/pkg/extensions:$KRATOS_DEMO/api --bm_out=$KRATOS_DEMO/api $KRATOS_DEMO/api/api.proto
# 生成:api.swagger.json
protoc -I$GOPATH/src:$KRATOS_HOME/tool/protobuf/pkg/extensions:$KRATOS_DEMO/api --bswagger_out=$KRATOS_DEMO/api $KRATOS_DEMO/api/api.proto
```
大家也可以参考该命令进行`proto`生成,也可以参考[protobuf](https://github.com/google/protobuf)官方参数。
### kratos tool swagger
```shell
kratos tool swagger serve api/api.swagger.json
```
执行命令后,浏览器会自动打开swagger文档地址。
同时也可以查看更多的 [go-swagger](https://github.com/go-swagger/go-swagger) 官方参数进行使用。
------------- -------------

@ -25,3 +25,7 @@
* [memcache](cache-mc.md) * [memcache](cache-mc.md)
* [redis](cache-redis.md) * [redis](cache-redis.md)
* [kratos工具](kratos-tool.md) * [kratos工具](kratos-tool.md)
* [protoc](kratos-protoc.md)
* [swagger](kratos-swagger.md)
* [genmc](kratos-genmc.md)
* [genbts](kratos-genbts.md)

@ -0,0 +1,46 @@
#### genbts
> 缓存代码生成
##### 项目简介
从缓存中获取数据 如果miss则调用回源函数从数据源获取 然后塞入缓存
支持以下功能:
- 单飞限制回源并发 防止打爆数据源
- 空缓存 防止缓存穿透
- 分批获取数据 降低延时
- 默认异步加缓存 可选同步加缓存
- prometheus回源比监控
- 多行注释生成代码
- 支持分页(限单key模板)
- 自定义注释
- 支持忽略参数
##### 使用方式:
1. 在dao package中 增加注解 //go:generate kratos tool genbts 定义bts接口 声明需要的方法
2. 在dao 文件夹中执行 go generate命令 将会生成相应的缓存代码
3. 调用生成的XXX方法
4. 示例见testdata/dao.go
要求:
dao里面需要有cache对象 代码会调用d.cache来新增缓存
需要实现代码中所需的方法 每一个缓存方法都需要实现以下方法:
从缓存中获取数据 名称为Cache+方法名 函数定义和声明一致
从数据源(db/api/...)获取数据 名称为Raw+方法 函数定义和声明一致
存入缓存方法 名称为AddCache+方法名 函数定义为 func AddCache方法名(c context.Context, ...) (error)
##### 注解参数:
| 参数名称 | 默认值 | 说明 | 示例 |
| ---------------- | ------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| -nullcache | | 空指针对象(存正常业务不会出现的内容 id的话像是-1这样的) | &Article{ID:-1} 或-1 或"null" |
| -check_null_code | | 开启空缓存并且value为指针对象时必填 用于判断是否是空缓存 $来指代对象名 | `-check_null_code=$!=nil&&$.ID==-1 或 $ == -1` |
| -batch | | (限多key模板) 批量获取数据 每组大小 | 100 |
| -max_group | | (限多key模板)批量获取数据 最大组数量 | 10 |
| -batch_err | break | (限多key模板)批量获取数据回源错误的时候 降级继续请求(continue)还是直接返回(break) | break 或 continue |
| -singleflight | false | 是否开启单飞(开启后生成函数会多一个单飞名称参数 生成的代码会调用d.cacheSFNAME方法获取单飞的key) | true |
| -sync | false | 是否同步增加缓存 | false |
| -paging | false | (限单key模板)分页 数据源应返回2个值 第一个为对外数据 第二个为全量数据 用于新增缓存 | false |
| -ignores | | 用于依赖的三个方法参数和主方法参数不一致的情况. 忽略方法的某些参数 用\|分隔方法逗号分隔参数 | pn,ps\|pn\|origin 表示"缓存获取"方法忽略pn,ps两个参数 回源方法忽略pn参数 加缓存方法忽略origin参数 |
| -custom_method | false | 自定义方法名 \|分隔 缓存获取方法名\|回源方法名\|增加缓存方法名 | d.mc.AddSubject\|d.mysql.Subject\|d.mc.AddSubject |

@ -0,0 +1,31 @@
package main
var _headerTemplate = `
// Code generated by kratos tool genbts. DO NOT EDIT.
NEWLINE
/*
Package {{.PkgName}} is a generated cache proxy package.
It is generated from:
ARGS
*/
NEWLINE
package {{.PkgName}}
import (
"context"
{{if .EnableBatch }}"sync"{{end}}
NEWLINE
"github.com/bilibili/kratos/pkg/stat/prom"
{{if .EnableBatch }}"github.com/bilibili/kratos/pkg/sync/errgroup"{{end}}
{{.ImportPackage}}
NEWLINE
{{if .EnableSingleFlight}} "golang.org/x/sync/singleflight" {{end}}
)
var _ _bts
{{if .EnableSingleFlight}}
var cacheSingleFlights = [SFCOUNT]*singleflight.Group{SFINIT}
{{end }}
`

@ -0,0 +1,482 @@
package main
import (
"bytes"
"flag"
"fmt"
"go/ast"
"io/ioutil"
"log"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"text/template"
"github.com/bilibili/kratos/tool/pkg"
)
var (
// arguments
singleFlight = flag.Bool("singleflight", false, "enable singleflight")
nullCache = flag.String("nullcache", "", "null cache")
checkNullCode = flag.String("check_null_code", "", "check null code")
batchSize = flag.Int("batch", 0, "batch size")
batchErr = flag.String("batch_err", "break", "batch err to contine or break")
maxGroup = flag.Int("max_group", 0, "max group size")
sync = flag.Bool("sync", false, "add cache in sync way.")
paging = flag.Bool("paging", false, "use paging in single template")
ignores = flag.String("ignores", "", "ignore params")
customMethod = flag.String("custom_method", "", "自定义方法名 |分隔: 缓存|回源|增加缓存")
numberTypes = []string{"int", "int8", "int16", "int32", "int64", "float32", "float64", "uint", "uint8", "uint16", "uint32", "uint64"}
simpleTypes = []string{"int", "int8", "int16", "int32", "int64", "float32", "float64", "uint", "uint8", "uint16", "uint32", "uint64", "bool", "string", "[]byte"}
optionNames = []string{"singleflight", "nullcache", "check_null_code", "batch", "max_group", "sync", "paging", "ignores", "batch_err", "custom_method"}
optionNamesMap = map[string]bool{}
)
const (
_interfaceName = "_bts"
_multiTpl = 1
_singleTpl = 2
_noneTpl = 3
)
func resetFlag() {
*singleFlight = false
*nullCache = ""
*checkNullCode = ""
*batchSize = 0
*maxGroup = 0
*sync = false
*paging = false
*batchErr = "break"
*ignores = ""
*customMethod = ""
}
// options options
type options struct {
name string
keyType string
valueType string
cacheFunc string
rawFunc string
addCacheFunc string
template int
SimpleValue bool
NumberValue bool
GoValue bool
ZeroValue string
ImportPackage string
importPackages []string
Args string
PkgName string
EnableSingleFlight bool
NullCache string
EnableNullCache bool
GroupSize int
MaxGroup int
EnableBatch bool
BatchErrBreak bool
Sync bool
CheckNullCode string
ExtraArgsType string
ExtraArgs string
ExtraCacheArgs string
ExtraRawArgs string
ExtraAddCacheArgs string
EnablePaging bool
Comment string
CustomMethod string
IDName string
}
func getOptions(opt *options, comment string) {
os.Args = []string{os.Args[0]}
if regexp.MustCompile(`\s+//\s*bts:.+`).Match([]byte(comment)) {
args := strings.Split(pkg.RegexpReplace(`//\s*bts:(?P<arg>.+)`, comment, "$arg"), " ")
for _, arg := range args {
arg = strings.TrimSpace(arg)
if arg != "" {
// validate option name
argName := pkg.RegexpReplace(`-(?P<name>[\w_-]+)=.+`, arg, "$name")
if !optionNamesMap[argName] {
log.Fatalf("选项:%s 不存在 请检查拼写\n", argName)
}
os.Args = append(os.Args, arg)
}
}
}
resetFlag()
flag.Parse()
opt.EnableSingleFlight = *singleFlight
opt.NullCache = *nullCache
opt.EnablePaging = *paging
opt.EnableNullCache = *nullCache != ""
opt.EnableBatch = (*batchSize != 0) && (*maxGroup != 0)
opt.BatchErrBreak = *batchErr == "break"
opt.Sync = *sync
opt.CheckNullCode = *checkNullCode
opt.GroupSize = *batchSize
opt.MaxGroup = *maxGroup
opt.CustomMethod = *customMethod
}
func processList(s *pkg.Source, list *ast.Field) (opt options) {
fset := s.Fset
src := s.Src
lines := strings.Split(src, "\n")
opt = options{Args: s.GetDef(_interfaceName), importPackages: s.Packages(list)}
// get comment
line := fset.Position(list.Pos()).Line - 3
if len(lines)-1 >= line {
comment := lines[line]
opt.Comment = pkg.RegexpReplace(`\s+//(?P<name>.+)`, comment, "$name")
opt.Comment = strings.TrimSpace(opt.Comment)
}
// get options
line = fset.Position(list.Pos()).Line - 2
comment := lines[line]
getOptions(&opt, comment)
// get func
opt.name = list.Names[0].Name
params := list.Type.(*ast.FuncType).Params.List
if len(params) == 0 {
log.Fatalln(opt.name + "参数不足")
}
for _, p := range params {
if len(p.Names) > 1 {
log.Fatalln(opt.name + "不支持省略类型 请写全声明中的字段类型名称")
}
}
if s.ExprString(params[0].Type) != "context.Context" {
log.Fatalln("第一个参数必须为context")
}
if len(params) == 1 {
opt.template = _noneTpl
} else {
opt.IDName = params[1].Names[0].Name
if _, ok := params[1].Type.(*ast.ArrayType); ok {
opt.template = _multiTpl
} else {
opt.template = _singleTpl
// get key
opt.keyType = s.ExprString(params[1].Type)
}
}
if len(params) > 2 {
var args []string
var allArgs []string
for _, pa := range params[2:] {
paType := s.ExprString(pa.Type)
if len(pa.Names) == 0 {
args = append(args, paType)
allArgs = append(allArgs, paType)
continue
}
var names []string
for _, name := range pa.Names {
names = append(names, name.Name)
}
allArgs = append(allArgs, strings.Join(names, ",")+" "+paType)
args = append(args, names...)
}
opt.ExtraArgs = strings.Join(args, ",")
opt.ExtraArgsType = strings.Join(allArgs, ",")
argsMap := make(map[string]bool)
for _, arg := range args {
argsMap[arg] = true
}
ignoreCache := make(map[string]bool)
ignoreRaw := make(map[string]bool)
ignoreAddCache := make(map[string]bool)
ignoreArray := [3]map[string]bool{ignoreCache, ignoreRaw, ignoreAddCache}
if *ignores != "" {
is := strings.Split(*ignores, "|")
if len(is) > 3 {
log.Fatalln("ignores参数错误")
}
for i := range is {
if len(is) > i {
for _, s := range strings.Split(is[i], ",") {
ignoreArray[i][s] = true
}
}
}
}
var as []string
for _, arg := range args {
if !ignoreCache[arg] {
as = append(as, arg)
}
}
opt.ExtraCacheArgs = strings.Join(as, ",")
as = []string{}
for _, arg := range args {
if !ignoreRaw[arg] {
as = append(as, arg)
}
}
opt.ExtraRawArgs = strings.Join(as, ",")
as = []string{}
for _, arg := range args {
if !ignoreAddCache[arg] {
as = append(as, arg)
}
}
opt.ExtraAddCacheArgs = strings.Join(as, ",")
if opt.ExtraAddCacheArgs != "" {
opt.ExtraAddCacheArgs = "," + opt.ExtraAddCacheArgs
}
if opt.ExtraRawArgs != "" {
opt.ExtraRawArgs = "," + opt.ExtraRawArgs
}
if opt.ExtraCacheArgs != "" {
opt.ExtraCacheArgs = "," + opt.ExtraCacheArgs
}
if opt.ExtraArgs != "" {
opt.ExtraArgs = "," + opt.ExtraArgs
}
if opt.ExtraArgsType != "" {
opt.ExtraArgsType = "," + opt.ExtraArgsType
}
}
// get k v from results
results := list.Type.(*ast.FuncType).Results.List
if len(results) != 2 {
log.Fatalln(opt.name + ": 参数个数不对")
}
if s.ExprString(results[1].Type) != "error" {
log.Fatalln(opt.name + ": 最后返回值参数需为error")
}
if opt.template == _multiTpl {
p, ok := results[0].Type.(*ast.MapType)
if !ok {
log.Fatalln(opt.name + ": 批量获取方法 返回值类型需为map类型")
}
opt.keyType = s.ExprString(p.Key)
opt.valueType = s.ExprString(p.Value)
} else {
opt.valueType = s.ExprString(results[0].Type)
}
for _, t := range numberTypes {
if t == opt.valueType {
opt.NumberValue = true
break
}
}
opt.ZeroValue = "nil"
for _, t := range simpleTypes {
if t == opt.valueType {
opt.SimpleValue = true
opt.ZeroValue = zeroValue(t)
break
}
}
if !opt.SimpleValue {
for _, t := range []string{"[]", "map"} {
if strings.HasPrefix(opt.valueType, t) {
opt.GoValue = true
break
}
}
}
upperName := strings.ToUpper(opt.name[0:1]) + opt.name[1:]
opt.cacheFunc = fmt.Sprintf("d.Cache%s", upperName)
opt.rawFunc = fmt.Sprintf("d.Raw%s", upperName)
opt.addCacheFunc = fmt.Sprintf("d.AddCache%s", upperName)
if opt.CustomMethod != "" {
arrs := strings.Split(opt.CustomMethod, "|")
if len(arrs) > 0 && arrs[0] != "" {
opt.cacheFunc = arrs[0]
}
if len(arrs) > 1 && arrs[1] != "" {
opt.rawFunc = arrs[1]
}
if len(arrs) > 2 && arrs[2] != "" {
opt.addCacheFunc = arrs[2]
}
}
return
}
// parse parse options
func parse(s *pkg.Source) (opts []*options) {
c := s.F.Scope.Lookup(_interfaceName)
if (c == nil) || (c.Kind != ast.Typ) {
log.Fatalln("无法找到缓存声明")
}
lists := c.Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType).Methods.List
for _, list := range lists {
opt := processList(s, list)
opt.Check()
opts = append(opts, &opt)
}
return
}
func (option *options) Check() {
if !option.SimpleValue && !strings.Contains(option.valueType, "*") && !strings.Contains(option.valueType, "[]") && !strings.Contains(option.valueType, "map") {
log.Fatalf("%s: 值类型只能为基本类型/slice/map/指针类型\n", option.name)
}
if option.EnableSingleFlight && option.EnableBatch {
log.Fatalf("%s: 单飞和批量获取不能同时开启\n", option.name)
}
if option.template != _singleTpl && option.EnablePaging {
log.Fatalf("%s: 分页只能用在单key模板中\n", option.name)
}
if option.SimpleValue && !option.EnableNullCache {
if !((option.template == _multiTpl) && option.NumberValue) {
log.Fatalf("%s: 值为基本类型时需开启空缓存 防止缓存零值穿透\n", option.name)
}
}
if option.EnableNullCache {
if !option.SimpleValue && option.CheckNullCode == "" {
log.Fatalf("%s: 缺少-check_null_code参数\n", option.name)
}
if option.SimpleValue && option.NullCache == option.ZeroValue {
log.Fatalf("%s: %s 不能作为空缓存值 \n", option.name, option.NullCache)
}
if strings.Contains(option.NullCache, "{}") {
// -nullcache=[]*model.OrderMain{} 这种无效
log.Fatalf("%s: %s 不能作为空缓存值 会导致空缓存无效 \n", option.name, option.NullCache)
}
if strings.Contains(option.CheckNullCode, "len") && strings.Contains(strings.Replace(option.CheckNullCode, " ", "", -1), "==0") {
// -check_null_code=len($)==0 这种无效
log.Fatalf("%s: -check_null_code=%s 错误 会有无意义的赋值\n", option.name, option.CheckNullCode)
}
}
}
func genHeader(opts []*options) (src string) {
option := options{PkgName: os.Getenv("GOPACKAGE")}
var sfCount int
var packages, sfInit []string
packagesMap := map[string]bool{`"context"`: true}
for _, opt := range opts {
if opt.EnableSingleFlight {
option.EnableSingleFlight = true
sfCount++
}
if opt.EnableBatch {
option.EnableBatch = true
}
if len(opt.importPackages) > 0 {
for _, pkg := range opt.importPackages {
if !packagesMap[pkg] {
packages = append(packages, pkg)
packagesMap[pkg] = true
}
}
}
if opt.Args != "" {
option.Args = opt.Args
}
}
option.ImportPackage = strings.Join(packages, "\n")
for i := 0; i < sfCount; i++ {
sfInit = append(sfInit, "{}")
}
src = _headerTemplate
src = strings.Replace(src, "SFCOUNT", strconv.Itoa(sfCount), -1)
t := template.Must(template.New("header").Parse(src))
var buffer bytes.Buffer
err := t.Execute(&buffer, option)
if err != nil {
log.Fatalf("execute template: %s", err)
}
// Format the output.
src = strings.Replace(buffer.String(), "\t", "", -1)
src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n")
src = strings.Replace(src, "NEWLINE", "", -1)
src = strings.Replace(src, "ARGS", option.Args, -1)
src = strings.Replace(src, "SFINIT", strings.Join(sfInit, ","), -1)
return
}
func genBody(opts []*options) (res string) {
sfnum := -1
for _, option := range opts {
var nullCodeVar, src string
if option.template == _multiTpl {
src = _multiTemplate
nullCodeVar = "v"
} else if option.template == _singleTpl {
src = _singleTemplate
nullCodeVar = "res"
} else {
src = _noneTemplate
nullCodeVar = "res"
}
if option.template != _noneTpl {
src = strings.Replace(src, "KEY", option.keyType, -1)
}
if option.CheckNullCode != "" {
option.CheckNullCode = strings.Replace(option.CheckNullCode, "$", nullCodeVar, -1)
}
if option.EnableSingleFlight {
sfnum++
}
src = strings.Replace(src, "NAME", option.name, -1)
src = strings.Replace(src, "VALUE", option.valueType, -1)
src = strings.Replace(src, "ADDCACHEFUNC", option.addCacheFunc, -1)
src = strings.Replace(src, "CACHEFUNC", option.cacheFunc, -1)
src = strings.Replace(src, "RAWFUNC", option.rawFunc, -1)
src = strings.Replace(src, "GROUPSIZE", strconv.Itoa(option.GroupSize), -1)
src = strings.Replace(src, "MAXGROUP", strconv.Itoa(option.MaxGroup), -1)
src = strings.Replace(src, "SFNUM", strconv.Itoa(sfnum), -1)
t := template.Must(template.New("cache").Parse(src))
var buffer bytes.Buffer
err := t.Execute(&buffer, option)
if err != nil {
log.Fatalf("execute template: %s", err)
}
// Format the output.
src = strings.Replace(buffer.String(), "\t", "", -1)
src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n")
res = res + "\n" + src
}
return
}
func zeroValue(t string) string {
switch t {
case "bool":
return "false"
case "string":
return "\"\""
case "[]byte":
return "nil"
default:
return "0"
}
}
func init() {
for _, name := range optionNames {
optionNamesMap[name] = true
}
}
func main() {
log.SetFlags(0)
defer func() {
if err := recover(); err != nil {
log.Fatalf("程序解析失败, err: %+v", err)
}
}()
options := parse(pkg.NewSource(pkg.SourceText()))
header := genHeader(options)
body := genBody(options)
code := pkg.FormatCode(header + "\n" + body)
// Write to file.
dir := filepath.Dir(".")
outputName := filepath.Join(dir, "dao.bts.go")
err := ioutil.WriteFile(outputName, []byte(code), 0644)
if err != nil {
log.Fatalf("写入文件失败: %s", err)
}
log.Println("dao.bts.go: 生成成功")
}

@ -0,0 +1,126 @@
package main
var _multiTemplate = `
// NAME {{or .Comment "get data from cache if miss will call source method, then add to cache."}}
func (d *Dao) NAME(c context.Context, {{.IDName}} []KEY{{.ExtraArgsType}}) (res map[KEY]VALUE, err error) {
if len({{.IDName}}) == 0 {
return
}
addCache := true
if res, err = CACHEFUNC(c, {{.IDName}} {{.ExtraCacheArgs}});err != nil {
addCache = false
res = nil
err = nil
}
var miss []KEY
for _, key := range {{.IDName}} {
{{if .GoValue}}
if (res == nil) || (len(res[key]) == 0) {
{{else}}
{{if .NumberValue}}
if _, ok := res[key]; !ok {
{{else}}
if (res == nil) || (res[key] == {{.ZeroValue}}) {
{{end}}
{{end}}
miss = append(miss, key)
}
}
prom.CacheHit.Add("NAME", int64(len({{.IDName}}) - len(miss)))
{{if .EnableNullCache}}
for k, v := range res {
{{if .SimpleValue}} if v == {{.NullCache}} { {{else}} if {{.CheckNullCode}} { {{end}}
delete(res, k)
}
}
{{end}}
missLen := len(miss)
if missLen == 0 {
return
}
{{if .EnableBatch}}
missData := make(map[KEY]VALUE, missLen)
{{else}}
var missData map[KEY]VALUE
{{end}}
{{if .EnableSingleFlight}}
var rr interface{}
sf := d.cacheSFNAME({{.IDName}} {{.ExtraArgs}})
rr, err, _ = cacheSingleFlights[SFNUM].Do(sf, func() (r interface{}, e error) {
prom.CacheMiss.Add("NAME", int64(len(miss)))
r, e = RAWFUNC(c, miss {{.ExtraRawArgs}})
return
})
missData = rr.(map[KEY]VALUE)
{{else}}
{{if .EnableBatch}}
prom.CacheMiss.Add("NAME", int64(missLen))
var mutex sync.Mutex
{{if .BatchErrBreak}}
group := errgroup.WithCancel(c)
{{else}}
group := &errgroup.WithContext(c)
{{end}}
if missLen > MAXGROUP {
group.GOMAXPROCS(MAXGROUP)
}
var run = func(ms []KEY) {
group.Go(func(ctx context.Context) (err error) {
data, err := RAWFUNC(ctx, ms {{.ExtraRawArgs}})
mutex.Lock()
for k, v := range data {
missData[k] = v
}
mutex.Unlock()
return
})
}
var (
i int
n = missLen/GROUPSIZE
)
for i=0; i< n; i++{
run(miss[i*GROUPSIZE:(i+1)*GROUPSIZE])
}
if len(miss[i*GROUPSIZE:]) > 0 {
run(miss[i*GROUPSIZE:])
}
err = group.Wait()
{{else}}
prom.CacheMiss.Add("NAME", int64(len(miss)))
missData, err = RAWFUNC(c, miss {{.ExtraRawArgs}})
{{end}}
{{end}}
if res == nil {
res = make(map[KEY]VALUE, len({{.IDName}}))
}
for k, v := range missData {
res[k] = v
}
if err != nil {
return
}
{{if .EnableNullCache}}
for _, key := range miss {
{{if .GoValue}}
if len(res[key]) == 0 {
{{else}}
if res[key] == {{.ZeroValue}} {
{{end}}
missData[key] = {{.NullCache}}
}
}
{{end}}
if !addCache {
return
}
{{if .Sync}}
ADDCACHEFUNC(c, missData {{.ExtraAddCacheArgs}})
{{else}}
d.cache.Do(c, func(c context.Context) {
ADDCACHEFUNC(c, missData {{.ExtraAddCacheArgs}})
})
{{end}}
return
}
`

@ -0,0 +1,65 @@
package main
var _noneTemplate = `
// NAME {{or .Comment "get data from cache if miss will call source method, then add to cache."}}
func (d *Dao) NAME(c context.Context) (res VALUE, err error) {
addCache := true
res, err = CACHEFUNC(c)
if err != nil {
addCache = false
err = nil
}
{{if .EnableNullCache}}
defer func() {
{{if .SimpleValue}} if res == {{.NullCache}} { {{else}} if {{.CheckNullCode}} { {{end}}
res = {{.ZeroValue}}
}
}()
{{end}}
{{if .GoValue}}
if len(res) != 0 {
{{else}}
if res != {{.ZeroValue}} {
{{end}}
prom.CacheHit.Incr("NAME")
return
}
{{if .EnableSingleFlight}}
var rr interface{}
sf := d.cacheSFNAME()
rr, err, _ = cacheSingleFlights[SFNUM].Do(sf, func() (r interface{}, e error) {
prom.CacheMiss.Incr("NAME")
r, e = RAWFUNC(c)
return
})
res = rr.(VALUE)
{{else}}
prom.CacheMiss.Incr("NAME")
res, err = RAWFUNC(c)
{{end}}
if err != nil {
return
}
var miss = res
{{if .EnableNullCache}}
{{if .GoValue}}
if len(miss) == 0 {
{{else}}
if miss == {{.ZeroValue}} {
{{end}}
miss = {{.NullCache}}
}
{{end}}
if !addCache {
return
}
{{if .Sync}}
ADDCACHEFUNC(c, miss)
{{else}}
d.cache.Do(c, func(c context.Context) {
ADDCACHEFUNC(c, miss)
})
{{end}}
return
}
`

@ -0,0 +1,86 @@
package main
var _singleTemplate = `
// NAME {{or .Comment "get data from cache if miss will call source method, then add to cache."}}
func (d *Dao) NAME(c context.Context, {{.IDName}} KEY{{.ExtraArgsType}}) (res VALUE, err error) {
addCache := true
res, err = CACHEFUNC(c, {{.IDName}} {{.ExtraCacheArgs}})
if err != nil {
addCache = false
err = nil
}
{{if .EnableNullCache}}
defer func() {
{{if .SimpleValue}} if res == {{.NullCache}} { {{else}} if {{.CheckNullCode}} { {{end}}
res = {{.ZeroValue}}
}
}()
{{end}}
{{if .GoValue}}
if len(res) != 0 {
{{else}}
if res != {{.ZeroValue}} {
{{end}}
prom.CacheHit.Incr("NAME")
return
}
{{if .EnablePaging}}
var miss VALUE
{{end}}
{{if .EnableSingleFlight}}
var rr interface{}
sf := d.cacheSFNAME({{.IDName}} {{.ExtraArgs}})
rr, err, _ = cacheSingleFlights[SFNUM].Do(sf, func() (r interface{}, e error) {
prom.CacheMiss.Incr("NAME")
{{if .EnablePaging}}
var rrs [2]interface{}
rrs[0], rrs[1], e = RAWFUNC(c, {{.IDName}} {{.ExtraRawArgs}})
r = rrs
{{else}}
r, e = RAWFUNC(c, {{.IDName}} {{.ExtraRawArgs}})
{{end}}
return
})
{{if .EnablePaging}}
res = rr.([2]interface{})[0].(VALUE)
miss = rr.([2]interface{})[1].(VALUE)
{{else}}
res = rr.(VALUE)
{{end}}
{{else}}
prom.CacheMiss.Incr("NAME")
{{if .EnablePaging}}
res, miss, err = RAWFUNC(c, {{.IDName}} {{.ExtraRawArgs}})
{{else}}
res, err = RAWFUNC(c, {{.IDName}} {{.ExtraRawArgs}})
{{end}}
{{end}}
if err != nil {
return
}
{{if .EnablePaging}}
{{else}}
miss := res
{{end}}
{{if .EnableNullCache}}
{{if .GoValue}}
if len(miss) == 0 {
{{else}}
if miss == {{.ZeroValue}} {
{{end}}
miss = {{.NullCache}}
}
{{end}}
if !addCache {
return
}
{{if .Sync}}
ADDCACHEFUNC(c, {{.IDName}}, miss {{.ExtraAddCacheArgs}})
{{else}}
d.cache.Do(c, func(c context.Context) {
ADDCACHEFUNC(c, {{.IDName}}, miss {{.ExtraAddCacheArgs}})
})
{{end}}
return
}
`

@ -0,0 +1,202 @@
// Code generated by kratos tool btsgen. DO NOT EDIT.
/*
Package testdata is a generated cache proxy package.
It is generated from:
type _bts interface {
// bts: -batch=2 -max_group=20 -batch_err=break -nullcache=&Demo{ID:-1} -check_null_code=$.ID==-1
Demos(c context.Context, keys []int64) (map[int64]*Demo, error)
// bts: -sync=true -nullcache=&Demo{ID:-1} -check_null_code=$.ID==-1
Demo(c context.Context, key int64) (*Demo, error)
// bts: -paging=true
Demo1(c context.Context, key int64, pn int, ps int) (*Demo, error)
// bts: -nullcache=&Demo{ID:-1} -check_null_code=$.ID==-1
None(c context.Context) (*Demo, error)
}
*/
package testdata
import (
"context"
"sync"
"github.com/bilibili/kratos/pkg/stat/prom"
"github.com/bilibili/kratos/pkg/sync/errgroup"
)
var _ _bts
// Demos get data from cache if miss will call source method, then add to cache.
func (d *Dao) Demos(c context.Context, keys []int64) (res map[int64]*Demo, err error) {
if len(keys) == 0 {
return
}
addCache := true
if res, err = d.CacheDemos(c, keys); err != nil {
addCache = false
res = nil
err = nil
}
var miss []int64
for _, key := range keys {
if (res == nil) || (res[key] == nil) {
miss = append(miss, key)
}
}
prom.CacheHit.Add("Demos", int64(len(keys)-len(miss)))
for k, v := range res {
if v.ID == -1 {
delete(res, k)
}
}
missLen := len(miss)
if missLen == 0 {
return
}
missData := make(map[int64]*Demo, missLen)
prom.CacheMiss.Add("Demos", int64(missLen))
var mutex sync.Mutex
group := errgroup.WithCancel(c)
if missLen > 20 {
group.GOMAXPROCS(20)
}
var run = func(ms []int64) {
group.Go(func(ctx context.Context) (err error) {
data, err := d.RawDemos(ctx, ms)
mutex.Lock()
for k, v := range data {
missData[k] = v
}
mutex.Unlock()
return
})
}
var (
i int
n = missLen / 2
)
for i = 0; i < n; i++ {
run(miss[i*2 : (i+1)*2])
}
if len(miss[i*2:]) > 0 {
run(miss[i*2:])
}
err = group.Wait()
if res == nil {
res = make(map[int64]*Demo, len(keys))
}
for k, v := range missData {
res[k] = v
}
if err != nil {
return
}
for _, key := range miss {
if res[key] == nil {
missData[key] = &Demo{ID: -1}
}
}
if !addCache {
return
}
d.cache.Do(c, func(c context.Context) {
d.AddCacheDemos(c, missData)
})
return
}
// Demo get data from cache if miss will call source method, then add to cache.
func (d *Dao) Demo(c context.Context, key int64) (res *Demo, err error) {
addCache := true
res, err = d.CacheDemo(c, key)
if err != nil {
addCache = false
err = nil
}
defer func() {
if res.ID == -1 {
res = nil
}
}()
if res != nil {
prom.CacheHit.Incr("Demo")
return
}
prom.CacheMiss.Incr("Demo")
res, err = d.RawDemo(c, key)
if err != nil {
return
}
miss := res
if miss == nil {
miss = &Demo{ID: -1}
}
if !addCache {
return
}
d.AddCacheDemo(c, key, miss)
return
}
// Demo1 get data from cache if miss will call source method, then add to cache.
func (d *Dao) Demo1(c context.Context, key int64, pn int, ps int) (res *Demo, err error) {
addCache := true
res, err = d.CacheDemo1(c, key, pn, ps)
if err != nil {
addCache = false
err = nil
}
if res != nil {
prom.CacheHit.Incr("Demo1")
return
}
var miss *Demo
prom.CacheMiss.Incr("Demo1")
res, miss, err = d.RawDemo1(c, key, pn, ps)
if err != nil {
return
}
if !addCache {
return
}
d.cache.Do(c, func(c context.Context) {
d.AddCacheDemo1(c, key, miss, pn, ps)
})
return
}
// None get data from cache if miss will call source method, then add to cache.
func (d *Dao) None(c context.Context) (res *Demo, err error) {
addCache := true
res, err = d.CacheNone(c)
if err != nil {
addCache = false
err = nil
}
defer func() {
if res.ID == -1 {
res = nil
}
}()
if res != nil {
prom.CacheHit.Incr("None")
return
}
prom.CacheMiss.Incr("None")
res, err = d.RawNone(c)
if err != nil {
return
}
var miss = res
if miss == nil {
miss = &Demo{ID: -1}
}
if !addCache {
return
}
d.cache.Do(c, func(c context.Context) {
d.AddCacheNone(c, miss)
})
return
}

@ -0,0 +1,35 @@
package testdata
import (
"context"
"github.com/bilibili/kratos/pkg/sync/pipeline/fanout"
)
// Demo test struct
type Demo struct {
ID int64
Title string
}
// Dao .
type Dao struct {
cache *fanout.Fanout
}
// New .
func New() *Dao {
return &Dao{cache: fanout.New("cache")}
}
//go:generate kratos tool genbts
type _bts interface {
// bts: -batch=2 -max_group=20 -batch_err=break -nullcache=&Demo{ID:-1} -check_null_code=$.ID==-1
Demos(c context.Context, keys []int64) (map[int64]*Demo, error)
// bts: -sync=true -nullcache=&Demo{ID:-1} -check_null_code=$.ID==-1
Demo(c context.Context, key int64) (*Demo, error)
// bts: -paging=true
Demo1(c context.Context, key int64, pn int, ps int) (*Demo, error)
// bts: -nullcache=&Demo{ID:-1} -check_null_code=$.ID==-1
None(c context.Context) (*Demo, error)
}

@ -0,0 +1,30 @@
package testdata
import (
"context"
)
// mock test
var (
_multiCacheFunc func(c context.Context, keys []int64) (map[int64]*Demo, error)
_multiRawFunc func(c context.Context, keys []int64) (map[int64]*Demo, error)
_multiAddCacheFunc func(c context.Context, values map[int64]*Demo) error
)
// CacheDemos .
func (d *Dao) CacheDemos(c context.Context, keys []int64) (map[int64]*Demo, error) {
// get data from cache
return _multiCacheFunc(c, keys)
}
// RawDemos .
func (d *Dao) RawDemos(c context.Context, keys []int64) (map[int64]*Demo, error) {
// get data from db
return _multiRawFunc(c, keys)
}
// AddCacheDemos .
func (d *Dao) AddCacheDemos(c context.Context, values map[int64]*Demo) error {
// add to cache
return _multiAddCacheFunc(c, values)
}

@ -0,0 +1,67 @@
package testdata
import (
"context"
"errors"
"testing"
)
func TestMultiCache(t *testing.T) {
id := int64(1)
d := New()
meta := map[int64]*Demo{id: {ID: id}}
getsFromCache := func(c context.Context, keys []int64) (map[int64]*Demo, error) { return meta, nil }
notGetsFromCache := func(c context.Context, keys []int64) (map[int64]*Demo, error) { return nil, errors.New("err") }
// 缓存返回了部分数据
partFromCache := func(c context.Context, keys []int64) (map[int64]*Demo, error) { return meta, errors.New("err") }
getsFromSource := func(c context.Context, keys []int64) (map[int64]*Demo, error) { return meta, nil }
notGetsFromSource := func(c context.Context, keys []int64) (map[int64]*Demo, error) {
return meta, errors.New("err")
}
addToCache := func(c context.Context, values map[int64]*Demo) error { return nil }
// gets from cache
_multiCacheFunc = getsFromCache
_multiRawFunc = notGetsFromSource
_multiAddCacheFunc = addToCache
res, err := d.Demos(context.TODO(), []int64{id})
if err != nil {
t.Fatalf("err should be nil, get: %v", err)
}
if res[1].ID != 1 {
t.Fatalf("id should be 1")
}
// get from source
_multiCacheFunc = notGetsFromCache
_multiRawFunc = getsFromSource
res, err = d.Demos(context.TODO(), []int64{1, 2, 3, 4, 5, 6})
if err != nil {
t.Fatalf("err should be nil, get: %v", err)
}
if res[1].ID != 1 {
t.Fatalf("id should be 1")
}
// 缓存失败 返回部分数据 回源也失败的情况
_multiCacheFunc = partFromCache
_multiRawFunc = notGetsFromSource
res, err = d.Demos(context.TODO(), []int64{id})
if err == nil {
t.Fatalf("err should be nil, get: %v", err)
}
if res[1].ID != 1 {
t.Fatalf("id should be 1")
}
// with null cache
nullCache := &Demo{ID: -1}
getNullFromCache := func(c context.Context, keys []int64) (map[int64]*Demo, error) {
return map[int64]*Demo{id: nullCache}, nil
}
_multiCacheFunc = getNullFromCache
_multiRawFunc = notGetsFromSource
res, err = d.Demos(context.TODO(), []int64{id})
if err != nil {
t.Fatalf("err should be nil, get: %v", err)
}
if res[id] != nil {
t.Fatalf("res should be nil")
}
}

@ -0,0 +1,30 @@
package testdata
import (
"context"
)
// mock test
var (
_noneCacheFunc func(c context.Context) (*Demo, error)
_noneRawFunc func(c context.Context) (*Demo, error)
_noneAddCacheFunc func(c context.Context, value *Demo) error
)
// CacheNone .
func (d *Dao) CacheNone(c context.Context) (*Demo, error) {
// get data from cache
return _noneCacheFunc(c)
}
// RawNone .
func (d *Dao) RawNone(c context.Context) (*Demo, error) {
// get data from db
return _noneRawFunc(c)
}
// AddCacheNone .
func (d *Dao) AddCacheNone(c context.Context, value *Demo) error {
// add to cache
return _noneAddCacheFunc(c, value)
}

@ -0,0 +1,50 @@
package testdata
import (
"context"
"errors"
"testing"
)
func TestNoneCache(t *testing.T) {
d := New()
meta := &Demo{ID: 1}
getFromCache := func(c context.Context) (*Demo, error) { return meta, nil }
notGetFromCache := func(c context.Context) (*Demo, error) { return nil, errors.New("err") }
getFromSource := func(c context.Context) (*Demo, error) { return meta, nil }
notGetFromSource := func(c context.Context) (*Demo, error) { return meta, errors.New("err") }
addToCache := func(c context.Context, values *Demo) error { return nil }
// get from cache
_noneCacheFunc = getFromCache
_noneRawFunc = notGetFromSource
_noneAddCacheFunc = addToCache
res, err := d.None(context.TODO())
if err != nil {
t.Fatalf("err should be nil, get: %v", err)
}
if res.ID != 1 {
t.Fatalf("id should be 1")
}
// get from source
_noneCacheFunc = notGetFromCache
_noneRawFunc = getFromSource
res, err = d.None(context.TODO())
if err != nil {
t.Fatalf("err should be nil, get: %v", err)
}
if res.ID != 1 {
t.Fatalf("id should be 1")
}
// with null cache
nullCache := &Demo{ID: -1}
getNullFromCache := func(c context.Context) (*Demo, error) { return nullCache, nil }
_noneCacheFunc = getNullFromCache
_noneRawFunc = notGetFromSource
res, err = d.None(context.TODO())
if err != nil {
t.Fatalf("err should be nil, get: %v", err)
}
if res != nil {
t.Fatalf("res should be nil")
}
}

@ -0,0 +1,48 @@
package testdata
import (
"context"
)
// mock test
var (
_singleCacheFunc func(c context.Context, key int64) (*Demo, error)
_singleRawFunc func(c context.Context, key int64) (*Demo, error)
_singleAddCacheFunc func(c context.Context, key int64, value *Demo) error
)
// CacheDemo .
func (d *Dao) CacheDemo(c context.Context, key int64) (*Demo, error) {
// get data from cache
return _singleCacheFunc(c, key)
}
// RawDemo .
func (d *Dao) RawDemo(c context.Context, key int64) (*Demo, error) {
// get data from db
return _singleRawFunc(c, key)
}
// AddCacheDemo .
func (d *Dao) AddCacheDemo(c context.Context, key int64, value *Demo) error {
// add to cache
return _singleAddCacheFunc(c, key, value)
}
// CacheDemo1 .
func (d *Dao) CacheDemo1(c context.Context, key int64, pn, ps int) (*Demo, error) {
// get data from cache
return nil, nil
}
// RawDemo1 .
func (d *Dao) RawDemo1(c context.Context, key int64, pn, ps int) (*Demo, *Demo, error) {
// get data from db
return nil, nil, nil
}
// AddCacheDemo1 .
func (d *Dao) AddCacheDemo1(c context.Context, key int64, value *Demo, pn, ps int) error {
// add to cache
return nil
}

@ -0,0 +1,50 @@
package testdata
import (
"context"
"errors"
"testing"
)
func TestSingleCache(t *testing.T) {
d := New()
meta := &Demo{ID: 1}
getFromCache := func(c context.Context, id int64) (*Demo, error) { return meta, nil }
notGetFromCache := func(c context.Context, id int64) (*Demo, error) { return nil, errors.New("err") }
getFromSource := func(c context.Context, id int64) (*Demo, error) { return meta, nil }
notGetFromSource := func(c context.Context, id int64) (*Demo, error) { return meta, errors.New("err") }
addToCache := func(c context.Context, id int64, values *Demo) error { return nil }
// get from cache
_singleCacheFunc = getFromCache
_singleRawFunc = notGetFromSource
_singleAddCacheFunc = addToCache
res, err := d.Demo(context.TODO(), 1)
if err != nil {
t.Fatalf("err should be nil, get: %v", err)
}
if res.ID != 1 {
t.Fatalf("id should be 1")
}
// get from source
_singleCacheFunc = notGetFromCache
_singleRawFunc = getFromSource
res, err = d.Demo(context.TODO(), 1)
if err != nil {
t.Fatalf("err should be nil, get: %v", err)
}
if res.ID != 1 {
t.Fatalf("id should be 1")
}
// with null cache
nullCache := &Demo{ID: -1}
getNullFromCache := func(c context.Context, id int64) (*Demo, error) { return nullCache, nil }
_singleCacheFunc = getNullFromCache
_singleRawFunc = notGetFromSource
res, err = d.Demo(context.TODO(), 1)
if err != nil {
t.Fatalf("err should be nil, get: %v", err)
}
if res != nil {
t.Fatalf("res should be nil")
}
}

@ -0,0 +1,44 @@
#### genmc
> mc缓存代码生成
##### 项目简介
自动生成memcached缓存代码 和缓存回源工具kratos-gen-bts配合使用 体验更佳
支持以下功能:
- 常用mc命令(get/set/add/replace/delete)
- 多种数据存储格式(json/pb/raw/gob/gzip)
- 常用值类型自动转换(int/bool/float...)
- 自定义缓存名称和过期时间
- 记录pkg/error错误栈
- 记录日志trace id
- prometheus错误监控
- 自定义参数个数
- 自定义注释
##### 使用方式:
1. dao.go文件中新增 _mc interface
2. 在dao 文件夹中执行 go generate命令 将会生成相应的缓存代码
3. 示例见testdata/dao.go
##### 注意:
类型会根据前缀进行猜测
set / add 对应mc方法Set
replace 对应mc方法 Replace
del 对应mc方法 Delete
get / cache对应mc方法Get
mc Add方法需要用注解 -type=only_add单独指定
#### 注解参数:
| 名称 | 默认值 | 可用范围 | 说明 | 可选值 | 示例 |
| ----------- | ------------------- | ---------------- | ------------------------------------------------------------ | ---------------------------- | -------------------------- |
| encode | 根据值类型raw或json | set/add/replace | 数据存储的格式 | json/pb/raw/gob/gzip | json 或 json\|gzip 或gob等 |
| type | 前缀推断 | 全部 | mc方法 set/get/delete... | get/set/del/replace/only_add | get 或 replace 等 |
| key | 根据方法名称生成 | 全部 | 缓存key名称 | - | articleKey |
| expire | 根据方法名称生成 | 全部 | 缓存过期时间 | - | d.articleExpire |
| batch | | get(限多key模板) | 批量获取数据 每组大小 | - | 100 |
| max_group | | get(限多key模板) | 批量获取数据 最大组数量 | - | 10 |
| batch_err | break | get(限多key模板) | 批量获取数据回源错误的时候 降级继续请求(continue)还是直接返回(break) | break 或 continue | continue |
| struct_name | Dao | 全部 | 用户自定义Dao结构体名称 | | MemcacheDao |

@ -0,0 +1,30 @@
package main
var _headerTemplate = `
// Code generated by kratos tool mcgen. DO NOT EDIT.
NEWLINE
/*
Package {{.PkgName}} is a generated mc cache package.
It is generated from:
ARGS
*/
NEWLINE
package {{.PkgName}}
import (
"context"
"fmt"
{{if .UseStrConv}}"strconv"{{end}}
{{if .EnableBatch }}"sync"{{end}}
NEWLINE
"github.com/bilibili/kratos/pkg/stat/prom"
{{if .UseMemcached }}"github.com/bilibili/kratos/pkg/cache/memcache"{{end}}
{{if .EnableBatch }}"github.com/bilibili/kratos/pkg/sync/errgroup"{{end}}
"github.com/bilibili/kratos/pkg/log"
{{.ImportPackage}}
)
var _ _mc
`

@ -0,0 +1,548 @@
package main
import (
"bytes"
"flag"
"go/ast"
"io/ioutil"
"log"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"text/template"
"github.com/bilibili/kratos/tool/pkg"
)
var (
encode = flag.String("encode", "", "encode type: json/pb/raw/gob/gzip")
mcType = flag.String("type", "", "type: get/set/del/replace/only_add")
key = flag.String("key", "", "key name method")
expire = flag.String("expire", "", "expire time code")
structName = flag.String("struct_name", "Dao", "struct name")
batchSize = flag.Int("batch", 0, "batch size")
batchErr = flag.String("batch_err", "break", "batch err to contine or break")
maxGroup = flag.Int("max_group", 0, "max group size")
mcValidTypes = []string{"set", "replace", "del", "get", "only_add"}
mcValidPrefix = []string{"set", "replace", "del", "get", "cache", "add"}
optionNamesMap = map[string]bool{"batch": true, "max_group": true, "encode": true, "type": true, "key": true, "expire": true, "batch_err": true, "struct_name": true}
simpleTypes = []string{"int", "int8", "int16", "int32", "int64", "float32", "float64", "uint", "uint8", "uint16", "uint32", "uint64", "bool", "string", "[]byte"}
lenTypes = []string{"[]", "map"}
)
const (
_interfaceName = "_mc"
_multiTpl = 1
_singleTpl = 2
_noneTpl = 3
_typeGet = "get"
_typeSet = "set"
_typeDel = "del"
_typeReplace = "replace"
_typeAdd = "only_add"
)
func resetFlag() {
*encode = ""
*mcType = ""
*batchSize = 0
*maxGroup = 0
*batchErr = "break"
}
// options options
type options struct {
name string
keyType string
ValueType string
template int
SimpleValue bool
// int float 类型
GetSimpleValue bool
// string, []byte类型
GetDirectValue bool
ConvertValue2Bytes string
ConvertBytes2Value string
GoValue bool
ImportPackage string
importPackages []string
Args string
PkgName string
ExtraArgsType string
ExtraArgs string
MCType string
KeyMethod string
ExpireCode string
Encode string
UseMemcached bool
OriginValueType string
UseStrConv bool
Comment string
GroupSize int
MaxGroup int
EnableBatch bool
BatchErrBreak bool
LenType bool
PointType bool
StructName string
}
func getOptions(opt *options, comment string) {
os.Args = []string{os.Args[0]}
if regexp.MustCompile(`\s+//\s*mc:.+`).Match([]byte(comment)) {
args := strings.Split(pkg.RegexpReplace(`//\s*mc:(?P<arg>.+)`, comment, "$arg"), " ")
for _, arg := range args {
arg = strings.TrimSpace(arg)
if arg != "" {
// validate option name
argName := pkg.RegexpReplace(`-(?P<name>[\w_-]+)=.+`, arg, "$name")
if !optionNamesMap[argName] {
log.Fatalf("选项:%s 不存在 请检查拼写\n", argName)
}
os.Args = append(os.Args, arg)
}
}
}
resetFlag()
flag.Parse()
if *mcType != "" {
opt.MCType = *mcType
}
if *key != "" {
opt.KeyMethod = *key
}
if *expire != "" {
opt.ExpireCode = *expire
}
opt.EnableBatch = (*batchSize != 0) && (*maxGroup != 0)
opt.BatchErrBreak = *batchErr == "break"
opt.GroupSize = *batchSize
opt.MaxGroup = *maxGroup
opt.StructName = *structName
}
func getTypeFromPrefix(opt *options, params []*ast.Field, s *pkg.Source) {
if opt.MCType == "" {
for _, t := range mcValidPrefix {
if strings.HasPrefix(strings.ToLower(opt.name), t) {
if t == "add" {
t = _typeSet
}
opt.MCType = t
break
}
}
if opt.MCType == "" {
log.Fatalln(opt.name + "请指定方法类型(type=get/set/del...)")
}
}
if opt.MCType == "cache" {
opt.MCType = _typeGet
}
if len(params) == 0 {
log.Fatalln(opt.name + "参数不足")
}
for _, p := range params {
if len(p.Names) > 1 {
log.Fatalln(opt.name + "不支持省略类型 请写全声明中的字段类型名称")
}
}
if s.ExprString(params[0].Type) != "context.Context" {
log.Fatalln(opt.name + "第一个参数必须为context")
}
for _, param := range params {
if len(param.Names) > 1 {
log.Fatalln(opt.name + "不支持省略类型")
}
}
}
func processList(s *pkg.Source, list *ast.Field) (opt options) {
src := s.Src
fset := s.Fset
lines := strings.Split(src, "\n")
opt = options{Args: s.GetDef(_interfaceName), UseMemcached: true, importPackages: s.Packages(list)}
opt.name = list.Names[0].Name
opt.KeyMethod = "key" + opt.name
opt.ExpireCode = "d.mc" + opt.name + "Expire"
// get comment
line := fset.Position(list.Pos()).Line - 3
if len(lines)-1 >= line {
comment := lines[line]
opt.Comment = pkg.RegexpReplace(`\s+//(?P<name>.+)`, comment, "$name")
opt.Comment = strings.TrimSpace(opt.Comment)
}
// get options
line = fset.Position(list.Pos()).Line - 2
comment := lines[line]
getOptions(&opt, comment)
// get type from prefix
params := list.Type.(*ast.FuncType).Params.List
getTypeFromPrefix(&opt, params, s)
// get template
if len(params) == 1 {
opt.template = _noneTpl
} else if (len(params) == 2) && (opt.MCType == _typeSet || opt.MCType == _typeAdd || opt.MCType == _typeReplace) {
if _, ok := params[1].Type.(*ast.MapType); ok {
opt.template = _multiTpl
} else {
opt.template = _noneTpl
}
} else {
if _, ok := params[1].Type.(*ast.ArrayType); ok {
opt.template = _multiTpl
} else if _, ok := params[1].Type.(*ast.MapType); ok {
opt.template = _multiTpl
} else {
opt.template = _singleTpl
}
}
// extra args
if len(params) > 2 {
args := []string{""}
allArgs := []string{""}
var pos = 2
if (opt.MCType == _typeAdd) || (opt.MCType == _typeSet) || (opt.MCType == _typeReplace) {
pos = 3
}
if opt.template == _multiTpl && opt.MCType == _typeSet {
pos = 2
}
for _, pa := range params[pos:] {
paType := s.ExprString(pa.Type)
if len(pa.Names) == 0 {
args = append(args, paType)
allArgs = append(allArgs, paType)
continue
}
var names []string
for _, name := range pa.Names {
names = append(names, name.Name)
}
allArgs = append(allArgs, strings.Join(names, ",")+" "+paType)
args = append(args, strings.Join(names, ","))
}
if len(args) > 1 {
opt.ExtraArgs = strings.Join(args, ",")
opt.ExtraArgsType = strings.Join(allArgs, ",")
}
}
results := list.Type.(*ast.FuncType).Results.List
getKeyValueType(&opt, params, results, s)
return
}
func getKeyValueType(opt *options, params, results []*ast.Field, s *pkg.Source) {
// check
if s.ExprString(results[len(results)-1].Type) != "error" {
log.Fatalln("最后返回值参数需为error")
}
for _, res := range results {
if len(res.Names) > 1 {
log.Fatalln(opt.name + "返回值不支持省略类型")
}
}
if opt.MCType == _typeGet {
if len(results) != 2 {
log.Fatalln("参数个数不对")
}
}
// get key type and value type
if (opt.MCType == _typeAdd) || (opt.MCType == _typeSet) || (opt.MCType == _typeReplace) {
if opt.template == _multiTpl {
p, ok := params[1].Type.(*ast.MapType)
if !ok {
log.Fatalf("%s: 参数类型错误 批量设置数据时类型需为map类型\n", opt.name)
}
opt.keyType = s.ExprString(p.Key)
opt.ValueType = s.ExprString(p.Value)
} else if opt.template == _singleTpl {
opt.keyType = s.ExprString(params[1].Type)
opt.ValueType = s.ExprString(params[2].Type)
} else {
opt.ValueType = s.ExprString(params[1].Type)
}
}
if opt.MCType == _typeGet {
if opt.template == _multiTpl {
if p, ok := results[0].Type.(*ast.MapType); ok {
opt.keyType = s.ExprString(p.Key)
opt.ValueType = s.ExprString(p.Value)
} else {
log.Fatalf("%s: 返回值类型错误 批量获取数据时返回值需为map类型\n", opt.name)
}
} else if opt.template == _singleTpl {
opt.keyType = s.ExprString(params[1].Type)
opt.ValueType = s.ExprString(results[0].Type)
} else {
opt.ValueType = s.ExprString(results[0].Type)
}
}
if opt.MCType == _typeDel {
if opt.template == _multiTpl {
p, ok := params[1].Type.(*ast.ArrayType)
if !ok {
log.Fatalf("%s: 类型错误 参数需为[]类型\n", opt.name)
}
opt.keyType = s.ExprString(p.Elt)
} else if opt.template == _singleTpl {
opt.keyType = s.ExprString(params[1].Type)
}
}
for _, t := range simpleTypes {
if t == opt.ValueType {
opt.SimpleValue = true
opt.GetSimpleValue = true
opt.ConvertValue2Bytes = convertValue2Bytes(t)
opt.ConvertBytes2Value = convertBytes2Value(t)
break
}
}
if opt.ValueType == "string" {
opt.LenType = true
} else {
for _, t := range lenTypes {
if strings.HasPrefix(opt.ValueType, t) {
opt.LenType = true
break
}
}
}
if opt.SimpleValue && (opt.ValueType == "[]byte" || opt.ValueType == "string") {
opt.GetSimpleValue = false
opt.GetDirectValue = true
}
if opt.MCType == _typeGet && opt.template == _multiTpl {
opt.UseMemcached = false
}
if strings.HasPrefix(opt.ValueType, "*") {
opt.PointType = true
opt.OriginValueType = strings.Replace(opt.ValueType, "*", "", 1)
} else {
opt.OriginValueType = opt.ValueType
}
if *encode != "" {
var flags []string
for _, f := range strings.Split(*encode, "|") {
switch f {
case "gob":
flags = append(flags, "memcache.FlagGOB")
case "json":
flags = append(flags, "memcache.FlagJSON")
case "raw":
flags = append(flags, "memcache.FlagRAW")
case "pb":
flags = append(flags, "memcache.FlagProtobuf")
case "gzip":
flags = append(flags, "memcache.FlagGzip")
default:
log.Fatalf("%s: encode类型无效\n", opt.name)
}
}
opt.Encode = strings.Join(flags, " | ")
} else {
if opt.SimpleValue {
opt.Encode = "memcache.FlagRAW"
} else {
opt.Encode = "memcache.FlagJSON"
}
}
}
func parse(s *pkg.Source) (opts []*options) {
c := s.F.Scope.Lookup(_interfaceName)
if (c == nil) || (c.Kind != ast.Typ) {
log.Fatalln("无法找到缓存声明")
}
lists := c.Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType).Methods.List
for _, list := range lists {
opt := processList(s, list)
opt.Check()
opts = append(opts, &opt)
}
return
}
func (option *options) Check() {
var valid bool
for _, x := range mcValidTypes {
if x == option.MCType {
valid = true
break
}
}
if !valid {
log.Fatalf("%s: 类型错误 不支持%s类型\n", option.name, option.MCType)
}
if (option.MCType != _typeDel) && !option.SimpleValue && !strings.Contains(option.ValueType, "*") && !strings.Contains(option.ValueType, "[]") && !strings.Contains(option.ValueType, "map") {
log.Fatalf("%s: 值类型只能为基本类型/slice/map/指针类型\n", option.name)
}
}
func genHeader(opts []*options) (src string) {
option := options{PkgName: os.Getenv("GOPACKAGE"), UseMemcached: false}
var packages []string
packagesMap := map[string]bool{`"context"`: true}
for _, opt := range opts {
if len(opt.importPackages) > 0 {
for _, pkg := range opt.importPackages {
if !packagesMap[pkg] {
packages = append(packages, pkg)
packagesMap[pkg] = true
}
}
}
if opt.Args != "" {
option.Args = opt.Args
}
if opt.UseMemcached {
option.UseMemcached = true
}
if opt.SimpleValue && !opt.GetDirectValue {
option.UseStrConv = true
}
if opt.EnableBatch {
option.EnableBatch = true
}
}
option.ImportPackage = strings.Join(packages, "\n")
src = _headerTemplate
t := template.Must(template.New("header").Parse(src))
var buffer bytes.Buffer
err := t.Execute(&buffer, option)
if err != nil {
log.Fatalf("execute template: %s", err)
}
// Format the output.
src = strings.Replace(buffer.String(), "\t", "", -1)
src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n")
src = strings.Replace(src, "NEWLINE", "", -1)
src = strings.Replace(src, "ARGS", option.Args, -1)
return
}
func getNewTemplate(option *options) (src string) {
if option.template == _multiTpl {
switch option.MCType {
case _typeGet:
src = _multiGetTemplate
case _typeSet:
src = _multiSetTemplate
case _typeReplace:
src = _multiReplaceTemplate
case _typeDel:
src = _multiDelTemplate
case _typeAdd:
src = _multiAddTemplate
}
} else if option.template == _singleTpl {
switch option.MCType {
case _typeGet:
src = _singleGetTemplate
case _typeSet:
src = _singleSetTemplate
case _typeReplace:
src = _singleReplaceTemplate
case _typeDel:
src = _singleDelTemplate
case _typeAdd:
src = _singleAddTemplate
}
} else {
switch option.MCType {
case _typeGet:
src = _noneGetTemplate
case _typeSet:
src = _noneSetTemplate
case _typeReplace:
src = _noneReplaceTemplate
case _typeDel:
src = _noneDelTemplate
case _typeAdd:
src = _noneAddTemplate
}
}
return
}
func genBody(opts []*options) (res string) {
for _, option := range opts {
src := getNewTemplate(option)
src = strings.Replace(src, "KEY", option.keyType, -1)
src = strings.Replace(src, "NAME", option.name, -1)
src = strings.Replace(src, "VALUE", option.ValueType, -1)
src = strings.Replace(src, "GROUPSIZE", strconv.Itoa(option.GroupSize), -1)
src = strings.Replace(src, "MAXGROUP", strconv.Itoa(option.MaxGroup), -1)
t := template.Must(template.New("cache").Parse(src))
var buffer bytes.Buffer
err := t.Execute(&buffer, option)
if err != nil {
log.Fatalf("execute template: %s", err)
}
// Format the output.
src = strings.Replace(buffer.String(), "\t", "", -1)
src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n")
res = res + "\n" + src
}
return
}
func main() {
log.SetFlags(0)
defer func() {
if err := recover(); err != nil {
log.Fatalf("程序解析失败, err: %+v", err)
}
}()
options := parse(pkg.NewSource(pkg.SourceText()))
header := genHeader(options)
body := genBody(options)
code := pkg.FormatCode(header + "\n" + body)
// Write to file.
dir := filepath.Dir(".")
outputName := filepath.Join(dir, "mc.cache.go")
err := ioutil.WriteFile(outputName, []byte(code), 0644)
if err != nil {
log.Fatalf("写入文件失败: %s", err)
}
log.Println("mc.cache.go: 生成成功")
}
func convertValue2Bytes(t string) string {
switch t {
case "int", "int8", "int16", "int32", "int64":
return "[]byte(strconv.FormatInt(int64(val), 10))"
case "uint", "uint8", "uint16", "uint32", "uint64":
return "[]byte(strconv.FormatUInt(val, 10))"
case "bool":
return "[]byte(strconv.FormatBool(val))"
case "float32":
return "[]byte(strconv.FormatFloat(val, 'E', -1, 32))"
case "float64":
return "[]byte(strconv.FormatFloat(val, 'E', -1, 64))"
case "string":
return "[]byte(val)"
case "[]byte":
return "val"
}
return ""
}
func convertBytes2Value(t string) string {
switch t {
case "int", "int8", "int16", "int32", "int64":
return "strconv.ParseInt(v, 10, 64)"
case "uint", "uint8", "uint16", "uint32", "uint64":
return "strconv.ParseUInt(v, 10, 64)"
case "bool":
return "strconv.ParseBool(v)"
case "float32":
return "float32(strconv.ParseFloat(v, 32))"
case "float64":
return "strconv.ParseFloat(v, 64)"
}
return ""
}

@ -0,0 +1,208 @@
package main
import (
"strings"
)
var _multiGetTemplate = `
// NAME {{or .Comment "get data from mc"}}
func (d *{{.StructName}}) NAME(c context.Context, ids []KEY {{.ExtraArgsType}}) (res map[KEY]VALUE, err error) {
l := len(ids)
if l == 0 {
return
}
{{if .EnableBatch}}
mutex := sync.Mutex{}
for i:=0;i < l; i+= GROUPSIZE * MAXGROUP {
var subKeys []KEY
{{if .BatchErrBreak}}
group, ctx := errgroup.WithContext(c)
{{else}}
group := &errgroup.Group{}
ctx := c
{{end}}
if (i + GROUPSIZE * MAXGROUP) > l {
subKeys = ids[i:]
} else {
subKeys = ids[i : i+GROUPSIZE * MAXGROUP]
}
subLen := len(subKeys)
for j:=0; j< subLen; j += GROUPSIZE {
var ks []KEY
if (j+GROUPSIZE) > subLen {
ks = subKeys[j:]
} else {
ks = subKeys[j:j+GROUPSIZE]
}
group.Go(func() (err error) {
keysMap := make(map[string]KEY, len(ks))
keys := make([]string, 0, len(ks))
for _, id := range ks {
key := {{.KeyMethod}}(id{{.ExtraArgs}})
keysMap[key] = id
keys = append(keys, key)
}
replies, err := d.mc.GetMulti(c, keys)
if err != nil {
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(ctx, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("keys", keys))
return
}
for _, key := range replies.Keys() {
{{if .GetSimpleValue}}
var v string
err = replies.Scan(key, &v)
{{else}}
var v VALUE
{{if .GetDirectValue}}
err = replies.Scan(key, &v)
{{else}}
{{if .PointType}}
v = &{{.OriginValueType}}{}
err = replies.Scan(key, v)
{{else}}
v = {{.OriginValueType}}{}
err = replies.Scan(key, &v)
{{end}}
{{end}}
{{end}}
if err != nil {
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(ctx, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
{{if .GetSimpleValue}}
r, err := {{.ConvertBytes2Value}}
if err != nil {
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(ctx, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("key", key))
return res, err
}
mutex.Lock()
if res == nil {
res = make(map[KEY]VALUE, len(keys))
}
res[keysMap[key]] = {{.ValueType}}(r)
mutex.Unlock()
{{else}}
mutex.Lock()
if res == nil {
res = make(map[KEY]VALUE, len(keys))
}
res[keysMap[key]] = v
mutex.Unlock()
{{end}}
}
return
})
}
err1 := group.Wait()
if err1 != nil {
err = err1
{{if .BatchErrBreak}}
break
{{end}}
}
}
{{else}}
keysMap := make(map[string]KEY, l)
keys := make([]string, 0, l)
for _, id := range ids {
key := {{.KeyMethod}}(id{{.ExtraArgs}})
keysMap[key] = id
keys = append(keys, key)
}
replies, err := d.mc.GetMulti(c, keys)
if err != nil {
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(c, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("keys", keys))
return
}
for _, key := range replies.Keys() {
{{if .GetSimpleValue}}
var v string
err = replies.Scan(key, &v)
{{else}}
{{if .PointType}}
v := &{{.OriginValueType}}{}
err = replies.Scan(key, v)
{{else}}
v := {{.OriginValueType}}{}
err = replies.Scan(key, &v)
{{end}}
{{end}}
if err != nil {
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(c, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
{{if .GetSimpleValue}}
r, err := {{.ConvertBytes2Value}}
if err != nil {
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(c, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("key", key))
return res, err
}
if res == nil {
res = make(map[KEY]VALUE, len(keys))
}
res[keysMap[key]] = {{.ValueType}}(r)
{{else}}
if res == nil {
res = make(map[KEY]VALUE, len(keys))
}
res[keysMap[key]] = v
{{end}}
}
{{end}}
return
}
`
var _multiSetTemplate = `
// NAME {{or .Comment "Set data to mc"}}
func (d *{{.StructName}}) NAME(c context.Context, values map[KEY]VALUE {{.ExtraArgsType}}) (err error) {
if len(values) == 0 {
return
}
for id, val := range values {
key := {{.KeyMethod}}(id{{.ExtraArgs}})
{{if .SimpleValue}}
bs := {{.ConvertValue2Bytes}}
item := &memcache.Item{Key: key, Value: bs, Expiration: {{.ExpireCode}}, Flags: {{.Encode}}}
{{else}}
item := &memcache.Item{Key: key, Object: val, Expiration: {{.ExpireCode}}, Flags: {{.Encode}}}
{{end}}
if err = d.mc.Set(c, item); err != nil {
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(c, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
}
return
}
`
var _multiAddTemplate = strings.Replace(_multiSetTemplate, "Set", "Add", -1)
var _multiReplaceTemplate = strings.Replace(_multiSetTemplate, "Set", "Replace", -1)
var _multiDelTemplate = `
// NAME {{or .Comment "delete data from mc"}}
func (d *{{.StructName}}) NAME(c context.Context, ids []KEY {{.ExtraArgsType}}) (err error) {
if len(ids) == 0 {
return
}
for _, id := range ids {
key := {{.KeyMethod}}(id{{.ExtraArgs}})
if err = d.mc.Delete(c, key); err != nil {
if err == memcache.ErrNotFound {
err = nil
continue
}
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(c, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
}
return
}
`

@ -0,0 +1,103 @@
package main
import (
"strings"
)
var _noneGetTemplate = `
// NAME {{or .Comment "get data from mc"}}
func (d *{{.StructName}}) NAME(c context.Context) (res VALUE, err error) {
key := {{.KeyMethod}}()
{{if .GetSimpleValue}}
var v string
err = d.mc.Get(c, key).Scan(&v)
{{else}}
{{if .GetDirectValue}}
err = d.mc.Get(c, key).Scan(&res)
{{else}}
{{if .PointType}}
res = &{{.OriginValueType}}{}
if err = d.mc.Get(c, key).Scan(res); err != nil {
res = nil
if err == memcache.ErrNotFound {
err = nil
return
}
}
{{else}}
err = d.mc.Get(c, key).Scan(&res)
{{end}}
{{end}}
{{end}}
if err != nil {
{{if .PointType}}
{{else}}
if err == memcache.ErrNotFound {
err = nil
return
}
{{end}}
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(c, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
{{if .GetSimpleValue}}
r, err := {{.ConvertBytes2Value}}
if err != nil {
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(c, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
res = {{.ValueType}}(r)
{{end}}
return
}
`
var _noneSetTemplate = `
// NAME {{or .Comment "Set data to mc"}}
func (d *{{.StructName}}) NAME(c context.Context, val VALUE) (err error) {
{{if .PointType}}
if val == nil {
return
}
{{end}}
{{if .LenType}}
if len(val) == 0 {
return
}
{{end}}
key := {{.KeyMethod}}()
{{if .SimpleValue}}
bs := {{.ConvertValue2Bytes}}
item := &memcache.Item{Key: key, Value: bs, Expiration: {{.ExpireCode}}, Flags: {{.Encode}}}
{{else}}
item := &memcache.Item{Key: key, Object: val, Expiration: {{.ExpireCode}}, Flags: {{.Encode}}}
{{end}}
if err = d.mc.Set(c, item); err != nil {
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(c, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
`
var _noneAddTemplate = strings.Replace(_noneSetTemplate, "Set", "Add", -1)
var _noneReplaceTemplate = strings.Replace(_noneSetTemplate, "Set", "Replace", -1)
var _noneDelTemplate = `
// NAME {{or .Comment "delete data from mc"}}
func (d *{{.StructName}}) NAME(c context.Context) (err error) {
key := {{.KeyMethod}}()
if err = d.mc.Delete(c, key); err != nil {
if err == memcache.ErrNotFound {
err = nil
return
}
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(c, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
`

@ -0,0 +1,102 @@
package main
import (
"strings"
)
var _singleGetTemplate = `
// NAME {{or .Comment "get data from mc"}}
func (d *{{.StructName}}) NAME(c context.Context, id KEY {{.ExtraArgsType}}) (res VALUE, err error) {
key := {{.KeyMethod}}(id{{.ExtraArgs}})
{{if .GetSimpleValue}}
var v string
err = d.mc.Get(c, key).Scan(&v)
{{else}}
{{if .GetDirectValue}}
err = d.mc.Get(c, key).Scan(&res)
{{else}}
{{if .PointType}}
res = &{{.OriginValueType}}{}
if err = d.mc.Get(c, key).Scan(res); err != nil {
res = nil
if err == memcache.ErrNotFound {
err = nil
}
}
{{else}}
err = d.mc.Get(c, key).Scan(&res)
{{end}}
{{end}}
{{end}}
if err != nil {
{{if .PointType}}
{{else}}
if err == memcache.ErrNotFound {
err = nil
return
}
{{end}}
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(c, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
{{if .GetSimpleValue}}
r, err := {{.ConvertBytes2Value}}
if err != nil {
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(c, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
res = {{.ValueType}}(r)
{{end}}
return
}
`
var _singleSetTemplate = `
// NAME {{or .Comment "Set data to mc"}}
func (d *{{.StructName}}) NAME(c context.Context, id KEY, val VALUE {{.ExtraArgsType}}) (err error) {
{{if .PointType}}
if val == nil {
return
}
{{end}}
{{if .LenType}}
if len(val) == 0 {
return
}
{{end}}
key := {{.KeyMethod}}(id{{.ExtraArgs}})
{{if .SimpleValue}}
bs := {{.ConvertValue2Bytes}}
item := &memcache.Item{Key: key, Value: bs, Expiration: {{.ExpireCode}}, Flags: {{.Encode}}}
{{else}}
item := &memcache.Item{Key: key, Object: val, Expiration: {{.ExpireCode}}, Flags: {{.Encode}}}
{{end}}
if err = d.mc.Set(c, item); err != nil {
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(c, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
`
var _singleAddTemplate = strings.Replace(_singleSetTemplate, "Set", "Add", -1)
var _singleReplaceTemplate = strings.Replace(_singleSetTemplate, "Set", "Replace", -1)
var _singleDelTemplate = `
// NAME {{or .Comment "delete data from mc"}}
func (d *{{.StructName}}) NAME(c context.Context, id KEY {{.ExtraArgsType}}) (err error) {
key := {{.KeyMethod}}(id{{.ExtraArgs}})
if err = d.mc.Delete(c, key); err != nil {
if err == memcache.ErrNotFound {
err = nil
return
}
prom.BusinessErrCount.Incr("mc:NAME")
log.Errorv(c, log.KV("NAME", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
`

@ -0,0 +1,93 @@
package testdata
import (
"context"
"fmt"
"time"
"github.com/bilibili/kratos/pkg/cache/memcache"
"github.com/bilibili/kratos/pkg/container/pool"
xtime "github.com/bilibili/kratos/pkg/time"
)
// Dao .
type Dao struct {
mc *memcache.Memcache
demoExpire int32
}
// New new dao
func New() (d *Dao) {
cfg := &memcache.Config{
Config: &pool.Config{
Active: 10,
Idle: 5,
IdleTimeout: xtime.Duration(time.Second),
},
Name: "test",
Proto: "tcp",
// Addr: "172.16.33.54:11214",
Addr: "127.0.0.1:11211",
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}
d = &Dao{
mc: memcache.New(cfg),
demoExpire: int32(5),
}
return
}
//go:generate kratos tool genmc
type _mc interface {
// mc: -key=demoKey
CacheDemos(c context.Context, keys []int64) (map[int64]*Demo, error)
// mc: -key=demoKey
CacheDemo(c context.Context, key int64) (*Demo, error)
// mc: -key=keyMid
CacheDemo1(c context.Context, key int64, mid int64) (*Demo, error)
// mc: -key=noneKey
CacheNone(c context.Context) (*Demo, error)
// mc: -key=demoKey
CacheString(c context.Context, key int64) (string, error)
// mc: -key=demoKey -expire=d.demoExpire -encode=json
AddCacheDemos(c context.Context, values map[int64]*Demo) error
// mc: -key=demo2Key -expire=d.demoExpire -encode=json
AddCacheDemos2(c context.Context, values map[int64]*Demo, tp int64) error
// 这里也支持自定义注释 会替换默认的注释
// mc: -key=demoKey -expire=d.demoExpire -encode=json|gzip
AddCacheDemo(c context.Context, key int64, value *Demo) error
// mc: -key=keyMid -expire=d.demoExpire -encode=gob
AddCacheDemo1(c context.Context, key int64, value *Demo, mid int64) error
// mc: -key=noneKey
AddCacheNone(c context.Context, value *Demo) error
// mc: -key=demoKey -expire=d.demoExpire
AddCacheString(c context.Context, key int64, value string) error
// mc: -key=demoKey
DelCacheDemos(c context.Context, keys []int64) error
// mc: -key=demoKey
DelCacheDemo(c context.Context, key int64) error
// mc: -key=keyMid
DelCacheDemo1(c context.Context, key int64, mid int64) error
// mc: -key=noneKey
DelCacheNone(c context.Context) error
}
func demoKey(id int64) string {
return fmt.Sprintf("art_%d", id)
}
func demo2Key(id, tp int64) string {
return fmt.Sprintf("art_%d_%d", id, tp)
}
func keyMid(id, mid int64) string {
return fmt.Sprintf("art_%d_%d", id, mid)
}
func noneKey() string {
return "none"
}

@ -0,0 +1,116 @@
package testdata
import (
"context"
"testing"
)
func TestDemo(t *testing.T) {
d := New()
c := context.TODO()
art := &Demo{ID: 1, Title: "title"}
err := d.AddCacheDemo(c, art.ID, art)
if err != nil {
t.Errorf("err should be nil, get: %v", err)
t.FailNow()
}
art1, err := d.CacheDemo(c, art.ID)
if err != nil {
t.Errorf("err should be nil, get: %v", err)
t.FailNow()
}
if (art1.ID != art.ID) || (art.Title != art1.Title) {
t.Error("art not equal")
t.FailNow()
}
err = d.DelCacheDemo(c, art.ID)
if err != nil {
t.Errorf("err should be nil, get: %v", err)
t.FailNow()
}
art1, err = d.CacheDemo(c, art.ID)
if (art1 != nil) || (err != nil) {
t.Errorf("art %v, err: %v", art1, err)
t.FailNow()
}
}
func TestNone(t *testing.T) {
d := New()
c := context.TODO()
art := &Demo{ID: 1, Title: "title"}
err := d.AddCacheNone(c, art)
if err != nil {
t.Errorf("err should be nil, get: %v", err)
t.FailNow()
}
art1, err := d.CacheNone(c)
if err != nil {
t.Errorf("err should be nil, get: %v", err)
t.FailNow()
}
if (art1.ID != art.ID) || (art.Title != art1.Title) {
t.Error("art not equal")
t.FailNow()
}
err = d.DelCacheNone(c)
if err != nil {
t.Errorf("err should be nil, get: %v", err)
t.FailNow()
}
art1, err = d.CacheNone(c)
if (art1 != nil) || (err != nil) {
t.Errorf("art %v, err: %v", art1, err)
t.FailNow()
}
}
func TestDemos(t *testing.T) {
d := New()
c := context.TODO()
art1 := &Demo{ID: 1, Title: "title"}
art2 := &Demo{ID: 2, Title: "title"}
err := d.AddCacheDemos(c, map[int64]*Demo{1: art1, 2: art2})
if err != nil {
t.Errorf("err should be nil, get: %v", err)
t.FailNow()
}
arts, err := d.CacheDemos(c, []int64{art1.ID, art2.ID})
if err != nil {
t.Errorf("err should be nil, get: %v", err)
t.FailNow()
}
if (arts[1].Title != art1.Title) || (arts[2].Title != art2.Title) {
t.Error("art not equal")
t.FailNow()
}
err = d.DelCacheDemos(c, []int64{art1.ID, art2.ID})
if err != nil {
t.Errorf("err should be nil, get: %v", err)
t.FailNow()
}
arts, err = d.CacheDemos(c, []int64{art1.ID, art2.ID})
if (arts != nil) || (err != nil) {
t.Errorf("art %v, err: %v", art1, err)
t.FailNow()
}
}
func TestString(t *testing.T) {
d := New()
c := context.TODO()
err := d.AddCacheString(c, 1, "abc")
if err != nil {
t.Errorf("err should be nil, get: %v", err)
t.FailNow()
}
res, err := d.CacheString(c, 1)
if err != nil {
t.Errorf("err should be nil, get: %v", err)
t.FailNow()
}
if res != "abc" {
t.Error("res wrong")
t.FailNow()
}
}

@ -0,0 +1,320 @@
// Code generated by kratos tool mcgen. DO NOT EDIT.
/*
Package testdata is a generated mc cache package.
It is generated from:
type _mc interface {
// mc: -key=demoKey
CacheDemos(c context.Context, keys []int64) (map[int64]*Demo, error)
// mc: -key=demoKey
CacheDemo(c context.Context, key int64) (*Demo, error)
// mc: -key=keyMid
CacheDemo1(c context.Context, key int64, mid int64) (*Demo, error)
// mc: -key=noneKey
CacheNone(c context.Context) (*Demo, error)
// mc: -key=demoKey
CacheString(c context.Context, key int64) (string, error)
// mc: -key=demoKey -expire=d.demoExpire -encode=json
AddCacheDemos(c context.Context, values map[int64]*Demo) error
// mc: -key=demo2Key -expire=d.demoExpire -encode=json
AddCacheDemos2(c context.Context, values map[int64]*Demo, tp int64) error
// 这里也支持自定义注释 会替换默认的注释
// mc: -key=demoKey -expire=d.demoExpire -encode=json|gzip
AddCacheDemo(c context.Context, key int64, value *Demo) error
// mc: -key=keyMid -expire=d.demoExpire -encode=gob
AddCacheDemo1(c context.Context, key int64, value *Demo, mid int64) error
// mc: -key=noneKey
AddCacheNone(c context.Context, value *Demo) error
// mc: -key=demoKey -expire=d.demoExpire
AddCacheString(c context.Context, key int64, value string) error
// mc: -key=demoKey
DelCacheDemos(c context.Context, keys []int64) error
// mc: -key=demoKey
DelCacheDemo(c context.Context, key int64) error
// mc: -key=keyMid
DelCacheDemo1(c context.Context, key int64, mid int64) error
// mc: -key=noneKey
DelCacheNone(c context.Context) error
}
*/
package testdata
import (
"context"
"fmt"
"github.com/bilibili/kratos/pkg/cache/memcache"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/stat/prom"
)
var _ _mc
// CacheDemos get data from mc
func (d *Dao) CacheDemos(c context.Context, ids []int64) (res map[int64]*Demo, err error) {
l := len(ids)
if l == 0 {
return
}
keysMap := make(map[string]int64, l)
keys := make([]string, 0, l)
for _, id := range ids {
key := demoKey(id)
keysMap[key] = id
keys = append(keys, key)
}
replies, err := d.mc.GetMulti(c, keys)
if err != nil {
prom.BusinessErrCount.Incr("mc:CacheDemos")
log.Errorv(c, log.KV("CacheDemos", fmt.Sprintf("%+v", err)), log.KV("keys", keys))
return
}
for _, key := range replies.Keys() {
v := &Demo{}
err = replies.Scan(key, v)
if err != nil {
prom.BusinessErrCount.Incr("mc:CacheDemos")
log.Errorv(c, log.KV("CacheDemos", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
if res == nil {
res = make(map[int64]*Demo, len(keys))
}
res[keysMap[key]] = v
}
return
}
// CacheDemo get data from mc
func (d *Dao) CacheDemo(c context.Context, id int64) (res *Demo, err error) {
key := demoKey(id)
res = &Demo{}
if err = d.mc.Get(c, key).Scan(res); err != nil {
res = nil
if err == memcache.ErrNotFound {
err = nil
}
}
if err != nil {
prom.BusinessErrCount.Incr("mc:CacheDemo")
log.Errorv(c, log.KV("CacheDemo", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
// CacheDemo1 get data from mc
func (d *Dao) CacheDemo1(c context.Context, id int64, mid int64) (res *Demo, err error) {
key := keyMid(id, mid)
res = &Demo{}
if err = d.mc.Get(c, key).Scan(res); err != nil {
res = nil
if err == memcache.ErrNotFound {
err = nil
}
}
if err != nil {
prom.BusinessErrCount.Incr("mc:CacheDemo1")
log.Errorv(c, log.KV("CacheDemo1", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
// CacheNone get data from mc
func (d *Dao) CacheNone(c context.Context) (res *Demo, err error) {
key := noneKey()
res = &Demo{}
if err = d.mc.Get(c, key).Scan(res); err != nil {
res = nil
if err == memcache.ErrNotFound {
err = nil
return
}
}
if err != nil {
prom.BusinessErrCount.Incr("mc:CacheNone")
log.Errorv(c, log.KV("CacheNone", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
// CacheString get data from mc
func (d *Dao) CacheString(c context.Context, id int64) (res string, err error) {
key := demoKey(id)
err = d.mc.Get(c, key).Scan(&res)
if err != nil {
if err == memcache.ErrNotFound {
err = nil
return
}
prom.BusinessErrCount.Incr("mc:CacheString")
log.Errorv(c, log.KV("CacheString", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
// AddCacheDemos Set data to mc
func (d *Dao) AddCacheDemos(c context.Context, values map[int64]*Demo) (err error) {
if len(values) == 0 {
return
}
for id, val := range values {
key := demoKey(id)
item := &memcache.Item{Key: key, Object: val, Expiration: d.demoExpire, Flags: memcache.FlagJSON}
if err = d.mc.Set(c, item); err != nil {
prom.BusinessErrCount.Incr("mc:AddCacheDemos")
log.Errorv(c, log.KV("AddCacheDemos", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
}
return
}
// AddCacheDemos2 Set data to mc
func (d *Dao) AddCacheDemos2(c context.Context, values map[int64]*Demo, tp int64) (err error) {
if len(values) == 0 {
return
}
for id, val := range values {
key := demo2Key(id, tp)
item := &memcache.Item{Key: key, Object: val, Expiration: d.demoExpire, Flags: memcache.FlagJSON}
if err = d.mc.Set(c, item); err != nil {
prom.BusinessErrCount.Incr("mc:AddCacheDemos2")
log.Errorv(c, log.KV("AddCacheDemos2", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
}
return
}
// AddCacheDemo 这里也支持自定义注释 会替换默认的注释
func (d *Dao) AddCacheDemo(c context.Context, id int64, val *Demo) (err error) {
if val == nil {
return
}
key := demoKey(id)
item := &memcache.Item{Key: key, Object: val, Expiration: d.demoExpire, Flags: memcache.FlagJSON | memcache.FlagGzip}
if err = d.mc.Set(c, item); err != nil {
prom.BusinessErrCount.Incr("mc:AddCacheDemo")
log.Errorv(c, log.KV("AddCacheDemo", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
// AddCacheDemo1 Set data to mc
func (d *Dao) AddCacheDemo1(c context.Context, id int64, val *Demo, mid int64) (err error) {
if val == nil {
return
}
key := keyMid(id, mid)
item := &memcache.Item{Key: key, Object: val, Expiration: d.demoExpire, Flags: memcache.FlagGOB}
if err = d.mc.Set(c, item); err != nil {
prom.BusinessErrCount.Incr("mc:AddCacheDemo1")
log.Errorv(c, log.KV("AddCacheDemo1", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
// AddCacheNone Set data to mc
func (d *Dao) AddCacheNone(c context.Context, val *Demo) (err error) {
if val == nil {
return
}
key := noneKey()
item := &memcache.Item{Key: key, Object: val, Expiration: d.demoExpire, Flags: memcache.FlagJSON}
if err = d.mc.Set(c, item); err != nil {
prom.BusinessErrCount.Incr("mc:AddCacheNone")
log.Errorv(c, log.KV("AddCacheNone", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
// AddCacheString Set data to mc
func (d *Dao) AddCacheString(c context.Context, id int64, val string) (err error) {
if len(val) == 0 {
return
}
key := demoKey(id)
bs := []byte(val)
item := &memcache.Item{Key: key, Value: bs, Expiration: d.demoExpire, Flags: memcache.FlagRAW}
if err = d.mc.Set(c, item); err != nil {
prom.BusinessErrCount.Incr("mc:AddCacheString")
log.Errorv(c, log.KV("AddCacheString", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
// DelCacheDemos delete data from mc
func (d *Dao) DelCacheDemos(c context.Context, ids []int64) (err error) {
if len(ids) == 0 {
return
}
for _, id := range ids {
key := demoKey(id)
if err = d.mc.Delete(c, key); err != nil {
if err == memcache.ErrNotFound {
err = nil
continue
}
prom.BusinessErrCount.Incr("mc:DelCacheDemos")
log.Errorv(c, log.KV("DelCacheDemos", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
}
return
}
// DelCacheDemo delete data from mc
func (d *Dao) DelCacheDemo(c context.Context, id int64) (err error) {
key := demoKey(id)
if err = d.mc.Delete(c, key); err != nil {
if err == memcache.ErrNotFound {
err = nil
return
}
prom.BusinessErrCount.Incr("mc:DelCacheDemo")
log.Errorv(c, log.KV("DelCacheDemo", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
// DelCacheDemo1 delete data from mc
func (d *Dao) DelCacheDemo1(c context.Context, id int64, mid int64) (err error) {
key := keyMid(id, mid)
if err = d.mc.Delete(c, key); err != nil {
if err == memcache.ErrNotFound {
err = nil
return
}
prom.BusinessErrCount.Incr("mc:DelCacheDemo1")
log.Errorv(c, log.KV("DelCacheDemo1", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
// DelCacheNone delete data from mc
func (d *Dao) DelCacheNone(c context.Context) (err error) {
key := noneKey()
if err = d.mc.Delete(c, key); err != nil {
if err == memcache.ErrNotFound {
err = nil
return
}
prom.BusinessErrCount.Incr("mc:DelCacheNone")
log.Errorv(c, log.KV("DelCacheNone", fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}

@ -0,0 +1,328 @@
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: model.proto
/*
Package model is a generated protocol buffer package.
It is generated from these files:
model.proto
It has these top-level messages:
Demo
*/
package testdata
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import _ "github.com/gogo/protobuf/gogoproto"
import io "io"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type Demo struct {
ID int64 `protobuf:"varint,1,opt,name=ID,proto3" json:"id"`
Title string `protobuf:"bytes,3,opt,name=Title,proto3" json:"title"`
}
func (m *Demo) Reset() { *m = Demo{} }
func (m *Demo) String() string { return proto.CompactTextString(m) }
func (*Demo) ProtoMessage() {}
func (*Demo) Descriptor() ([]byte, []int) { return fileDescriptorModel, []int{0} }
func init() {
proto.RegisterType((*Demo)(nil), "model.Demo")
}
func (m *Demo) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalTo(dAtA)
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Demo) MarshalTo(dAtA []byte) (int, error) {
var i int
_ = i
var l int
_ = l
if m.ID != 0 {
dAtA[i] = 0x8
i++
i = encodeVarintModel(dAtA, i, uint64(m.ID))
}
if len(m.Title) > 0 {
dAtA[i] = 0x1a
i++
i = encodeVarintModel(dAtA, i, uint64(len(m.Title)))
i += copy(dAtA[i:], m.Title)
}
return i, nil
}
func encodeVarintModel(dAtA []byte, offset int, v uint64) int {
for v >= 1<<7 {
dAtA[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
dAtA[offset] = uint8(v)
return offset + 1
}
func (m *Demo) Size() (n int) {
var l int
_ = l
if m.ID != 0 {
n += 1 + sovModel(uint64(m.ID))
}
l = len(m.Title)
if l > 0 {
n += 1 + l + sovModel(uint64(l))
}
return n
}
func sovModel(x uint64) (n int) {
for {
n++
x >>= 7
if x == 0 {
break
}
}
return n
}
func sozModel(x uint64) (n int) {
return sovModel(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *Demo) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowModel
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Demo: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Demo: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field ID", wireType)
}
m.ID = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowModel
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.ID |= (int64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
case 3:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Title", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowModel
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthModel
}
postIndex := iNdEx + intStringLen
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Title = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipModel(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthModel
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipModel(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowModel
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowModel
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if dAtA[iNdEx-1] < 0x80 {
break
}
}
return iNdEx, nil
case 1:
iNdEx += 8
return iNdEx, nil
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowModel
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
iNdEx += length
if length < 0 {
return 0, ErrInvalidLengthModel
}
return iNdEx, nil
case 3:
for {
var innerWire uint64
var start int = iNdEx
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowModel
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
innerWire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
innerWireType := int(innerWire & 0x7)
if innerWireType == 4 {
break
}
next, err := skipModel(dAtA[start:])
if err != nil {
return 0, err
}
iNdEx = start + next
}
return iNdEx, nil
case 4:
return iNdEx, nil
case 5:
iNdEx += 4
return iNdEx, nil
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
}
panic("unreachable")
}
var (
ErrInvalidLengthModel = fmt.Errorf("proto: negative length found during unmarshaling")
ErrIntOverflowModel = fmt.Errorf("proto: integer overflow")
)
func init() { proto.RegisterFile("model.proto", fileDescriptorModel) }
var fileDescriptorModel = []byte{
// 166 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xce, 0xcd, 0x4f, 0x49,
0xcd, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x73, 0xa4, 0x74, 0xd3, 0x33, 0x4b,
0x32, 0x4a, 0x93, 0xf4, 0x92, 0xf3, 0x73, 0xf5, 0xd3, 0xf3, 0xd3, 0xf3, 0xf5, 0xc1, 0xb2, 0x49,
0xa5, 0x69, 0x60, 0x1e, 0x98, 0x03, 0x66, 0x41, 0x74, 0x29, 0x39, 0x71, 0xb1, 0x3b, 0x16, 0x95,
0x64, 0x26, 0xe7, 0xa4, 0x0a, 0x89, 0x71, 0x31, 0x79, 0xba, 0x48, 0x30, 0x2a, 0x30, 0x6a, 0x30,
0x3b, 0xb1, 0xbd, 0xba, 0x27, 0xcf, 0x94, 0x99, 0x12, 0xc4, 0xe4, 0xe9, 0x22, 0x24, 0xcf, 0xc5,
0x1a, 0x92, 0x59, 0x92, 0x93, 0x2a, 0xc1, 0xac, 0xc0, 0xa8, 0xc1, 0xe9, 0xc4, 0xf9, 0xea, 0x9e,
0x3c, 0x6b, 0x09, 0x48, 0x20, 0x08, 0x22, 0xee, 0x24, 0x71, 0xe2, 0xa1, 0x1c, 0xc3, 0x85, 0x87,
0x72, 0x0c, 0x27, 0x1e, 0xc9, 0x31, 0x5e, 0x78, 0x24, 0xc7, 0xf8, 0xe0, 0x91, 0x1c, 0xe3, 0x8c,
0xc7, 0x72, 0x0c, 0x49, 0x6c, 0x60, 0x4b, 0x8c, 0x01, 0x01, 0x00, 0x00, 0xff, 0xff, 0x11, 0xa6,
0xfa, 0x1c, 0xa9, 0x00, 0x00, 0x00,
}

@ -0,0 +1,14 @@
syntax = "proto3";
package testdata;
import "github.com/gogo/protobuf/gogoproto/gogo.proto";
option (gogoproto.goproto_enum_prefix_all) = false;
option (gogoproto.goproto_getters_all) = false;
option (gogoproto.unmarshaler_all) = true;
option (gogoproto.marshaler_all) = true;
option (gogoproto.sizer_all) = true;
message Demo {
int64 ID = 1 [(gogoproto.jsontag) = "id"];
string Title = 3 [(gogoproto.jsontag) = "title"];
}

@ -30,4 +30,22 @@ var toolIndexs = []*Tool{
Platform: []string{"darwin", "linux", "windows"}, Platform: []string{"darwin", "linux", "windows"},
Author: "goswagger.io", Author: "goswagger.io",
}, },
&Tool{
Name: "genbts",
Alias: "kratos-gen-bts",
BuildTime: time.Date(2019, 5, 5, 0, 0, 0, 0, time.Local),
Install: "go get -u github.com/bilibili/kratos/tool/kratos-gen-bts",
Summary: "缓存回源逻辑代码生成器",
Platform: []string{"darwin", "linux", "windows"},
Author: "kratos",
},
&Tool{
Name: "genmc",
Alias: "kratos-gen-mc",
BuildTime: time.Date(2019, 5, 5, 0, 0, 0, 0, time.Local),
Install: "go get -u github.com/bilibili/kratos/tool/kratos-gen-mc",
Summary: "mc缓存代码生成",
Platform: []string{"darwin", "linux", "windows"},
Author: "kratos",
},
} }

@ -0,0 +1,149 @@
package pkg
import (
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"io/ioutil"
"log"
"os"
"regexp"
"strings"
)
// Source source
type Source struct {
Fset *token.FileSet
Src string
F *ast.File
}
// NewSource new source
func NewSource(src string) *Source {
s := &Source{
Fset: token.NewFileSet(),
Src: src,
}
f, err := parser.ParseFile(s.Fset, "", src, 0)
if err != nil {
log.Fatal("无法解析源文件")
}
s.F = f
return s
}
// ExprString expr string
func (s *Source) ExprString(typ ast.Expr) string {
fset := s.Fset
s1 := fset.Position(typ.Pos()).Offset
s2 := fset.Position(typ.End()).Offset
return s.Src[s1:s2]
}
// pkgPath package path
func (s *Source) pkgPath(name string) (res string) {
for _, im := range s.F.Imports {
if im.Name != nil && im.Name.Name == name {
return im.Path.Value
}
}
for _, im := range s.F.Imports {
if strings.HasSuffix(im.Path.Value, name+"\"") {
return im.Path.Value
}
}
return
}
// GetDef get define code
func (s *Source) GetDef(name string) string {
c := s.F.Scope.Lookup(name).Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType)
s1 := s.Fset.Position(c.Pos()).Offset
s2 := s.Fset.Position(c.End()).Offset
line := s.Fset.Position(c.Pos()).Line
lines := []string{strings.Split(s.Src, "\n")[line-1]}
for _, l := range strings.Split(s.Src[s1:s2], "\n")[1:] {
lines = append(lines, "\t"+l)
}
return strings.Join(lines, "\n")
}
// RegexpReplace replace regexp
func RegexpReplace(reg, src, temp string) string {
result := []byte{}
pattern := regexp.MustCompile(reg)
for _, submatches := range pattern.FindAllStringSubmatchIndex(src, -1) {
result = pattern.ExpandString(result, temp, src, submatches)
}
return string(result)
}
// formatPackage format package
func formatPackage(name, path string) (res string) {
if path != "" {
if strings.HasSuffix(path, name+"\"") {
res = path
return
}
res = fmt.Sprintf("%s %s", name, path)
}
return
}
// SourceText get source file text
func SourceText() string {
file := os.Getenv("GOFILE")
data, err := ioutil.ReadFile(file)
if err != nil {
log.Fatal("请使用go generate执行", file)
}
return string(data)
}
// FormatCode format code
func FormatCode(source string) string {
src, err := format.Source([]byte(source))
if err != nil {
// Should never happen, but can arise when developing this code.
// The user can compile the output to see the error.
log.Printf("warning: 输出文件不合法: %s", err)
log.Printf("warning: 详细错误请编译查看")
return source
}
return string(src)
}
// Packages get import packages
func (s *Source) Packages(f *ast.Field) (res []string) {
fs := f.Type.(*ast.FuncType).Params.List
fs = append(fs, f.Type.(*ast.FuncType).Results.List...)
var types []string
resMap := make(map[string]bool)
for _, field := range fs {
if p, ok := field.Type.(*ast.MapType); ok {
types = append(types, s.ExprString(p.Key))
types = append(types, s.ExprString(p.Value))
} else if p, ok := field.Type.(*ast.ArrayType); ok {
types = append(types, s.ExprString(p.Elt))
} else {
types = append(types, s.ExprString(field.Type))
}
}
for _, t := range types {
name := RegexpReplace(`(?P<pkg>\w+)\.\w+`, t, "$pkg")
if name == "" {
continue
}
pkg := formatPackage(name, s.pkgPath(name))
if !resMap[pkg] {
resMap[pkg] = true
}
}
for pkg := range resMap {
res = append(res, pkg)
}
return
}
Loading…
Cancel
Save