1
2
3
4
5 package expect
6
7 import (
8 "fmt"
9 "go/ast"
10 "go/parser"
11 "go/token"
12 "path/filepath"
13 "regexp"
14 "strconv"
15 "strings"
16 "text/scanner"
17
18 "golang.org/x/mod/modfile"
19 )
20
21 const commentStart = "@"
22 const commentStartLen = len(commentStart)
23
24
25 type Identifier string
26
27
28
29
30
31
32
33
34 func Parse(fset *token.FileSet, filename string, content []byte) ([]*Note, error) {
35 var src interface{}
36 if content != nil {
37 src = content
38 }
39 switch filepath.Ext(filename) {
40 case ".go":
41
42
43
44
45 file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.AllErrors)
46 if file == nil {
47 return nil, err
48 }
49 return ExtractGo(fset, file)
50 case ".mod":
51 file, err := modfile.Parse(filename, content, nil)
52 if err != nil {
53 return nil, err
54 }
55 f := fset.AddFile(filename, -1, len(content))
56 f.SetLinesForContent(content)
57 notes, err := extractMod(fset, file)
58 if err != nil {
59 return nil, err
60 }
61
62
63 for _, note := range notes {
64 note.Pos += token.Pos(f.Base())
65 }
66 return notes, nil
67 }
68 return nil, nil
69 }
70
71
72
73
74
75
76
77 func extractMod(fset *token.FileSet, file *modfile.File) ([]*Note, error) {
78 var notes []*Note
79 for _, stmt := range file.Syntax.Stmt {
80 comment := stmt.Comment()
81 if comment == nil {
82 continue
83 }
84
85
86
87 for _, cmt := range comment.Before {
88 text, adjust := getAdjustedNote(cmt.Token)
89 if text == "" {
90 continue
91 }
92 parsed, err := parse(fset, token.Pos(int(cmt.Start.Byte)+adjust), text)
93 if err != nil {
94 return nil, err
95 }
96 notes = append(notes, parsed...)
97 }
98
99 for _, cmt := range comment.Suffix {
100 text, adjust := getAdjustedNote(cmt.Token)
101 if text == "" {
102 continue
103 }
104 parsed, err := parse(fset, token.Pos(int(cmt.Start.Byte)+adjust), text)
105 if err != nil {
106 return nil, err
107 }
108 notes = append(notes, parsed...)
109 }
110 }
111 return notes, nil
112 }
113
114
115
116
117
118
119 func ExtractGo(fset *token.FileSet, file *ast.File) ([]*Note, error) {
120 var notes []*Note
121 for _, g := range file.Comments {
122 for _, c := range g.List {
123 text, adjust := getAdjustedNote(c.Text)
124 if text == "" {
125 continue
126 }
127 parsed, err := parse(fset, token.Pos(int(c.Pos())+adjust), text)
128 if err != nil {
129 return nil, err
130 }
131 notes = append(notes, parsed...)
132 }
133 }
134 return notes, nil
135 }
136
137 func getAdjustedNote(text string) (string, int) {
138 if strings.HasPrefix(text, "/*") {
139 text = strings.TrimSuffix(text, "*/")
140 }
141 text = text[2:]
142
143
144
145
146
147
148 var adjust int
149 if i := strings.Index(text, commentStart); i > 2 {
150
151 pre := text[i-2 : i]
152 if pre != "//" {
153 return "", 0
154 }
155 text = text[i:]
156 adjust = i
157 }
158 if !strings.HasPrefix(text, commentStart) {
159 return "", 0
160 }
161 text = text[commentStartLen:]
162 return text, commentStartLen + adjust + 1
163 }
164
165 const invalidToken rune = 0
166
167 type tokens struct {
168 scanner scanner.Scanner
169 current rune
170 err error
171 base token.Pos
172 }
173
174 func (t *tokens) Init(base token.Pos, text string) *tokens {
175 t.base = base
176 t.scanner.Init(strings.NewReader(text))
177 t.scanner.Mode = scanner.GoTokens
178 t.scanner.Whitespace ^= 1 << '\n'
179 t.scanner.Error = func(s *scanner.Scanner, msg string) {
180 t.Errorf("%v", msg)
181 }
182 return t
183 }
184
185 func (t *tokens) Consume() string {
186 t.current = invalidToken
187 return t.scanner.TokenText()
188 }
189
190 func (t *tokens) Token() rune {
191 if t.err != nil {
192 return scanner.EOF
193 }
194 if t.current == invalidToken {
195 t.current = t.scanner.Scan()
196 }
197 return t.current
198 }
199
200 func (t *tokens) Skip(r rune) int {
201 i := 0
202 for t.Token() == '\n' {
203 t.Consume()
204 i++
205 }
206 return i
207 }
208
209 func (t *tokens) TokenString() string {
210 return scanner.TokenString(t.Token())
211 }
212
213 func (t *tokens) Pos() token.Pos {
214 return t.base + token.Pos(t.scanner.Position.Offset)
215 }
216
217 func (t *tokens) Errorf(msg string, args ...interface{}) {
218 if t.err != nil {
219 return
220 }
221 t.err = fmt.Errorf(msg, args...)
222 }
223
224 func parse(fset *token.FileSet, base token.Pos, text string) ([]*Note, error) {
225 t := new(tokens).Init(base, text)
226 notes := parseComment(t)
227 if t.err != nil {
228 return nil, fmt.Errorf("%v:%s", fset.Position(t.Pos()), t.err)
229 }
230 return notes, nil
231 }
232
233 func parseComment(t *tokens) []*Note {
234 var notes []*Note
235 for {
236 t.Skip('\n')
237 switch t.Token() {
238 case scanner.EOF:
239 return notes
240 case scanner.Ident:
241 notes = append(notes, parseNote(t))
242 default:
243 t.Errorf("unexpected %s parsing comment, expect identifier", t.TokenString())
244 return nil
245 }
246 switch t.Token() {
247 case scanner.EOF:
248 return notes
249 case ',', '\n':
250 t.Consume()
251 default:
252 t.Errorf("unexpected %s parsing comment, expect separator", t.TokenString())
253 return nil
254 }
255 }
256 }
257
258 func parseNote(t *tokens) *Note {
259 n := &Note{
260 Pos: t.Pos(),
261 Name: t.Consume(),
262 }
263
264 switch t.Token() {
265 case ',', '\n', scanner.EOF:
266
267 return n
268 case '(':
269 n.Args = parseArgumentList(t)
270 return n
271 default:
272 t.Errorf("unexpected %s parsing note", t.TokenString())
273 return nil
274 }
275 }
276
277 func parseArgumentList(t *tokens) []interface{} {
278 args := []interface{}{}
279 t.Consume()
280 t.Skip('\n')
281 for t.Token() != ')' {
282 args = append(args, parseArgument(t))
283 if t.Token() != ',' {
284 break
285 }
286 t.Consume()
287 t.Skip('\n')
288 }
289 if t.Token() != ')' {
290 t.Errorf("unexpected %s parsing argument list", t.TokenString())
291 return nil
292 }
293 t.Consume()
294 return args
295 }
296
297 func parseArgument(t *tokens) interface{} {
298 switch t.Token() {
299 case scanner.Ident:
300 v := t.Consume()
301 switch v {
302 case "true":
303 return true
304 case "false":
305 return false
306 case "nil":
307 return nil
308 case "re":
309 if t.Token() != scanner.String && t.Token() != scanner.RawString {
310 t.Errorf("re must be followed by string, got %s", t.TokenString())
311 return nil
312 }
313 pattern, _ := strconv.Unquote(t.Consume())
314 re, err := regexp.Compile(pattern)
315 if err != nil {
316 t.Errorf("invalid regular expression %s: %v", pattern, err)
317 return nil
318 }
319 return re
320 default:
321 return Identifier(v)
322 }
323
324 case scanner.String, scanner.RawString:
325 v, _ := strconv.Unquote(t.Consume())
326 return v
327
328 case scanner.Int:
329 s := t.Consume()
330 v, err := strconv.ParseInt(s, 0, 0)
331 if err != nil {
332 t.Errorf("cannot convert %v to int: %v", s, err)
333 }
334 return v
335
336 case scanner.Float:
337 s := t.Consume()
338 v, err := strconv.ParseFloat(s, 64)
339 if err != nil {
340 t.Errorf("cannot convert %v to float: %v", s, err)
341 }
342 return v
343
344 case scanner.Char:
345 t.Errorf("unexpected char literal %s", t.Consume())
346 return nil
347
348 default:
349 t.Errorf("unexpected %s parsing argument", t.TokenString())
350 return nil
351 }
352 }
353
View as plain text