You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
570 lines
16 KiB
570 lines
16 KiB
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"flag"
|
|
"go/ast"
|
|
"io/ioutil"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"text/template"
|
|
|
|
common "github.com/bilibili/kratos/tool/pkg"
|
|
)
|
|
|
|
var (
|
|
encode = flag.String("encode", "", "encode type: json/pb/raw/gob/gzip")
|
|
mcType = flag.String("type", "", "type: get/set/del/replace/only_add")
|
|
key = flag.String("key", "", "key name method")
|
|
expire = flag.String("expire", "", "expire time code")
|
|
structName = flag.String("struct_name", "dao", "struct name")
|
|
batchSize = flag.Int("batch", 0, "batch size")
|
|
batchErr = flag.String("batch_err", "break", "batch err to contine or break")
|
|
maxGroup = flag.Int("max_group", 0, "max group size")
|
|
checkNullCode = flag.String("check_null_code", "", "check null code")
|
|
nullExpire = flag.String("null_expire", "", "null cache expire time code")
|
|
|
|
mcValidTypes = []string{"set", "replace", "del", "get", "only_add"}
|
|
mcValidPrefix = []string{"set", "replace", "del", "get", "cache", "add"}
|
|
optionNamesMap = map[string]bool{"batch": true, "max_group": true, "encode": true, "type": true, "key": true, "expire": true, "batch_err": true, "struct_name": true, "check_null_code": true, "null_expire": true}
|
|
simpleTypes = []string{"int", "int8", "int16", "int32", "int64", "float32", "float64", "uint", "uint8", "uint16", "uint32", "uint64", "bool", "string", "[]byte"}
|
|
lenTypes = []string{"[]", "map"}
|
|
)
|
|
|
|
const (
|
|
_interfaceName = "_mc"
|
|
_multiTpl = 1
|
|
_singleTpl = 2
|
|
_noneTpl = 3
|
|
_typeGet = "get"
|
|
_typeSet = "set"
|
|
_typeDel = "del"
|
|
_typeReplace = "replace"
|
|
_typeAdd = "only_add"
|
|
)
|
|
|
|
func resetFlag() {
|
|
*encode = ""
|
|
*mcType = ""
|
|
*batchSize = 0
|
|
*maxGroup = 0
|
|
*batchErr = "break"
|
|
*checkNullCode = ""
|
|
*nullExpire = ""
|
|
*structName = "dao"
|
|
}
|
|
|
|
// options options
|
|
type options struct {
|
|
name string
|
|
keyType string
|
|
ValueType string
|
|
template int
|
|
SimpleValue bool
|
|
// int float 类型
|
|
GetSimpleValue bool
|
|
// string, []byte类型
|
|
GetDirectValue bool
|
|
ConvertValue2Bytes string
|
|
ConvertBytes2Value string
|
|
GoValue bool
|
|
ImportPackage string
|
|
importPackages []string
|
|
Args string
|
|
PkgName string
|
|
ExtraArgsType string
|
|
ExtraArgs string
|
|
MCType string
|
|
KeyMethod string
|
|
ExpireCode string
|
|
Encode string
|
|
UseMemcached bool
|
|
OriginValueType string
|
|
UseStrConv bool
|
|
Comment string
|
|
GroupSize int
|
|
MaxGroup int
|
|
EnableBatch bool
|
|
BatchErrBreak bool
|
|
LenType bool
|
|
PointType bool
|
|
StructName string
|
|
CheckNullCode string
|
|
ExpireNullCode string
|
|
EnableNullCode bool
|
|
}
|
|
|
|
func getOptions(opt *options, comment string) {
|
|
os.Args = []string{os.Args[0]}
|
|
if regexp.MustCompile(`\s+//\s*mc:.+`).Match([]byte(comment)) {
|
|
args := strings.Split(common.RegexpReplace(`//\s*mc:(?P<arg>.+)`, comment, "$arg"), " ")
|
|
for _, arg := range args {
|
|
arg = strings.TrimSpace(arg)
|
|
if arg != "" {
|
|
// validate option name
|
|
argName := common.RegexpReplace(`-(?P<name>[\w_-]+)=.+`, arg, "$name")
|
|
if !optionNamesMap[argName] {
|
|
log.Fatalf("选项:%s 不存在 请检查拼写\n", argName)
|
|
}
|
|
os.Args = append(os.Args, arg)
|
|
}
|
|
}
|
|
}
|
|
resetFlag()
|
|
flag.Parse()
|
|
if *mcType != "" {
|
|
opt.MCType = *mcType
|
|
}
|
|
if *key != "" {
|
|
opt.KeyMethod = *key
|
|
}
|
|
if *expire != "" {
|
|
opt.ExpireCode = *expire
|
|
}
|
|
opt.EnableBatch = (*batchSize != 0) && (*maxGroup != 0)
|
|
opt.BatchErrBreak = *batchErr == "break"
|
|
opt.GroupSize = *batchSize
|
|
opt.MaxGroup = *maxGroup
|
|
opt.StructName = *structName
|
|
opt.CheckNullCode = *checkNullCode
|
|
if *nullExpire != "" {
|
|
opt.ExpireNullCode = *nullExpire
|
|
}
|
|
if opt.CheckNullCode != "" {
|
|
opt.EnableNullCode = true
|
|
}
|
|
}
|
|
|
|
func getTypeFromPrefix(opt *options, params []*ast.Field, s *common.Source) {
|
|
if opt.MCType == "" {
|
|
for _, t := range mcValidPrefix {
|
|
if strings.HasPrefix(strings.ToLower(opt.name), t) {
|
|
if t == "add" {
|
|
t = _typeSet
|
|
}
|
|
opt.MCType = t
|
|
break
|
|
}
|
|
}
|
|
if opt.MCType == "" {
|
|
log.Fatalln(opt.name + "请指定方法类型(type=get/set/del...)")
|
|
}
|
|
}
|
|
if opt.MCType == "cache" {
|
|
opt.MCType = _typeGet
|
|
}
|
|
if len(params) == 0 {
|
|
log.Fatalln(opt.name + "参数不足")
|
|
}
|
|
for _, p := range params {
|
|
if len(p.Names) > 1 {
|
|
log.Fatalln(opt.name + "不支持省略类型 请写全声明中的字段类型名称")
|
|
}
|
|
}
|
|
if s.ExprString(params[0].Type) != "context.Context" {
|
|
log.Fatalln(opt.name + "第一个参数必须为context")
|
|
}
|
|
for _, param := range params {
|
|
if len(param.Names) > 1 {
|
|
log.Fatalln(opt.name + "不支持省略类型")
|
|
}
|
|
}
|
|
}
|
|
|
|
func processList(s *common.Source, list *ast.Field) (opt options) {
|
|
src := s.Src
|
|
fset := s.Fset
|
|
lines := strings.Split(src, "\n")
|
|
opt = options{Args: s.GetDef(_interfaceName), UseMemcached: true, importPackages: s.Packages(list)}
|
|
opt.name = list.Names[0].Name
|
|
opt.KeyMethod = "key" + opt.name
|
|
opt.ExpireCode = "d.mc" + opt.name + "Expire"
|
|
opt.ExpireNullCode = "300" // 默认5分钟
|
|
// get comment
|
|
line := fset.Position(list.Pos()).Line - 3
|
|
if len(lines)-1 >= line {
|
|
comment := lines[line]
|
|
opt.Comment = common.RegexpReplace(`\s+//(?P<name>.+)`, comment, "$name")
|
|
opt.Comment = strings.TrimSpace(opt.Comment)
|
|
}
|
|
// get options
|
|
line = fset.Position(list.Pos()).Line - 2
|
|
comment := lines[line]
|
|
getOptions(&opt, comment)
|
|
// get type from prefix
|
|
params := list.Type.(*ast.FuncType).Params.List
|
|
getTypeFromPrefix(&opt, params, s)
|
|
// get template
|
|
if len(params) == 1 {
|
|
opt.template = _noneTpl
|
|
} else if (len(params) == 2) && (opt.MCType == _typeSet || opt.MCType == _typeAdd || opt.MCType == _typeReplace) {
|
|
if _, ok := params[1].Type.(*ast.MapType); ok {
|
|
opt.template = _multiTpl
|
|
} else {
|
|
opt.template = _noneTpl
|
|
}
|
|
} else {
|
|
if _, ok := params[1].Type.(*ast.ArrayType); ok {
|
|
opt.template = _multiTpl
|
|
} else if _, ok := params[1].Type.(*ast.MapType); ok {
|
|
opt.template = _multiTpl
|
|
} else {
|
|
opt.template = _singleTpl
|
|
}
|
|
}
|
|
// extra args
|
|
if len(params) > 2 {
|
|
args := []string{""}
|
|
allArgs := []string{""}
|
|
var pos = 2
|
|
if (opt.MCType == _typeAdd) || (opt.MCType == _typeSet) || (opt.MCType == _typeReplace) {
|
|
pos = 3
|
|
}
|
|
if opt.template == _multiTpl && opt.MCType == _typeSet {
|
|
pos = 2
|
|
}
|
|
for _, pa := range params[pos:] {
|
|
paType := s.ExprString(pa.Type)
|
|
if len(pa.Names) == 0 {
|
|
args = append(args, paType)
|
|
allArgs = append(allArgs, paType)
|
|
continue
|
|
}
|
|
var names []string
|
|
for _, name := range pa.Names {
|
|
names = append(names, name.Name)
|
|
}
|
|
allArgs = append(allArgs, strings.Join(names, ",")+" "+paType)
|
|
args = append(args, strings.Join(names, ","))
|
|
}
|
|
if len(args) > 1 {
|
|
opt.ExtraArgs = strings.Join(args, ",")
|
|
opt.ExtraArgsType = strings.Join(allArgs, ",")
|
|
}
|
|
}
|
|
results := list.Type.(*ast.FuncType).Results.List
|
|
getKeyValueType(&opt, params, results, s)
|
|
return
|
|
}
|
|
|
|
func getKeyValueType(opt *options, params, results []*ast.Field, s *common.Source) {
|
|
// check
|
|
if s.ExprString(results[len(results)-1].Type) != "error" {
|
|
log.Fatalln("最后返回值参数需为error")
|
|
}
|
|
for _, res := range results {
|
|
if len(res.Names) > 1 {
|
|
log.Fatalln(opt.name + "返回值不支持省略类型")
|
|
}
|
|
}
|
|
if opt.MCType == _typeGet {
|
|
if len(results) != 2 {
|
|
log.Fatalln("参数个数不对")
|
|
}
|
|
}
|
|
// get key type and value type
|
|
if (opt.MCType == _typeAdd) || (opt.MCType == _typeSet) || (opt.MCType == _typeReplace) {
|
|
if opt.template == _multiTpl {
|
|
p, ok := params[1].Type.(*ast.MapType)
|
|
if !ok {
|
|
log.Fatalf("%s: 参数类型错误 批量设置数据时类型需为map类型\n", opt.name)
|
|
}
|
|
opt.keyType = s.ExprString(p.Key)
|
|
opt.ValueType = s.ExprString(p.Value)
|
|
} else if opt.template == _singleTpl {
|
|
opt.keyType = s.ExprString(params[1].Type)
|
|
opt.ValueType = s.ExprString(params[2].Type)
|
|
} else {
|
|
opt.ValueType = s.ExprString(params[1].Type)
|
|
}
|
|
}
|
|
if opt.MCType == _typeGet {
|
|
if opt.template == _multiTpl {
|
|
if p, ok := results[0].Type.(*ast.MapType); ok {
|
|
opt.keyType = s.ExprString(p.Key)
|
|
opt.ValueType = s.ExprString(p.Value)
|
|
} else {
|
|
log.Fatalf("%s: 返回值类型错误 批量获取数据时返回值需为map类型\n", opt.name)
|
|
}
|
|
} else if opt.template == _singleTpl {
|
|
opt.keyType = s.ExprString(params[1].Type)
|
|
opt.ValueType = s.ExprString(results[0].Type)
|
|
} else {
|
|
opt.ValueType = s.ExprString(results[0].Type)
|
|
}
|
|
}
|
|
if opt.MCType == _typeDel {
|
|
if opt.template == _multiTpl {
|
|
p, ok := params[1].Type.(*ast.ArrayType)
|
|
if !ok {
|
|
log.Fatalf("%s: 类型错误 参数需为[]类型\n", opt.name)
|
|
}
|
|
opt.keyType = s.ExprString(p.Elt)
|
|
} else if opt.template == _singleTpl {
|
|
opt.keyType = s.ExprString(params[1].Type)
|
|
}
|
|
}
|
|
for _, t := range simpleTypes {
|
|
if t == opt.ValueType {
|
|
opt.SimpleValue = true
|
|
opt.GetSimpleValue = true
|
|
opt.ConvertValue2Bytes = convertValue2Bytes(t)
|
|
opt.ConvertBytes2Value = convertBytes2Value(t)
|
|
break
|
|
}
|
|
}
|
|
if opt.ValueType == "string" {
|
|
opt.LenType = true
|
|
} else {
|
|
for _, t := range lenTypes {
|
|
if strings.HasPrefix(opt.ValueType, t) {
|
|
opt.LenType = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if opt.SimpleValue && (opt.ValueType == "[]byte" || opt.ValueType == "string") {
|
|
opt.GetSimpleValue = false
|
|
opt.GetDirectValue = true
|
|
}
|
|
if opt.MCType == _typeGet && opt.template == _multiTpl {
|
|
opt.UseMemcached = false
|
|
}
|
|
if strings.HasPrefix(opt.ValueType, "*") {
|
|
opt.PointType = true
|
|
opt.OriginValueType = strings.Replace(opt.ValueType, "*", "", 1)
|
|
} else {
|
|
opt.OriginValueType = opt.ValueType
|
|
}
|
|
if *encode != "" {
|
|
var flags []string
|
|
for _, f := range strings.Split(*encode, "|") {
|
|
switch f {
|
|
case "gob":
|
|
flags = append(flags, "memcache.FlagGOB")
|
|
case "json":
|
|
flags = append(flags, "memcache.FlagJSON")
|
|
case "raw":
|
|
flags = append(flags, "memcache.FlagRAW")
|
|
case "pb":
|
|
flags = append(flags, "memcache.FlagProtobuf")
|
|
case "gzip":
|
|
flags = append(flags, "memcache.FlagGzip")
|
|
default:
|
|
log.Fatalf("%s: encode类型无效\n", opt.name)
|
|
}
|
|
}
|
|
opt.Encode = strings.Join(flags, " | ")
|
|
} else {
|
|
if opt.SimpleValue {
|
|
opt.Encode = "memcache.FlagRAW"
|
|
} else {
|
|
opt.Encode = "memcache.FlagJSON"
|
|
}
|
|
}
|
|
}
|
|
|
|
func parse(s *common.Source) (opts []*options) {
|
|
c := s.F.Scope.Lookup(_interfaceName)
|
|
if (c == nil) || (c.Kind != ast.Typ) {
|
|
log.Fatalln("无法找到缓存声明")
|
|
}
|
|
lists := c.Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType).Methods.List
|
|
for _, list := range lists {
|
|
opt := processList(s, list)
|
|
opt.Check()
|
|
opts = append(opts, &opt)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (option *options) Check() {
|
|
var valid bool
|
|
for _, x := range mcValidTypes {
|
|
if x == option.MCType {
|
|
valid = true
|
|
break
|
|
}
|
|
}
|
|
if !valid {
|
|
log.Fatalf("%s: 类型错误 不支持%s类型\n", option.name, option.MCType)
|
|
}
|
|
if (option.MCType != _typeDel) && !option.SimpleValue && !strings.Contains(option.ValueType, "*") && !strings.Contains(option.ValueType, "[]") && !strings.Contains(option.ValueType, "map") {
|
|
log.Fatalf("%s: 值类型只能为基本类型/slice/map/指针类型\n", option.name)
|
|
}
|
|
}
|
|
|
|
func genHeader(opts []*options) (src string) {
|
|
option := options{PkgName: os.Getenv("GOPACKAGE"), UseMemcached: false}
|
|
var packages []string
|
|
packagesMap := map[string]bool{`"context"`: true}
|
|
for _, opt := range opts {
|
|
if len(opt.importPackages) > 0 {
|
|
for _, pkg := range opt.importPackages {
|
|
if !packagesMap[pkg] {
|
|
packages = append(packages, pkg)
|
|
packagesMap[pkg] = true
|
|
}
|
|
}
|
|
}
|
|
if opt.Args != "" {
|
|
option.Args = opt.Args
|
|
}
|
|
if opt.UseMemcached {
|
|
option.UseMemcached = true
|
|
}
|
|
if opt.SimpleValue && !opt.GetDirectValue {
|
|
option.UseStrConv = true
|
|
}
|
|
if opt.EnableBatch {
|
|
option.EnableBatch = true
|
|
}
|
|
}
|
|
option.ImportPackage = strings.Join(packages, "\n")
|
|
src = _headerTemplate
|
|
t := template.Must(template.New("header").Parse(src))
|
|
var buffer bytes.Buffer
|
|
err := t.Execute(&buffer, option)
|
|
if err != nil {
|
|
log.Fatalf("execute template: %s", err)
|
|
}
|
|
// Format the output.
|
|
src = strings.Replace(buffer.String(), "\t", "", -1)
|
|
src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n")
|
|
src = strings.Replace(src, "NEWLINE", "", -1)
|
|
src = strings.Replace(src, "ARGS", option.Args, -1)
|
|
return
|
|
}
|
|
|
|
func getNewTemplate(option *options) (src string) {
|
|
if option.template == _multiTpl {
|
|
switch option.MCType {
|
|
case _typeGet:
|
|
src = _multiGetTemplate
|
|
case _typeSet:
|
|
src = _multiSetTemplate
|
|
case _typeReplace:
|
|
src = _multiReplaceTemplate
|
|
case _typeDel:
|
|
src = _multiDelTemplate
|
|
case _typeAdd:
|
|
src = _multiAddTemplate
|
|
}
|
|
} else if option.template == _singleTpl {
|
|
switch option.MCType {
|
|
case _typeGet:
|
|
src = _singleGetTemplate
|
|
case _typeSet:
|
|
src = _singleSetTemplate
|
|
case _typeReplace:
|
|
src = _singleReplaceTemplate
|
|
case _typeDel:
|
|
src = _singleDelTemplate
|
|
case _typeAdd:
|
|
src = _singleAddTemplate
|
|
}
|
|
} else {
|
|
switch option.MCType {
|
|
case _typeGet:
|
|
src = _noneGetTemplate
|
|
case _typeSet:
|
|
src = _noneSetTemplate
|
|
case _typeReplace:
|
|
src = _noneReplaceTemplate
|
|
case _typeDel:
|
|
src = _noneDelTemplate
|
|
case _typeAdd:
|
|
src = _noneAddTemplate
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func genBody(opts []*options) (res string) {
|
|
for _, option := range opts {
|
|
src := getNewTemplate(option)
|
|
src = strings.Replace(src, "KEY", option.keyType, -1)
|
|
src = strings.Replace(src, "NAME", option.name, -1)
|
|
src = strings.Replace(src, "VALUE", option.ValueType, -1)
|
|
src = strings.Replace(src, "GROUPSIZE", strconv.Itoa(option.GroupSize), -1)
|
|
src = strings.Replace(src, "MAXGROUP", strconv.Itoa(option.MaxGroup), -1)
|
|
if option.EnableNullCode {
|
|
option.CheckNullCode = strings.Replace(option.CheckNullCode, "$", "val", -1)
|
|
}
|
|
t := template.Must(template.New("cache").Parse(src))
|
|
var buffer bytes.Buffer
|
|
err := t.Execute(&buffer, option)
|
|
if err != nil {
|
|
log.Fatalf("execute template: %s", err)
|
|
}
|
|
// Format the output.
|
|
src = strings.Replace(buffer.String(), "\t", "", -1)
|
|
src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n")
|
|
res = res + "\n" + src
|
|
}
|
|
return
|
|
}
|
|
|
|
func main() {
|
|
log.SetFlags(0)
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
buf := make([]byte, 64*1024)
|
|
buf = buf[:runtime.Stack(buf, false)]
|
|
log.Fatalf("程序解析失败, err: %+v stack: %s 请企业微信联系 @wangxu01", err, buf)
|
|
}
|
|
}()
|
|
options := parse(common.NewSource(common.SourceText()))
|
|
header := genHeader(options)
|
|
body := genBody(options)
|
|
code := common.FormatCode(header + "\n" + body)
|
|
// Write to file.
|
|
dir := filepath.Dir(".")
|
|
outputName := filepath.Join(dir, "mc.cache.go")
|
|
err := ioutil.WriteFile(outputName, []byte(code), 0644)
|
|
if err != nil {
|
|
log.Fatalf("写入文件失败: %s", err)
|
|
}
|
|
log.Println("mc.cache.go: 生成成功")
|
|
}
|
|
|
|
func convertValue2Bytes(t string) string {
|
|
switch t {
|
|
case "int", "int8", "int16", "int32", "int64":
|
|
return "[]byte(strconv.FormatInt(int64(val), 10))"
|
|
case "uint", "uint8", "uint16", "uint32", "uint64":
|
|
return "[]byte(strconv.FormatUInt(val, 10))"
|
|
case "bool":
|
|
return "[]byte(strconv.FormatBool(val))"
|
|
case "float32":
|
|
return "[]byte(strconv.FormatFloat(val, 'E', -1, 32))"
|
|
case "float64":
|
|
return "[]byte(strconv.FormatFloat(val, 'E', -1, 64))"
|
|
case "string":
|
|
return "[]byte(val)"
|
|
case "[]byte":
|
|
return "val"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func convertBytes2Value(t string) string {
|
|
switch t {
|
|
case "int", "int8", "int16", "int32", "int64":
|
|
return "strconv.ParseInt(v, 10, 64)"
|
|
case "uint", "uint8", "uint16", "uint32", "uint64":
|
|
return "strconv.ParseUInt(v, 10, 64)"
|
|
case "bool":
|
|
return "strconv.ParseBool(v)"
|
|
case "float32":
|
|
return "float32(strconv.ParseFloat(v, 32))"
|
|
case "float64":
|
|
return "strconv.ParseFloat(v, 64)"
|
|
}
|
|
return ""
|
|
}
|
|
|