package pkg import ( "fmt" "go/ast" "go/format" "go/parser" "go/token" "io/ioutil" "log" "os" "regexp" "strings" ) // Source source type Source struct { Fset *token.FileSet Src string F *ast.File } // NewSource new source func NewSource(src string) *Source { s := &Source{ Fset: token.NewFileSet(), Src: src, } f, err := parser.ParseFile(s.Fset, "", src, 0) if err != nil { log.Fatal("无法解析源文件") } s.F = f return s } // ExprString expr string func (s *Source) ExprString(typ ast.Expr) string { fset := s.Fset s1 := fset.Position(typ.Pos()).Offset s2 := fset.Position(typ.End()).Offset return s.Src[s1:s2] } // pkgPath package path func (s *Source) pkgPath(name string) (res string) { for _, im := range s.F.Imports { if im.Name != nil && im.Name.Name == name { return im.Path.Value } } for _, im := range s.F.Imports { if strings.HasSuffix(im.Path.Value, name+"\"") { return im.Path.Value } } return } // GetDef get define code func (s *Source) GetDef(name string) string { c := s.F.Scope.Lookup(name).Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType) s1 := s.Fset.Position(c.Pos()).Offset s2 := s.Fset.Position(c.End()).Offset line := s.Fset.Position(c.Pos()).Line lines := []string{strings.Split(s.Src, "\n")[line-1]} for _, l := range strings.Split(s.Src[s1:s2], "\n")[1:] { lines = append(lines, "\t"+l) } return strings.Join(lines, "\n") } // RegexpReplace replace regexp func RegexpReplace(reg, src, temp string) string { result := []byte{} pattern := regexp.MustCompile(reg) for _, submatches := range pattern.FindAllStringSubmatchIndex(src, -1) { result = pattern.ExpandString(result, temp, src, submatches) } return string(result) } // formatPackage format package func formatPackage(name, path string) (res string) { if path != "" { if strings.HasSuffix(path, name+"\"") { res = path return } res = fmt.Sprintf("%s %s", name, path) } return } // SourceText get source file text func SourceText() string { file := os.Getenv("GOFILE") data, err := ioutil.ReadFile(file) if err != nil { log.Fatal("请使用go generate执行", file) } return string(data) } // FormatCode format code func FormatCode(source string) string { src, err := format.Source([]byte(source)) if err != nil { // Should never happen, but can arise when developing this code. // The user can compile the output to see the error. log.Printf("warning: 输出文件不合法: %s", err) log.Printf("warning: 详细错误请编译查看") return source } return string(src) } // Packages get import packages func (s *Source) Packages(f *ast.Field) (res []string) { fs := f.Type.(*ast.FuncType).Params.List if f.Type.(*ast.FuncType).Results != nil { fs = append(fs, f.Type.(*ast.FuncType).Results.List...) } var types []string resMap := make(map[string]bool) for _, field := range fs { if p, ok := field.Type.(*ast.MapType); ok { types = append(types, s.ExprString(p.Key)) types = append(types, s.ExprString(p.Value)) } else if p, ok := field.Type.(*ast.ArrayType); ok { types = append(types, s.ExprString(p.Elt)) } else { types = append(types, s.ExprString(field.Type)) } } for _, t := range types { name := RegexpReplace(`(?P<pkg>\w+)\.\w+`, t, "$pkg") if name == "" { continue } pkg := formatPackage(name, s.pkgPath(name)) if !resMap[pkg] { resMap[pkg] = true } } for pkg := range resMap { res = append(res, pkg) } return }