1
2
3
4
5 package packagestest
6
7 import (
8 "fmt"
9 "go/token"
10 "io/ioutil"
11 "os"
12 "path/filepath"
13 "reflect"
14 "regexp"
15 "strings"
16
17 "golang.org/x/tools/go/expect"
18 "golang.org/x/tools/go/packages"
19 )
20
21 const (
22 markMethod = "mark"
23 eofIdentifier = "EOF"
24 )
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76 func (e *Exported) Expect(methods map[string]interface{}) error {
77 if err := e.getNotes(); err != nil {
78 return err
79 }
80 if err := e.getMarkers(); err != nil {
81 return err
82 }
83 var err error
84 ms := make(map[string]method, len(methods))
85 for name, f := range methods {
86 mi := method{f: reflect.ValueOf(f)}
87 mi.converters = make([]converter, mi.f.Type().NumIn())
88 for i := 0; i < len(mi.converters); i++ {
89 mi.converters[i], err = e.buildConverter(mi.f.Type().In(i))
90 if err != nil {
91 return fmt.Errorf("invalid method %v: %v", name, err)
92 }
93 }
94 ms[name] = mi
95 }
96 for _, n := range e.notes {
97 if n.Args == nil {
98
99 n = &expect.Note{
100 Pos: n.Pos,
101 Name: markMethod,
102 Args: []interface{}{n.Name, n.Name},
103 }
104 }
105 mi, ok := ms[n.Name]
106 if !ok {
107 continue
108 }
109 params := make([]reflect.Value, len(mi.converters))
110 args := n.Args
111 for i, convert := range mi.converters {
112 params[i], args, err = convert(n, args)
113 if err != nil {
114 return fmt.Errorf("%v: %v", e.ExpectFileSet.Position(n.Pos), err)
115 }
116 }
117 if len(args) > 0 {
118 return fmt.Errorf("%v: unwanted args got %+v extra", e.ExpectFileSet.Position(n.Pos), args)
119 }
120
121 mi.f.Call(params)
122 }
123 return nil
124 }
125
126
127 type Range struct {
128 TokFile *token.File
129 Start, End token.Pos
130 }
131
132
133
134
135
136
137
138
139
140
141
142
143 type rangeSetter interface {
144 SetRange(file *token.File, start, end token.Pos)
145 }
146
147
148 func (e *Exported) Mark(name string, r Range) {
149 if e.markers == nil {
150 e.markers = make(map[string]Range)
151 }
152 e.markers[name] = r
153 }
154
155 func (e *Exported) getNotes() error {
156 if e.notes != nil {
157 return nil
158 }
159 notes := []*expect.Note{}
160 var dirs []string
161 for _, module := range e.written {
162 for _, filename := range module {
163 dirs = append(dirs, filepath.Dir(filename))
164 }
165 }
166 for filename := range e.Config.Overlay {
167 dirs = append(dirs, filepath.Dir(filename))
168 }
169 pkgs, err := packages.Load(e.Config, dirs...)
170 if err != nil {
171 return fmt.Errorf("unable to load packages for directories %s: %v", dirs, err)
172 }
173 seen := make(map[token.Position]struct{})
174 for _, pkg := range pkgs {
175 for _, filename := range pkg.GoFiles {
176 content, err := e.FileContents(filename)
177 if err != nil {
178 return err
179 }
180 l, err := expect.Parse(e.ExpectFileSet, filename, content)
181 if err != nil {
182 return fmt.Errorf("failed to extract expectations: %v", err)
183 }
184 for _, note := range l {
185 pos := e.ExpectFileSet.Position(note.Pos)
186 if _, ok := seen[pos]; ok {
187 continue
188 }
189 notes = append(notes, note)
190 seen[pos] = struct{}{}
191 }
192 }
193 }
194 if _, ok := e.written[e.primary]; !ok {
195 e.notes = notes
196 return nil
197 }
198
199
200 if gomod, found := e.written[e.primary]["go.mod"]; found {
201
202 if e.Exporter == Modules {
203 gomod += ".temp"
204 }
205 l, err := goModMarkers(e, gomod)
206 if err != nil {
207 return fmt.Errorf("failed to extract expectations for go.mod: %v", err)
208 }
209 notes = append(notes, l...)
210 }
211 e.notes = notes
212 return nil
213 }
214
215 func goModMarkers(e *Exported, gomod string) ([]*expect.Note, error) {
216 if _, err := os.Stat(gomod); os.IsNotExist(err) {
217
218 return nil, nil
219 }
220 content, err := e.FileContents(gomod)
221 if err != nil {
222 return nil, err
223 }
224 if e.Exporter == GOPATH {
225 return expect.Parse(e.ExpectFileSet, gomod, content)
226 }
227 gomod = strings.TrimSuffix(gomod, ".temp")
228
229 if err := ioutil.WriteFile(gomod, content, 0644); err != nil {
230 return nil, nil
231 }
232 return expect.Parse(e.ExpectFileSet, gomod, content)
233 }
234
235 func (e *Exported) getMarkers() error {
236 if e.markers != nil {
237 return nil
238 }
239
240 e.markers = make(map[string]Range)
241 return e.Expect(map[string]interface{}{
242 markMethod: e.Mark,
243 })
244 }
245
246 var (
247 noteType = reflect.TypeOf((*expect.Note)(nil))
248 identifierType = reflect.TypeOf(expect.Identifier(""))
249 posType = reflect.TypeOf(token.Pos(0))
250 positionType = reflect.TypeOf(token.Position{})
251 rangeType = reflect.TypeOf(Range{})
252 rangeSetterType = reflect.TypeOf((*rangeSetter)(nil)).Elem()
253 fsetType = reflect.TypeOf((*token.FileSet)(nil))
254 regexType = reflect.TypeOf((*regexp.Regexp)(nil))
255 exportedType = reflect.TypeOf((*Exported)(nil))
256 )
257
258
259
260
261
262
263 type converter func(*expect.Note, []interface{}) (reflect.Value, []interface{}, error)
264
265
266
267 type method struct {
268 f reflect.Value
269 converters []converter
270 }
271
272
273
274
275
276 func (e *Exported) buildConverter(pt reflect.Type) (converter, error) {
277 switch {
278 case pt == noteType:
279 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
280 return reflect.ValueOf(n), args, nil
281 }, nil
282 case pt == fsetType:
283 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
284 return reflect.ValueOf(e.ExpectFileSet), args, nil
285 }, nil
286 case pt == exportedType:
287 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
288 return reflect.ValueOf(e), args, nil
289 }, nil
290 case pt == posType:
291 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
292 r, remains, err := e.rangeConverter(n, args)
293 if err != nil {
294 return reflect.Value{}, nil, err
295 }
296 return reflect.ValueOf(r.Start), remains, nil
297 }, nil
298 case pt == positionType:
299 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
300 r, remains, err := e.rangeConverter(n, args)
301 if err != nil {
302 return reflect.Value{}, nil, err
303 }
304 return reflect.ValueOf(e.ExpectFileSet.Position(r.Start)), remains, nil
305 }, nil
306 case pt == rangeType:
307 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
308 r, remains, err := e.rangeConverter(n, args)
309 if err != nil {
310 return reflect.Value{}, nil, err
311 }
312 return reflect.ValueOf(r), remains, nil
313 }, nil
314 case reflect.PtrTo(pt).AssignableTo(rangeSetterType):
315
316 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
317 r, remains, err := e.rangeConverter(n, args)
318 if err != nil {
319 return reflect.Value{}, nil, err
320 }
321 v := reflect.New(pt)
322 v.Interface().(rangeSetter).SetRange(r.TokFile, r.Start, r.End)
323 return v.Elem(), remains, nil
324 }, nil
325 case pt == identifierType:
326 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
327 if len(args) < 1 {
328 return reflect.Value{}, nil, fmt.Errorf("missing argument")
329 }
330 arg := args[0]
331 args = args[1:]
332 switch arg := arg.(type) {
333 case expect.Identifier:
334 return reflect.ValueOf(arg), args, nil
335 default:
336 return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to string", arg)
337 }
338 }, nil
339
340 case pt == regexType:
341 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
342 if len(args) < 1 {
343 return reflect.Value{}, nil, fmt.Errorf("missing argument")
344 }
345 arg := args[0]
346 args = args[1:]
347 if _, ok := arg.(*regexp.Regexp); !ok {
348 return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to *regexp.Regexp", arg)
349 }
350 return reflect.ValueOf(arg), args, nil
351 }, nil
352
353 case pt.Kind() == reflect.String:
354 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
355 if len(args) < 1 {
356 return reflect.Value{}, nil, fmt.Errorf("missing argument")
357 }
358 arg := args[0]
359 args = args[1:]
360 switch arg := arg.(type) {
361 case expect.Identifier:
362 return reflect.ValueOf(string(arg)), args, nil
363 case string:
364 return reflect.ValueOf(arg), args, nil
365 default:
366 return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to string", arg)
367 }
368 }, nil
369 case pt.Kind() == reflect.Int64:
370 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
371 if len(args) < 1 {
372 return reflect.Value{}, nil, fmt.Errorf("missing argument")
373 }
374 arg := args[0]
375 args = args[1:]
376 switch arg := arg.(type) {
377 case int64:
378 return reflect.ValueOf(arg), args, nil
379 default:
380 return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to int", arg)
381 }
382 }, nil
383 case pt.Kind() == reflect.Bool:
384 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
385 if len(args) < 1 {
386 return reflect.Value{}, nil, fmt.Errorf("missing argument")
387 }
388 arg := args[0]
389 args = args[1:]
390 b, ok := arg.(bool)
391 if !ok {
392 return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to bool", arg)
393 }
394 return reflect.ValueOf(b), args, nil
395 }, nil
396 case pt.Kind() == reflect.Slice:
397 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
398 converter, err := e.buildConverter(pt.Elem())
399 if err != nil {
400 return reflect.Value{}, nil, err
401 }
402 result := reflect.MakeSlice(reflect.SliceOf(pt.Elem()), 0, len(args))
403 for range args {
404 value, remains, err := converter(n, args)
405 if err != nil {
406 return reflect.Value{}, nil, err
407 }
408 result = reflect.Append(result, value)
409 args = remains
410 }
411 return result, args, nil
412 }, nil
413 default:
414 if pt.Kind() == reflect.Interface && pt.NumMethod() == 0 {
415 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
416 if len(args) < 1 {
417 return reflect.Value{}, nil, fmt.Errorf("missing argument")
418 }
419 return reflect.ValueOf(args[0]), args[1:], nil
420 }, nil
421 }
422 return nil, fmt.Errorf("param has unexpected type %v (kind %v)", pt, pt.Kind())
423 }
424 }
425
426 func (e *Exported) rangeConverter(n *expect.Note, args []interface{}) (Range, []interface{}, error) {
427 tokFile := e.ExpectFileSet.File(n.Pos)
428 if len(args) < 1 {
429 return Range{}, nil, fmt.Errorf("missing argument")
430 }
431 arg := args[0]
432 args = args[1:]
433 switch arg := arg.(type) {
434 case expect.Identifier:
435
436 switch arg {
437 case eofIdentifier:
438
439 eof := tokFile.Pos(tokFile.Size())
440 return newRange(tokFile, eof, eof), args, nil
441 default:
442
443 mark, ok := e.markers[string(arg)]
444 if !ok {
445 return Range{}, nil, fmt.Errorf("cannot find marker %v", arg)
446 }
447 return mark, args, nil
448 }
449 case string:
450 start, end, err := expect.MatchBefore(e.ExpectFileSet, e.FileContents, n.Pos, arg)
451 if err != nil {
452 return Range{}, nil, err
453 }
454 if !start.IsValid() {
455 return Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.ExpectFileSet.Position(n.Pos), arg)
456 }
457 return newRange(tokFile, start, end), args, nil
458 case *regexp.Regexp:
459 start, end, err := expect.MatchBefore(e.ExpectFileSet, e.FileContents, n.Pos, arg)
460 if err != nil {
461 return Range{}, nil, err
462 }
463 if !start.IsValid() {
464 return Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.ExpectFileSet.Position(n.Pos), arg)
465 }
466 return newRange(tokFile, start, end), args, nil
467 default:
468 return Range{}, nil, fmt.Errorf("cannot convert %v to pos", arg)
469 }
470 }
471
472
473 func newRange(file *token.File, start, end token.Pos) Range {
474 fileBase := file.Base()
475 fileEnd := fileBase + file.Size()
476 if !start.IsValid() {
477 panic("invalid start token.Pos")
478 }
479 if !end.IsValid() {
480 panic("invalid end token.Pos")
481 }
482 if int(start) < fileBase || int(start) > fileEnd {
483 panic(fmt.Sprintf("invalid start: %d not in [%d, %d]", start, fileBase, fileEnd))
484 }
485 if int(end) < fileBase || int(end) > fileEnd {
486 panic(fmt.Sprintf("invalid end: %d not in [%d, %d]", end, fileBase, fileEnd))
487 }
488 if start > end {
489 panic("invalid start: greater than end")
490 }
491 return Range{
492 TokFile: file,
493 Start: start,
494 End: end,
495 }
496 }
497
View as plain text