You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
317 lines
9.3 KiB
317 lines
9.3 KiB
package generator
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"fmt"
|
|
"go/parser"
|
|
"go/printer"
|
|
"go/token"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/go-kratos/kratos/tool/protobuf/pkg/gen"
|
|
"github.com/go-kratos/kratos/tool/protobuf/pkg/naming"
|
|
"github.com/go-kratos/kratos/tool/protobuf/pkg/typemap"
|
|
"github.com/go-kratos/kratos/tool/protobuf/pkg/utils"
|
|
"github.com/golang/protobuf/protoc-gen-go/descriptor"
|
|
plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
const Version = "v0.1"
|
|
|
|
var GoModuleImportPath = "github.com/go-kratos/kratos"
|
|
var GoModuleDirName = "github.com/go-kratos/kratos"
|
|
|
|
type Base struct {
|
|
Reg *typemap.Registry
|
|
|
|
// Map to record whether we've built each package
|
|
// pkgName => alias name
|
|
pkgs map[string]string
|
|
pkgNamesInUse map[string]bool
|
|
|
|
ImportPrefix string // String to prefix to imported package file names.
|
|
importMap map[string]string // Mapping from .proto file name to import path.
|
|
|
|
// Package naming:
|
|
GenPkgName string // Name of the package that we're generating
|
|
PackageName string // Name of the proto file package
|
|
fileToGoPackageName map[*descriptor.FileDescriptorProto]string
|
|
|
|
// List of files that were inputs to the generator. We need to hold this in
|
|
// the struct so we can write a header for the file that lists its inputs.
|
|
GenFiles []*descriptor.FileDescriptorProto
|
|
|
|
// Output buffer that holds the bytes we want to write out for a single file.
|
|
// Gets reset after working on a file.
|
|
Output *bytes.Buffer
|
|
|
|
// key: pkgName
|
|
// value: importPath
|
|
Deps map[string]string
|
|
|
|
Params *ParamsBase
|
|
|
|
httpInfoCache map[string]*HTTPInfo
|
|
}
|
|
|
|
// RegisterPackageName name is the go package name or proto pkg name
|
|
// return go pkg alias
|
|
func (t *Base) RegisterPackageName(name string) (alias string) {
|
|
alias = name
|
|
i := 1
|
|
for t.pkgNamesInUse[alias] {
|
|
alias = name + strconv.Itoa(i)
|
|
i++
|
|
}
|
|
t.pkgNamesInUse[alias] = true
|
|
t.pkgs[name] = alias
|
|
return alias
|
|
}
|
|
|
|
func (t *Base) Setup(in *plugin.CodeGeneratorRequest, paramsOpt ...GeneratorParamsInterface) {
|
|
t.httpInfoCache = make(map[string]*HTTPInfo)
|
|
t.pkgs = make(map[string]string)
|
|
t.pkgNamesInUse = make(map[string]bool)
|
|
t.importMap = make(map[string]string)
|
|
t.Deps = make(map[string]string)
|
|
t.fileToGoPackageName = make(map[*descriptor.FileDescriptorProto]string)
|
|
t.Output = bytes.NewBuffer(nil)
|
|
|
|
var params GeneratorParamsInterface
|
|
if len(paramsOpt) > 0 {
|
|
params = paramsOpt[0]
|
|
} else {
|
|
params = &BasicParam{}
|
|
}
|
|
err := ParseGeneratorParams(in.GetParameter(), params)
|
|
if err != nil {
|
|
gen.Fail("could not parse parameters", err.Error())
|
|
}
|
|
t.Params = params.GetBase()
|
|
t.ImportPrefix = params.GetBase().ImportPrefix
|
|
t.importMap = params.GetBase().ImportMap
|
|
|
|
t.GenFiles = gen.FilesToGenerate(in)
|
|
|
|
// Collect information on types.
|
|
t.Reg = typemap.New(in.ProtoFile)
|
|
t.RegisterPackageName("context")
|
|
t.RegisterPackageName("ioutil")
|
|
t.RegisterPackageName("proto")
|
|
// Time to figure out package names of objects defined in protobuf. First,
|
|
// we'll figure out the name for the package we're generating.
|
|
genPkgName, err := DeduceGenPkgName(t.GenFiles)
|
|
if err != nil {
|
|
gen.Fail(err.Error())
|
|
}
|
|
t.GenPkgName = genPkgName
|
|
// Next, we need to pick names for all the files that are dependencies.
|
|
if len(in.ProtoFile) > 0 {
|
|
t.PackageName = t.GenFiles[0].GetPackage()
|
|
}
|
|
|
|
for _, f := range in.ProtoFile {
|
|
if fileDescSliceContains(t.GenFiles, f) {
|
|
// This is a file we are generating. It gets the shared package name.
|
|
t.fileToGoPackageName[f] = t.GenPkgName
|
|
} else {
|
|
// This is a dependency. Use its package name.
|
|
name := f.GetPackage()
|
|
if name == "" {
|
|
name = utils.BaseName(f.GetName())
|
|
}
|
|
name = utils.CleanIdentifier(name)
|
|
alias := t.RegisterPackageName(name)
|
|
t.fileToGoPackageName[f] = alias
|
|
}
|
|
}
|
|
|
|
for _, f := range t.GenFiles {
|
|
deps := t.DeduceDeps(f)
|
|
for k, v := range deps {
|
|
t.Deps[k] = v
|
|
}
|
|
}
|
|
}
|
|
|
|
func (t *Base) DeduceDeps(file *descriptor.FileDescriptorProto) map[string]string {
|
|
deps := make(map[string]string) // Map of package name to quoted import path.
|
|
ourImportPath := path.Dir(naming.GoFileName(file, ""))
|
|
for _, s := range file.Service {
|
|
for _, m := range s.Method {
|
|
defs := []*typemap.MessageDefinition{
|
|
t.Reg.MethodInputDefinition(m),
|
|
t.Reg.MethodOutputDefinition(m),
|
|
}
|
|
for _, def := range defs {
|
|
if def.File.GetPackage() == t.PackageName {
|
|
continue
|
|
}
|
|
// By default, import path is the dirname of the Go filename.
|
|
importPath := path.Dir(naming.GoFileName(def.File, ""))
|
|
if importPath == ourImportPath {
|
|
continue
|
|
}
|
|
importPath = t.SubstituteImportPath(importPath, def.File.GetName())
|
|
importPath = t.ImportPrefix + importPath
|
|
pkg := t.GoPackageNameForProtoFile(def.File)
|
|
deps[pkg] = strconv.Quote(importPath)
|
|
}
|
|
}
|
|
}
|
|
return deps
|
|
}
|
|
|
|
// DeduceGenPkgName figures out the go package name to use for generated code.
|
|
// Will try to use the explicit go_package setting in a file (if set, must be
|
|
// consistent in all files). If no files have go_package set, then use the
|
|
// protobuf package name (must be consistent in all files)
|
|
func DeduceGenPkgName(genFiles []*descriptor.FileDescriptorProto) (string, error) {
|
|
var genPkgName string
|
|
for _, f := range genFiles {
|
|
name, explicit := naming.GoPackageName(f)
|
|
if explicit {
|
|
name = utils.CleanIdentifier(name)
|
|
if genPkgName != "" && genPkgName != name {
|
|
// Make sure they're all set consistently.
|
|
return "", errors.Errorf("files have conflicting go_package settings, must be the same: %q and %q", genPkgName, name)
|
|
}
|
|
genPkgName = name
|
|
}
|
|
}
|
|
if genPkgName != "" {
|
|
return genPkgName, nil
|
|
}
|
|
|
|
// If there is no explicit setting, then check the implicit package name
|
|
// (derived from the protobuf package name) of the files and make sure it's
|
|
// consistent.
|
|
for _, f := range genFiles {
|
|
name, _ := naming.GoPackageName(f)
|
|
name = utils.CleanIdentifier(name)
|
|
if genPkgName != "" && genPkgName != name {
|
|
return "", errors.Errorf("files have conflicting package names, must be the same or overridden with go_package: %q and %q", genPkgName, name)
|
|
}
|
|
genPkgName = name
|
|
}
|
|
|
|
// All the files have the same name, so we're good.
|
|
return genPkgName, nil
|
|
}
|
|
|
|
func (t *Base) GoPackageNameForProtoFile(file *descriptor.FileDescriptorProto) string {
|
|
return t.fileToGoPackageName[file]
|
|
}
|
|
|
|
func fileDescSliceContains(slice []*descriptor.FileDescriptorProto, f *descriptor.FileDescriptorProto) bool {
|
|
for _, sf := range slice {
|
|
if f == sf {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// P forwards to g.gen.P, which prints output.
|
|
func (t *Base) P(args ...string) {
|
|
for _, v := range args {
|
|
t.Output.WriteString(v)
|
|
}
|
|
t.Output.WriteByte('\n')
|
|
}
|
|
|
|
func (t *Base) FormattedOutput() string {
|
|
// Reformat generated code.
|
|
fset := token.NewFileSet()
|
|
raw := t.Output.Bytes()
|
|
ast, err := parser.ParseFile(fset, "", raw, parser.ParseComments)
|
|
if err != nil {
|
|
// Print out the bad code with line numbers.
|
|
// This should never happen in practice, but it can while changing generated code,
|
|
// so consider this a debugging aid.
|
|
var src bytes.Buffer
|
|
s := bufio.NewScanner(bytes.NewReader(raw))
|
|
for line := 1; s.Scan(); line++ {
|
|
fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
|
|
}
|
|
gen.Fail("bad Go source code was generated:", err.Error(), "\n"+src.String())
|
|
}
|
|
|
|
out := bytes.NewBuffer(nil)
|
|
err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(out, fset, ast)
|
|
if err != nil {
|
|
gen.Fail("generated Go source code could not be reformatted:", err.Error())
|
|
}
|
|
|
|
return out.String()
|
|
}
|
|
|
|
func (t *Base) PrintComments(comments typemap.DefinitionComments) bool {
|
|
text := strings.TrimSuffix(comments.Leading, "\n")
|
|
if len(strings.TrimSpace(text)) == 0 {
|
|
return false
|
|
}
|
|
split := strings.Split(text, "\n")
|
|
for _, line := range split {
|
|
t.P("// ", strings.TrimPrefix(line, " "))
|
|
}
|
|
return len(split) > 0
|
|
}
|
|
|
|
// IsOwnPackage ...
|
|
// protoName is fully qualified name of a type
|
|
func (t *Base) IsOwnPackage(protoName string) bool {
|
|
def := t.Reg.MessageDefinition(protoName)
|
|
if def == nil {
|
|
gen.Fail("could not find message for", protoName)
|
|
}
|
|
return def.File.GetPackage() == t.PackageName
|
|
}
|
|
|
|
// Given a protobuf name for a Message, return the Go name we will use for that
|
|
// type, including its package prefix.
|
|
func (t *Base) GoTypeName(protoName string) string {
|
|
def := t.Reg.MessageDefinition(protoName)
|
|
if def == nil {
|
|
gen.Fail("could not find message for", protoName)
|
|
}
|
|
|
|
var prefix string
|
|
if def.File.GetPackage() != t.PackageName {
|
|
prefix = t.GoPackageNameForProtoFile(def.File) + "."
|
|
}
|
|
|
|
var name string
|
|
for _, parent := range def.Lineage() {
|
|
name += parent.Descriptor.GetName() + "_"
|
|
}
|
|
name += def.Descriptor.GetName()
|
|
return prefix + name
|
|
}
|
|
|
|
func streamingMethod(method *descriptor.MethodDescriptorProto) bool {
|
|
return (method.ServerStreaming != nil && *method.ServerStreaming) || (method.ClientStreaming != nil && *method.ClientStreaming)
|
|
}
|
|
|
|
func (t *Base) ShouldGenForMethod(file *descriptor.FileDescriptorProto,
|
|
service *descriptor.ServiceDescriptorProto,
|
|
method *descriptor.MethodDescriptorProto) bool {
|
|
if streamingMethod(method) {
|
|
return false
|
|
}
|
|
if !t.Params.ExplicitHTTP {
|
|
return true
|
|
}
|
|
httpInfo := t.GetHttpInfoCached(file, service, method)
|
|
return httpInfo.HasExplicitHTTPPath
|
|
}
|
|
func (t *Base) SubstituteImportPath(importPath string, importFile string) string {
|
|
if substitution, ok := t.importMap[importFile]; ok {
|
|
importPath = substitution
|
|
}
|
|
return importPath
|
|
}
|
|
|