1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package socket
16
17 import (
18 "bytes"
19 "encoding/json"
20 "errors"
21 "go/parser"
22 "go/token"
23 exec "golang.org/x/sys/execabs"
24 "io"
25 "io/ioutil"
26 "log"
27 "net"
28 "net/http"
29 "net/url"
30 "os"
31 "path/filepath"
32 "runtime"
33 "strings"
34 "time"
35 "unicode/utf8"
36
37 "golang.org/x/net/websocket"
38 "golang.org/x/tools/txtar"
39 )
40
41
42
43 var RunScripts = true
44
45
46
47 var Environ func() []string = os.Environ
48
49 const (
50
51 msgLimit = 1000
52
53
54 msgDelay = 10 * time.Millisecond
55 )
56
57
58
59
60 type Message struct {
61 Id string
62 Kind string
63 Body string
64 Options *Options `json:",omitempty"`
65 }
66
67
68 type Options struct {
69 Race bool
70 }
71
72
73 func NewHandler(origin *url.URL) websocket.Server {
74 return websocket.Server{
75 Config: websocket.Config{Origin: origin},
76 Handshake: handshake,
77 Handler: websocket.Handler(socketHandler),
78 }
79 }
80
81
82 func handshake(c *websocket.Config, req *http.Request) error {
83 o, err := websocket.Origin(c, req)
84 if err != nil {
85 log.Println("bad websocket origin:", err)
86 return websocket.ErrBadWebSocketOrigin
87 }
88 _, port, err := net.SplitHostPort(c.Origin.Host)
89 if err != nil {
90 log.Println("bad websocket origin:", err)
91 return websocket.ErrBadWebSocketOrigin
92 }
93 ok := c.Origin.Scheme == o.Scheme && (c.Origin.Host == o.Host || c.Origin.Host == net.JoinHostPort(o.Host, port))
94 if !ok {
95 log.Println("bad websocket origin:", o)
96 return websocket.ErrBadWebSocketOrigin
97 }
98 log.Println("accepting connection from:", req.RemoteAddr)
99 return nil
100 }
101
102
103
104
105 func socketHandler(c *websocket.Conn) {
106 in, out := make(chan *Message), make(chan *Message)
107 errc := make(chan error, 1)
108
109
110 go func() {
111 dec := json.NewDecoder(c)
112 for {
113 var m Message
114 if err := dec.Decode(&m); err != nil {
115 errc <- err
116 return
117 }
118 in <- &m
119 }
120 }()
121
122
123 go func() {
124 enc := json.NewEncoder(c)
125 for m := range out {
126 if err := enc.Encode(m); err != nil {
127 errc <- err
128 return
129 }
130 }
131 }()
132 defer close(out)
133
134
135 proc := make(map[string]*process)
136 for {
137 select {
138 case m := <-in:
139 switch m.Kind {
140 case "run":
141 log.Println("running snippet from:", c.Request().RemoteAddr)
142 proc[m.Id].Kill()
143 proc[m.Id] = startProcess(m.Id, m.Body, out, m.Options)
144 case "kill":
145 proc[m.Id].Kill()
146 }
147 case err := <-errc:
148 if err != io.EOF {
149
150 log.Println(err)
151 }
152
153 for _, p := range proc {
154 p.Kill()
155 }
156 return
157 }
158 }
159 }
160
161
162 type process struct {
163 out chan<- *Message
164 done chan struct{}
165 run *exec.Cmd
166 path string
167 }
168
169
170
171 func startProcess(id, body string, dest chan<- *Message, opt *Options) *process {
172 var (
173 done = make(chan struct{})
174 out = make(chan *Message)
175 p = &process{out: out, done: done}
176 )
177 go func() {
178 defer close(done)
179 for m := range buffer(limiter(out, p), time.After) {
180 m.Id = id
181 dest <- m
182 }
183 }()
184 var err error
185 if path, args := shebang(body); path != "" {
186 if RunScripts {
187 err = p.startProcess(path, args, body)
188 } else {
189 err = errors.New("script execution is not allowed")
190 }
191 } else {
192 err = p.start(body, opt)
193 }
194 if err != nil {
195 p.end(err)
196 return nil
197 }
198 go func() {
199 p.end(p.run.Wait())
200 }()
201 return p
202 }
203
204
205
206 func (p *process) end(err error) {
207 if p.path != "" {
208 defer os.RemoveAll(p.path)
209 }
210 m := &Message{Kind: "end"}
211 if err != nil {
212 m.Body = err.Error()
213 }
214 p.out <- m
215 close(p.out)
216 }
217
218
219
220 type killer interface {
221 Kill()
222 }
223
224
225
226
227
228
229
230 func limiter(in <-chan *Message, p killer) <-chan *Message {
231 out := make(chan *Message)
232 go func() {
233 defer close(out)
234 n := 0
235 for m := range in {
236 switch {
237 case n < msgLimit || m.Kind == "end":
238 out <- m
239 if m.Kind == "end" {
240 return
241 }
242 case n == msgLimit:
243
244
245
246 go p.Kill()
247 default:
248 continue
249 }
250 n++
251 }
252 }()
253 return out
254 }
255
256
257
258
259
260
261
262
263
264 func buffer(in <-chan *Message, timeAfter func(time.Duration) <-chan time.Time) <-chan *Message {
265 out := make(chan *Message)
266 go func() {
267 defer close(out)
268 var (
269 tc <-chan time.Time
270 buf []byte
271 kind string
272 flush = func() {
273 if len(buf) == 0 {
274 return
275 }
276 out <- &Message{Kind: kind, Body: safeString(buf)}
277 buf = buf[:0]
278 kind = ""
279 }
280 )
281 for {
282 select {
283 case m, ok := <-in:
284 if !ok {
285 flush()
286 return
287 }
288 if m.Kind == "end" {
289 flush()
290 out <- m
291 return
292 }
293 if kind != m.Kind {
294 flush()
295 kind = m.Kind
296 if tc == nil {
297 tc = timeAfter(msgDelay)
298 }
299 }
300 buf = append(buf, m.Body...)
301 case <-tc:
302 flush()
303 tc = nil
304 }
305 }
306 }()
307 return out
308 }
309
310
311 func (p *process) Kill() {
312 if p == nil || p.run == nil {
313 return
314 }
315 p.run.Process.Kill()
316 <-p.done
317 }
318
319
320
321
322 func shebang(body string) (path string, args []string) {
323 body = strings.TrimSpace(body)
324 if !strings.HasPrefix(body, "#!") {
325 return "", nil
326 }
327 if i := strings.Index(body, "\n"); i >= 0 {
328 body = body[:i]
329 }
330 fs := strings.Fields(body[2:])
331 return fs[0], fs
332 }
333
334
335
336 func (p *process) startProcess(path string, args []string, body string) error {
337 cmd := &exec.Cmd{
338 Path: path,
339 Args: args,
340 Stdin: strings.NewReader(body),
341 Stdout: &messageWriter{kind: "stdout", out: p.out},
342 Stderr: &messageWriter{kind: "stderr", out: p.out},
343 }
344 if err := cmd.Start(); err != nil {
345 return err
346 }
347 p.run = cmd
348 return nil
349 }
350
351
352
353 func (p *process) start(body string, opt *Options) error {
354
355
356
357
358
359 path, err := ioutil.TempDir("", "present-")
360 if err != nil {
361 return err
362 }
363 p.path = path
364
365 out := "prog"
366 if runtime.GOOS == "windows" {
367 out = "prog.exe"
368 }
369 bin := filepath.Join(path, out)
370
371
372 a := txtar.Parse([]byte(body))
373 if len(a.Comment) != 0 {
374 a.Files = append(a.Files, txtar.File{Name: "prog.go", Data: a.Comment})
375 a.Comment = nil
376 }
377 hasModfile := false
378 for _, f := range a.Files {
379 err = ioutil.WriteFile(filepath.Join(path, f.Name), f.Data, 0666)
380 if err != nil {
381 return err
382 }
383 if f.Name == "go.mod" {
384 hasModfile = true
385 }
386 }
387
388
389 args := []string{"go", "build", "-tags", "OMIT"}
390 if opt != nil && opt.Race {
391 p.out <- &Message{
392 Kind: "stderr",
393 Body: "Running with race detector.\n",
394 }
395 args = append(args, "-race")
396 }
397 args = append(args, "-o", bin)
398 cmd := p.cmd(path, args...)
399 if !hasModfile {
400 cmd.Env = append(cmd.Env, "GO111MODULE=off")
401 }
402 cmd.Stdout = cmd.Stderr
403 if err := cmd.Run(); err != nil {
404 return err
405 }
406
407
408 if isNacl() {
409 cmd, err = p.naclCmd(bin)
410 if err != nil {
411 return err
412 }
413 } else {
414 cmd = p.cmd("", bin)
415 }
416 if opt != nil && opt.Race {
417 cmd.Env = append(cmd.Env, "GOMAXPROCS=2")
418 }
419 if err := cmd.Start(); err != nil {
420
421
422
423 if name, err := packageName(body); err == nil && name != "main" {
424 return errors.New(`executable programs must use "package main"`)
425 }
426 return err
427 }
428 p.run = cmd
429 return nil
430 }
431
432
433
434 func (p *process) cmd(dir string, args ...string) *exec.Cmd {
435 cmd := exec.Command(args[0], args[1:]...)
436 cmd.Dir = dir
437 cmd.Env = Environ()
438 cmd.Stdout = &messageWriter{kind: "stdout", out: p.out}
439 cmd.Stderr = &messageWriter{kind: "stderr", out: p.out}
440 return cmd
441 }
442
443 func isNacl() bool {
444 for _, v := range append(Environ(), os.Environ()...) {
445 if v == "GOOS=nacl" {
446 return true
447 }
448 }
449 return false
450 }
451
452
453 func (p *process) naclCmd(bin string) (*exec.Cmd, error) {
454 pwd, err := os.Getwd()
455 if err != nil {
456 return nil, err
457 }
458 var args []string
459 env := []string{
460 "NACLENV_GOOS=" + runtime.GOOS,
461 "NACLENV_GOROOT=/go",
462 "NACLENV_NACLPWD=" + strings.Replace(pwd, runtime.GOROOT(), "/go", 1),
463 }
464 switch runtime.GOARCH {
465 case "amd64":
466 env = append(env, "NACLENV_GOARCH=amd64p32")
467 args = []string{"sel_ldr_x86_64"}
468 case "386":
469 env = append(env, "NACLENV_GOARCH=386")
470 args = []string{"sel_ldr_x86_32"}
471 case "arm":
472 env = append(env, "NACLENV_GOARCH=arm")
473 selLdr, err := exec.LookPath("sel_ldr_arm")
474 if err != nil {
475 return nil, err
476 }
477 args = []string{"nacl_helper_bootstrap_arm", selLdr, "--reserved_at_zero=0xXXXXXXXXXXXXXXXX"}
478 default:
479 return nil, errors.New("native client does not support GOARCH=" + runtime.GOARCH)
480 }
481
482 cmd := p.cmd("", append(args, "-l", "/dev/null", "-S", "-e", bin)...)
483 cmd.Env = append(cmd.Env, env...)
484
485 return cmd, nil
486 }
487
488 func packageName(body string) (string, error) {
489 f, err := parser.ParseFile(token.NewFileSet(), "prog.go",
490 strings.NewReader(body), parser.PackageClauseOnly)
491 if err != nil {
492 return "", err
493 }
494 return f.Name.String(), nil
495 }
496
497
498
499 type messageWriter struct {
500 kind string
501 out chan<- *Message
502 }
503
504 func (w *messageWriter) Write(b []byte) (n int, err error) {
505 w.out <- &Message{Kind: w.kind, Body: safeString(b)}
506 return len(b), nil
507 }
508
509
510 func safeString(b []byte) string {
511 if utf8.Valid(b) {
512 return string(b)
513 }
514 var buf bytes.Buffer
515 for len(b) > 0 {
516 r, size := utf8.DecodeRune(b)
517 b = b[size:]
518 buf.WriteRune(r)
519 }
520 return buf.String()
521 }
522
View as plain text