1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71 package main
72
73 import (
74 "bytes"
75 "flag"
76 "fmt"
77 "go/ast"
78 "go/format"
79 "go/printer"
80 "go/token"
81 "go/types"
82 "io/ioutil"
83 "log"
84 "os"
85 "strconv"
86 "strings"
87 "unicode"
88
89 "golang.org/x/tools/go/packages"
90 )
91
92 var (
93 outputFile = flag.String("o", "", "write output to `file` (default standard output)")
94 dstPath = flag.String("dst", ".", "set destination import `path`")
95 pkgName = flag.String("pkg", "", "set destination package `name`")
96 prefix = flag.String("prefix", "&_", "set bundled identifier prefix to `p` (default is \"&_\", where & stands for the original name)")
97 buildTags = flag.String("tags", "", "the build constraints to be inserted into the generated file")
98
99 importMap = map[string]string{}
100 )
101
102 func init() {
103 flag.Var(flagFunc(addImportMap), "import", "rewrite import using `map`, of form old=new (can be repeated)")
104 }
105
106 func addImportMap(s string) {
107 if strings.Count(s, "=") != 1 {
108 log.Fatal("-import argument must be of the form old=new")
109 }
110 i := strings.Index(s, "=")
111 old, new := s[:i], s[i+1:]
112 if old == "" || new == "" {
113 log.Fatal("-import argument must be of the form old=new; old and new must be non-empty")
114 }
115 importMap[old] = new
116 }
117
118 func usage() {
119 fmt.Fprintf(os.Stderr, "Usage: bundle [options] <src>\n")
120 flag.PrintDefaults()
121 }
122
123 func main() {
124 log.SetPrefix("bundle: ")
125 log.SetFlags(0)
126
127 flag.Usage = usage
128 flag.Parse()
129 args := flag.Args()
130 if len(args) != 1 {
131 usage()
132 os.Exit(2)
133 }
134
135 cfg := &packages.Config{Mode: packages.NeedName}
136 pkgs, err := packages.Load(cfg, *dstPath)
137 if err != nil {
138 log.Fatalf("cannot load destination package: %v", err)
139 }
140 if packages.PrintErrors(pkgs) > 0 || len(pkgs) != 1 {
141 log.Fatalf("failed to load destination package")
142 }
143 if *pkgName == "" {
144 *pkgName = pkgs[0].Name
145 }
146
147 code, err := bundle(args[0], pkgs[0].PkgPath, *pkgName, *prefix, *buildTags)
148 if err != nil {
149 log.Fatal(err)
150 }
151 if *outputFile != "" {
152 err := ioutil.WriteFile(*outputFile, code, 0666)
153 if err != nil {
154 log.Fatal(err)
155 }
156 } else {
157 _, err := os.Stdout.Write(code)
158 if err != nil {
159 log.Fatal(err)
160 }
161 }
162 }
163
164
165 func isStandardImportPath(path string) bool {
166 i := strings.Index(path, "/")
167 if i < 0 {
168 i = len(path)
169 }
170 elem := path[:i]
171 return !strings.Contains(elem, ".")
172 }
173
174 var testingOnlyPackagesConfig *packages.Config
175
176 func bundle(src, dst, dstpkg, prefix, buildTags string) ([]byte, error) {
177
178 cfg := &packages.Config{}
179 if testingOnlyPackagesConfig != nil {
180 *cfg = *testingOnlyPackagesConfig
181 } else {
182
183
184 cfg.Env = append(os.Environ(), "GOFLAGS=-mod=mod")
185 }
186 cfg.Mode = packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo
187 pkgs, err := packages.Load(cfg, src)
188 if err != nil {
189 return nil, err
190 }
191 if packages.PrintErrors(pkgs) > 0 || len(pkgs) != 1 {
192 return nil, fmt.Errorf("failed to load source package")
193 }
194 pkg := pkgs[0]
195
196 if strings.Contains(prefix, "&") {
197 prefix = strings.Replace(prefix, "&", pkg.Syntax[0].Name.Name, -1)
198 }
199
200 objsToUpdate := make(map[types.Object]bool)
201 var rename func(from types.Object)
202 rename = func(from types.Object) {
203 if !objsToUpdate[from] {
204 objsToUpdate[from] = true
205
206
207
208
209
210
211 if _, ok := from.(*types.TypeName); ok {
212 for id, obj := range pkg.TypesInfo.Uses {
213 if obj == from {
214 if field := pkg.TypesInfo.Defs[id]; field != nil {
215 rename(field)
216 }
217 }
218 }
219 }
220 }
221 }
222
223
224 scope := pkg.Types.Scope()
225 for _, name := range scope.Names() {
226 rename(scope.Lookup(name))
227 }
228
229 var out bytes.Buffer
230 if buildTags != "" {
231 fmt.Fprintf(&out, "//go:build %s\n", buildTags)
232 fmt.Fprintf(&out, "// +build %s\n\n", buildTags)
233 }
234
235 fmt.Fprintf(&out, "// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.\n")
236 if *outputFile != "" && buildTags == "" {
237 fmt.Fprintf(&out, "//go:generate bundle %s\n", strings.Join(quoteArgs(os.Args[1:]), " "))
238 } else {
239 fmt.Fprintf(&out, "// $ bundle %s\n", strings.Join(os.Args[1:], " "))
240 }
241 fmt.Fprintf(&out, "\n")
242
243
244 for _, f := range pkg.Syntax {
245 if doc := f.Doc.Text(); strings.TrimSpace(doc) != "" {
246 for _, line := range strings.Split(doc, "\n") {
247 fmt.Fprintf(&out, "// %s\n", line)
248 }
249 }
250 }
251
252 fmt.Fprintln(&out)
253
254 fmt.Fprintf(&out, "package %s\n\n", dstpkg)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271 var pkgStd = make(map[string]bool)
272 var pkgExt = make(map[string]bool)
273 for _, f := range pkg.Syntax {
274 for _, imp := range f.Imports {
275 path, err := strconv.Unquote(imp.Path.Value)
276 if err != nil {
277 log.Fatalf("invalid import path string: %v", err)
278 }
279 if path == dst {
280 continue
281 }
282 if newPath, ok := importMap[path]; ok {
283 path = newPath
284 }
285
286 var name string
287 if imp.Name != nil {
288 name = imp.Name.Name
289 }
290 spec := fmt.Sprintf("%s %q", name, path)
291 if isStandardImportPath(path) {
292 pkgStd[spec] = true
293 } else {
294 pkgExt[spec] = true
295 }
296 }
297 }
298
299
300 fmt.Fprintln(&out, "import (")
301 for p := range pkgStd {
302 fmt.Fprintf(&out, "\t%s\n", p)
303 }
304 if len(pkgExt) > 0 {
305 fmt.Fprintln(&out)
306 }
307 for p := range pkgExt {
308 fmt.Fprintf(&out, "\t%s\n", p)
309 }
310 fmt.Fprint(&out, ")\n\n")
311
312
313 for _, f := range pkg.Syntax {
314
315 for id, obj := range pkg.TypesInfo.Defs {
316 if objsToUpdate[obj] {
317 id.Name = prefix + obj.Name()
318 }
319 }
320 for id, obj := range pkg.TypesInfo.Uses {
321 if objsToUpdate[obj] {
322 id.Name = prefix + obj.Name()
323 }
324 }
325
326
327
328
329 ast.Inspect(f, func(n ast.Node) bool {
330 if sel, ok := n.(*ast.SelectorExpr); ok {
331 if id, ok := sel.X.(*ast.Ident); ok {
332 if obj, ok := pkg.TypesInfo.Uses[id].(*types.PkgName); ok {
333 if obj.Imported().Path() == dst {
334 id.Name = "@@@"
335 }
336 }
337 }
338 }
339 return true
340 })
341
342 last := f.Package
343 if len(f.Imports) > 0 {
344 imp := f.Imports[len(f.Imports)-1]
345 last = imp.End()
346 if imp.Comment != nil {
347 if e := imp.Comment.End(); e > last {
348 last = e
349 }
350 }
351 }
352
353
354
355 var buf bytes.Buffer
356 for _, decl := range f.Decls {
357 if decl, ok := decl.(*ast.GenDecl); ok && decl.Tok == token.IMPORT {
358 continue
359 }
360
361 beg, end := sourceRange(decl)
362
363 printComments(&out, f.Comments, last, beg)
364
365 buf.Reset()
366 format.Node(&buf, pkg.Fset, &printer.CommentedNode{Node: decl, Comments: f.Comments})
367
368
369 out.Write(bytes.Replace(buf.Bytes(), []byte("@@@."), nil, -1))
370
371 last = printSameLineComment(&out, f.Comments, pkg.Fset, end)
372
373 out.WriteString("\n\n")
374 }
375
376 printLastComments(&out, f.Comments, last)
377 }
378
379
380 result, err := format.Source(out.Bytes())
381 if err != nil {
382 log.Fatalf("formatting failed: %v", err)
383 }
384
385 return result, nil
386 }
387
388
389
390 func sourceRange(decl ast.Decl) (beg, end token.Pos) {
391 beg = decl.Pos()
392 end = decl.End()
393
394 var doc, com *ast.CommentGroup
395
396 switch d := decl.(type) {
397 case *ast.GenDecl:
398 doc = d.Doc
399 if len(d.Specs) > 0 {
400 switch spec := d.Specs[len(d.Specs)-1].(type) {
401 case *ast.ValueSpec:
402 com = spec.Comment
403 case *ast.TypeSpec:
404 com = spec.Comment
405 }
406 }
407 case *ast.FuncDecl:
408 doc = d.Doc
409 }
410
411 if doc != nil {
412 beg = doc.Pos()
413 }
414 if com != nil && com.End() > end {
415 end = com.End()
416 }
417
418 return beg, end
419 }
420
421 func printComments(out *bytes.Buffer, comments []*ast.CommentGroup, pos, end token.Pos) {
422 for _, cg := range comments {
423 if pos <= cg.Pos() && cg.Pos() < end {
424 for _, c := range cg.List {
425 fmt.Fprintln(out, c.Text)
426 }
427 fmt.Fprintln(out)
428 }
429 }
430 }
431
432 const infinity = 1 << 30
433
434 func printLastComments(out *bytes.Buffer, comments []*ast.CommentGroup, pos token.Pos) {
435 printComments(out, comments, pos, infinity)
436 }
437
438 func printSameLineComment(out *bytes.Buffer, comments []*ast.CommentGroup, fset *token.FileSet, pos token.Pos) token.Pos {
439 tf := fset.File(pos)
440 for _, cg := range comments {
441 if pos <= cg.Pos() && tf.Line(cg.Pos()) == tf.Line(pos) {
442 for _, c := range cg.List {
443 fmt.Fprintln(out, c.Text)
444 }
445 return cg.End()
446 }
447 }
448 return pos
449 }
450
451 func quoteArgs(ss []string) []string {
452
453
454
455
456
457
458
459
460
461 var qs []string
462 for _, s := range ss {
463 if s == "" || containsSpace(s) {
464 s = strconv.Quote(s)
465 }
466 qs = append(qs, s)
467 }
468 return qs
469 }
470
471 func containsSpace(s string) bool {
472 for _, r := range s {
473 if unicode.IsSpace(r) {
474 return true
475 }
476 }
477 return false
478 }
479
480 type flagFunc func(string)
481
482 func (f flagFunc) Set(s string) error {
483 f(s)
484 return nil
485 }
486
487 func (f flagFunc) String() string { return "" }
488
View as plain text