You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
kratos/tool/kratos-gen-bts/main.go

483 lines
14 KiB

package main
import (
"bytes"
"flag"
"fmt"
"go/ast"
"io/ioutil"
"log"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"text/template"
"github.com/bilibili/kratos/tool/pkg"
)
var (
// arguments
singleFlight = flag.Bool("singleflight", false, "enable singleflight")
nullCache = flag.String("nullcache", "", "null cache")
checkNullCode = flag.String("check_null_code", "", "check null code")
batchSize = flag.Int("batch", 0, "batch size")
batchErr = flag.String("batch_err", "break", "batch err to contine or break")
maxGroup = flag.Int("max_group", 0, "max group size")
sync = flag.Bool("sync", false, "add cache in sync way.")
paging = flag.Bool("paging", false, "use paging in single template")
ignores = flag.String("ignores", "", "ignore params")
customMethod = flag.String("custom_method", "", "自定义方法名 |分隔: 缓存|回源|增加缓存")
numberTypes = []string{"int", "int8", "int16", "int32", "int64", "float32", "float64", "uint", "uint8", "uint16", "uint32", "uint64"}
simpleTypes = []string{"int", "int8", "int16", "int32", "int64", "float32", "float64", "uint", "uint8", "uint16", "uint32", "uint64", "bool", "string", "[]byte"}
optionNames = []string{"singleflight", "nullcache", "check_null_code", "batch", "max_group", "sync", "paging", "ignores", "batch_err", "custom_method"}
optionNamesMap = map[string]bool{}
)
const (
_interfaceName = "_bts"
_multiTpl = 1
_singleTpl = 2
_noneTpl = 3
)
func resetFlag() {
*singleFlight = false
*nullCache = ""
*checkNullCode = ""
*batchSize = 0
*maxGroup = 0
*sync = false
*paging = false
*batchErr = "break"
*ignores = ""
*customMethod = ""
}
// options options
type options struct {
name string
keyType string
valueType string
cacheFunc string
rawFunc string
addCacheFunc string
template int
SimpleValue bool
NumberValue bool
GoValue bool
ZeroValue string
ImportPackage string
importPackages []string
Args string
PkgName string
EnableSingleFlight bool
NullCache string
EnableNullCache bool
GroupSize int
MaxGroup int
EnableBatch bool
BatchErrBreak bool
Sync bool
CheckNullCode string
ExtraArgsType string
ExtraArgs string
ExtraCacheArgs string
ExtraRawArgs string
ExtraAddCacheArgs string
EnablePaging bool
Comment string
CustomMethod string
IDName string
}
func getOptions(opt *options, comment string) {
os.Args = []string{os.Args[0]}
if regexp.MustCompile(`\s+//\s*bts:.+`).Match([]byte(comment)) {
args := strings.Split(pkg.RegexpReplace(`//\s*bts:(?P<arg>.+)`, comment, "$arg"), " ")
for _, arg := range args {
arg = strings.TrimSpace(arg)
if arg != "" {
// validate option name
argName := pkg.RegexpReplace(`-(?P<name>[\w_-]+)=.+`, arg, "$name")
if !optionNamesMap[argName] {
log.Fatalf("选项:%s 不存在 请检查拼写\n", argName)
}
os.Args = append(os.Args, arg)
}
}
}
resetFlag()
flag.Parse()
opt.EnableSingleFlight = *singleFlight
opt.NullCache = *nullCache
opt.EnablePaging = *paging
opt.EnableNullCache = *nullCache != ""
opt.EnableBatch = (*batchSize != 0) && (*maxGroup != 0)
opt.BatchErrBreak = *batchErr == "break"
opt.Sync = *sync
opt.CheckNullCode = *checkNullCode
opt.GroupSize = *batchSize
opt.MaxGroup = *maxGroup
opt.CustomMethod = *customMethod
}
func processList(s *pkg.Source, list *ast.Field) (opt options) {
fset := s.Fset
src := s.Src
lines := strings.Split(src, "\n")
opt = options{Args: s.GetDef(_interfaceName), importPackages: s.Packages(list)}
// get comment
line := fset.Position(list.Pos()).Line - 3
if len(lines)-1 >= line {
comment := lines[line]
opt.Comment = pkg.RegexpReplace(`\s+//(?P<name>.+)`, comment, "$name")
opt.Comment = strings.TrimSpace(opt.Comment)
}
// get options
line = fset.Position(list.Pos()).Line - 2
comment := lines[line]
getOptions(&opt, comment)
// get func
opt.name = list.Names[0].Name
params := list.Type.(*ast.FuncType).Params.List
if len(params) == 0 {
log.Fatalln(opt.name + "参数不足")
}
for _, p := range params {
if len(p.Names) > 1 {
log.Fatalln(opt.name + "不支持省略类型 请写全声明中的字段类型名称")
}
}
if s.ExprString(params[0].Type) != "context.Context" {
log.Fatalln("第一个参数必须为context")
}
if len(params) == 1 {
opt.template = _noneTpl
} else {
opt.IDName = params[1].Names[0].Name
if _, ok := params[1].Type.(*ast.ArrayType); ok {
opt.template = _multiTpl
} else {
opt.template = _singleTpl
// get key
opt.keyType = s.ExprString(params[1].Type)
}
}
if len(params) > 2 {
var args []string
var allArgs []string
for _, pa := range params[2:] {
paType := s.ExprString(pa.Type)
if len(pa.Names) == 0 {
args = append(args, paType)
allArgs = append(allArgs, paType)
continue
}
var names []string
for _, name := range pa.Names {
names = append(names, name.Name)
}
allArgs = append(allArgs, strings.Join(names, ",")+" "+paType)
args = append(args, names...)
}
opt.ExtraArgs = strings.Join(args, ",")
opt.ExtraArgsType = strings.Join(allArgs, ",")
argsMap := make(map[string]bool)
for _, arg := range args {
argsMap[arg] = true
}
ignoreCache := make(map[string]bool)
ignoreRaw := make(map[string]bool)
ignoreAddCache := make(map[string]bool)
ignoreArray := [3]map[string]bool{ignoreCache, ignoreRaw, ignoreAddCache}
if *ignores != "" {
is := strings.Split(*ignores, "|")
if len(is) > 3 {
log.Fatalln("ignores参数错误")
}
for i := range is {
if len(is) > i {
for _, s := range strings.Split(is[i], ",") {
ignoreArray[i][s] = true
}
}
}
}
var as []string
for _, arg := range args {
if !ignoreCache[arg] {
as = append(as, arg)
}
}
opt.ExtraCacheArgs = strings.Join(as, ",")
as = []string{}
for _, arg := range args {
if !ignoreRaw[arg] {
as = append(as, arg)
}
}
opt.ExtraRawArgs = strings.Join(as, ",")
as = []string{}
for _, arg := range args {
if !ignoreAddCache[arg] {
as = append(as, arg)
}
}
opt.ExtraAddCacheArgs = strings.Join(as, ",")
if opt.ExtraAddCacheArgs != "" {
opt.ExtraAddCacheArgs = "," + opt.ExtraAddCacheArgs
}
if opt.ExtraRawArgs != "" {
opt.ExtraRawArgs = "," + opt.ExtraRawArgs
}
if opt.ExtraCacheArgs != "" {
opt.ExtraCacheArgs = "," + opt.ExtraCacheArgs
}
if opt.ExtraArgs != "" {
opt.ExtraArgs = "," + opt.ExtraArgs
}
if opt.ExtraArgsType != "" {
opt.ExtraArgsType = "," + opt.ExtraArgsType
}
}
// get k v from results
results := list.Type.(*ast.FuncType).Results.List
if len(results) != 2 {
log.Fatalln(opt.name + ": 参数个数不对")
}
if s.ExprString(results[1].Type) != "error" {
log.Fatalln(opt.name + ": 最后返回值参数需为error")
}
if opt.template == _multiTpl {
p, ok := results[0].Type.(*ast.MapType)
if !ok {
log.Fatalln(opt.name + ": 批量获取方法 返回值类型需为map类型")
}
opt.keyType = s.ExprString(p.Key)
opt.valueType = s.ExprString(p.Value)
} else {
opt.valueType = s.ExprString(results[0].Type)
}
for _, t := range numberTypes {
if t == opt.valueType {
opt.NumberValue = true
break
}
}
opt.ZeroValue = "nil"
for _, t := range simpleTypes {
if t == opt.valueType {
opt.SimpleValue = true
opt.ZeroValue = zeroValue(t)
break
}
}
if !opt.SimpleValue {
for _, t := range []string{"[]", "map"} {
if strings.HasPrefix(opt.valueType, t) {
opt.GoValue = true
break
}
}
}
upperName := strings.ToUpper(opt.name[0:1]) + opt.name[1:]
opt.cacheFunc = fmt.Sprintf("d.Cache%s", upperName)
opt.rawFunc = fmt.Sprintf("d.Raw%s", upperName)
opt.addCacheFunc = fmt.Sprintf("d.AddCache%s", upperName)
if opt.CustomMethod != "" {
arrs := strings.Split(opt.CustomMethod, "|")
if len(arrs) > 0 && arrs[0] != "" {
opt.cacheFunc = arrs[0]
}
if len(arrs) > 1 && arrs[1] != "" {
opt.rawFunc = arrs[1]
}
if len(arrs) > 2 && arrs[2] != "" {
opt.addCacheFunc = arrs[2]
}
}
return
}
// parse parse options
func parse(s *pkg.Source) (opts []*options) {
c := s.F.Scope.Lookup(_interfaceName)
if (c == nil) || (c.Kind != ast.Typ) {
log.Fatalln("无法找到缓存声明")
}
lists := c.Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType).Methods.List
for _, list := range lists {
opt := processList(s, list)
opt.Check()
opts = append(opts, &opt)
}
return
}
func (option *options) Check() {
if !option.SimpleValue && !strings.Contains(option.valueType, "*") && !strings.Contains(option.valueType, "[]") && !strings.Contains(option.valueType, "map") {
log.Fatalf("%s: 值类型只能为基本类型/slice/map/指针类型\n", option.name)
}
if option.EnableSingleFlight && option.EnableBatch {
log.Fatalf("%s: 单飞和批量获取不能同时开启\n", option.name)
}
if option.template != _singleTpl && option.EnablePaging {
log.Fatalf("%s: 分页只能用在单key模板中\n", option.name)
}
if option.SimpleValue && !option.EnableNullCache {
if !((option.template == _multiTpl) && option.NumberValue) {
log.Fatalf("%s: 值为基本类型时需开启空缓存 防止缓存零值穿透\n", option.name)
}
}
if option.EnableNullCache {
if !option.SimpleValue && option.CheckNullCode == "" {
log.Fatalf("%s: 缺少-check_null_code参数\n", option.name)
}
if option.SimpleValue && option.NullCache == option.ZeroValue {
log.Fatalf("%s: %s 不能作为空缓存值 \n", option.name, option.NullCache)
}
if strings.Contains(option.NullCache, "{}") {
// -nullcache=[]*model.OrderMain{} 这种无效
log.Fatalf("%s: %s 不能作为空缓存值 会导致空缓存无效 \n", option.name, option.NullCache)
}
if strings.Contains(option.CheckNullCode, "len") && strings.Contains(strings.Replace(option.CheckNullCode, " ", "", -1), "==0") {
// -check_null_code=len($)==0 这种无效
log.Fatalf("%s: -check_null_code=%s 错误 会有无意义的赋值\n", option.name, option.CheckNullCode)
}
}
}
func genHeader(opts []*options) (src string) {
option := options{PkgName: os.Getenv("GOPACKAGE")}
var sfCount int
var packages, sfInit []string
packagesMap := map[string]bool{`"context"`: true}
for _, opt := range opts {
if opt.EnableSingleFlight {
option.EnableSingleFlight = true
sfCount++
}
if opt.EnableBatch {
option.EnableBatch = true
}
if len(opt.importPackages) > 0 {
for _, pkg := range opt.importPackages {
if !packagesMap[pkg] {
packages = append(packages, pkg)
packagesMap[pkg] = true
}
}
}
if opt.Args != "" {
option.Args = opt.Args
}
}
option.ImportPackage = strings.Join(packages, "\n")
for i := 0; i < sfCount; i++ {
sfInit = append(sfInit, "{}")
}
src = _headerTemplate
src = strings.Replace(src, "SFCOUNT", strconv.Itoa(sfCount), -1)
t := template.Must(template.New("header").Parse(src))
var buffer bytes.Buffer
err := t.Execute(&buffer, option)
if err != nil {
log.Fatalf("execute template: %s", err)
}
// Format the output.
src = strings.Replace(buffer.String(), "\t", "", -1)
src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n")
src = strings.Replace(src, "NEWLINE", "", -1)
src = strings.Replace(src, "ARGS", option.Args, -1)
src = strings.Replace(src, "SFINIT", strings.Join(sfInit, ","), -1)
return
}
func genBody(opts []*options) (res string) {
sfnum := -1
for _, option := range opts {
var nullCodeVar, src string
if option.template == _multiTpl {
src = _multiTemplate
nullCodeVar = "v"
} else if option.template == _singleTpl {
src = _singleTemplate
nullCodeVar = "res"
} else {
src = _noneTemplate
nullCodeVar = "res"
}
if option.template != _noneTpl {
src = strings.Replace(src, "KEY", option.keyType, -1)
}
if option.CheckNullCode != "" {
option.CheckNullCode = strings.Replace(option.CheckNullCode, "$", nullCodeVar, -1)
}
if option.EnableSingleFlight {
sfnum++
}
src = strings.Replace(src, "NAME", option.name, -1)
src = strings.Replace(src, "VALUE", option.valueType, -1)
src = strings.Replace(src, "ADDCACHEFUNC", option.addCacheFunc, -1)
src = strings.Replace(src, "CACHEFUNC", option.cacheFunc, -1)
src = strings.Replace(src, "RAWFUNC", option.rawFunc, -1)
src = strings.Replace(src, "GROUPSIZE", strconv.Itoa(option.GroupSize), -1)
src = strings.Replace(src, "MAXGROUP", strconv.Itoa(option.MaxGroup), -1)
src = strings.Replace(src, "SFNUM", strconv.Itoa(sfnum), -1)
t := template.Must(template.New("cache").Parse(src))
var buffer bytes.Buffer
err := t.Execute(&buffer, option)
if err != nil {
log.Fatalf("execute template: %s", err)
}
// Format the output.
src = strings.Replace(buffer.String(), "\t", "", -1)
src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n")
res = res + "\n" + src
}
return
}
func zeroValue(t string) string {
switch t {
case "bool":
return "false"
case "string":
return "\"\""
case "[]byte":
return "nil"
default:
return "0"
}
}
func init() {
for _, name := range optionNames {
optionNamesMap[name] = true
}
}
func main() {
log.SetFlags(0)
defer func() {
if err := recover(); err != nil {
log.Fatalf("程序解析失败, err: %+v", err)
}
}()
options := parse(pkg.NewSource(pkg.SourceText()))
header := genHeader(options)
body := genBody(options)
code := pkg.FormatCode(header + "\n" + body)
// Write to file.
dir := filepath.Dir(".")
outputName := filepath.Join(dir, "dao.bts.go")
err := ioutil.WriteFile(outputName, []byte(code), 0644)
if err != nil {
log.Fatalf("写入文件失败: %s", err)
}
log.Println("dao.bts.go: 生成成功")
}