1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "fmt"
12 "io"
13 "log"
14 "mime"
15 "net"
16 "net/http"
17 "net/http/internal/ascii"
18 "net/textproto"
19 "net/url"
20 "strings"
21 "sync"
22 "time"
23
24 "golang.org/x/net/http/httpguts"
25 )
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43 type ReverseProxy struct {
44
45
46
47
48
49
50 Director func(*http.Request)
51
52
53
54 Transport http.RoundTripper
55
56
57
58
59
60
61
62
63
64
65
66 FlushInterval time.Duration
67
68
69
70
71 ErrorLog *log.Logger
72
73
74
75
76 BufferPool BufferPool
77
78
79
80
81
82
83
84
85
86
87 ModifyResponse func(*http.Response) error
88
89
90
91
92
93
94 ErrorHandler func(http.ResponseWriter, *http.Request, error)
95 }
96
97
98
99 type BufferPool interface {
100 Get() []byte
101 Put([]byte)
102 }
103
104 func singleJoiningSlash(a, b string) string {
105 aslash := strings.HasSuffix(a, "/")
106 bslash := strings.HasPrefix(b, "/")
107 switch {
108 case aslash && bslash:
109 return a + b[1:]
110 case !aslash && !bslash:
111 return a + "/" + b
112 }
113 return a + b
114 }
115
116 func joinURLPath(a, b *url.URL) (path, rawpath string) {
117 if a.RawPath == "" && b.RawPath == "" {
118 return singleJoiningSlash(a.Path, b.Path), ""
119 }
120
121
122 apath := a.EscapedPath()
123 bpath := b.EscapedPath()
124
125 aslash := strings.HasSuffix(apath, "/")
126 bslash := strings.HasPrefix(bpath, "/")
127
128 switch {
129 case aslash && bslash:
130 return a.Path + b.Path[1:], apath + bpath[1:]
131 case !aslash && !bslash:
132 return a.Path + "/" + b.Path, apath + "/" + bpath
133 }
134 return a.Path + b.Path, apath + bpath
135 }
136
137
138
139
140
141
142
143
144 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
145 targetQuery := target.RawQuery
146 director := func(req *http.Request) {
147 req.URL.Scheme = target.Scheme
148 req.URL.Host = target.Host
149 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
150 if targetQuery == "" || req.URL.RawQuery == "" {
151 req.URL.RawQuery = targetQuery + req.URL.RawQuery
152 } else {
153 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
154 }
155 if _, ok := req.Header["User-Agent"]; !ok {
156
157 req.Header.Set("User-Agent", "")
158 }
159 }
160 return &ReverseProxy{Director: director}
161 }
162
163 func copyHeader(dst, src http.Header) {
164 for k, vv := range src {
165 for _, v := range vv {
166 dst.Add(k, v)
167 }
168 }
169 }
170
171
172
173
174
175
176 var hopHeaders = []string{
177 "Connection",
178 "Proxy-Connection",
179 "Keep-Alive",
180 "Proxy-Authenticate",
181 "Proxy-Authorization",
182 "Te",
183 "Trailer",
184 "Transfer-Encoding",
185 "Upgrade",
186 }
187
188 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
189 p.logf("http: proxy error: %v", err)
190 rw.WriteHeader(http.StatusBadGateway)
191 }
192
193 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
194 if p.ErrorHandler != nil {
195 return p.ErrorHandler
196 }
197 return p.defaultErrorHandler
198 }
199
200
201
202 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
203 if p.ModifyResponse == nil {
204 return true
205 }
206 if err := p.ModifyResponse(res); err != nil {
207 res.Body.Close()
208 p.getErrorHandler()(rw, req, err)
209 return false
210 }
211 return true
212 }
213
214 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
215 transport := p.Transport
216 if transport == nil {
217 transport = http.DefaultTransport
218 }
219
220 ctx := req.Context()
221 if ctx.Done() != nil {
222
223
224
225
226
227
228
229
230
231
232 } else if cn, ok := rw.(http.CloseNotifier); ok {
233 var cancel context.CancelFunc
234 ctx, cancel = context.WithCancel(ctx)
235 defer cancel()
236 notifyChan := cn.CloseNotify()
237 go func() {
238 select {
239 case <-notifyChan:
240 cancel()
241 case <-ctx.Done():
242 }
243 }()
244 }
245
246 outreq := req.Clone(ctx)
247 if req.ContentLength == 0 {
248 outreq.Body = nil
249 }
250 if outreq.Body != nil {
251
252
253
254
255
256
257 defer outreq.Body.Close()
258 }
259 if outreq.Header == nil {
260 outreq.Header = make(http.Header)
261 }
262
263 p.Director(outreq)
264 if outreq.Form != nil {
265 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
266 }
267 outreq.Close = false
268
269 reqUpType := upgradeType(outreq.Header)
270 if !ascii.IsPrint(reqUpType) {
271 p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
272 return
273 }
274 removeConnectionHeaders(outreq.Header)
275
276
277
278
279 for _, h := range hopHeaders {
280 outreq.Header.Del(h)
281 }
282
283
284
285
286
287
288 if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
289 outreq.Header.Set("Te", "trailers")
290 }
291
292
293
294 if reqUpType != "" {
295 outreq.Header.Set("Connection", "Upgrade")
296 outreq.Header.Set("Upgrade", reqUpType)
297 }
298
299 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
300
301
302
303 prior, ok := outreq.Header["X-Forwarded-For"]
304 omit := ok && prior == nil
305 if len(prior) > 0 {
306 clientIP = strings.Join(prior, ", ") + ", " + clientIP
307 }
308 if !omit {
309 outreq.Header.Set("X-Forwarded-For", clientIP)
310 }
311 }
312
313 res, err := transport.RoundTrip(outreq)
314 if err != nil {
315 p.getErrorHandler()(rw, outreq, err)
316 return
317 }
318
319
320 if res.StatusCode == http.StatusSwitchingProtocols {
321 if !p.modifyResponse(rw, res, outreq) {
322 return
323 }
324 p.handleUpgradeResponse(rw, outreq, res)
325 return
326 }
327
328 removeConnectionHeaders(res.Header)
329
330 for _, h := range hopHeaders {
331 res.Header.Del(h)
332 }
333
334 if !p.modifyResponse(rw, res, outreq) {
335 return
336 }
337
338 copyHeader(rw.Header(), res.Header)
339
340
341
342 announcedTrailers := len(res.Trailer)
343 if announcedTrailers > 0 {
344 trailerKeys := make([]string, 0, len(res.Trailer))
345 for k := range res.Trailer {
346 trailerKeys = append(trailerKeys, k)
347 }
348 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
349 }
350
351 rw.WriteHeader(res.StatusCode)
352
353 err = p.copyResponse(rw, res.Body, p.flushInterval(res))
354 if err != nil {
355 defer res.Body.Close()
356
357
358
359 if !shouldPanicOnCopyError(req) {
360 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
361 return
362 }
363 panic(http.ErrAbortHandler)
364 }
365 res.Body.Close()
366
367 if len(res.Trailer) > 0 {
368
369
370
371 if fl, ok := rw.(http.Flusher); ok {
372 fl.Flush()
373 }
374 }
375
376 if len(res.Trailer) == announcedTrailers {
377 copyHeader(rw.Header(), res.Trailer)
378 return
379 }
380
381 for k, vv := range res.Trailer {
382 k = http.TrailerPrefix + k
383 for _, v := range vv {
384 rw.Header().Add(k, v)
385 }
386 }
387 }
388
389 var inOurTests bool
390
391
392
393
394
395
396 func shouldPanicOnCopyError(req *http.Request) bool {
397 if inOurTests {
398
399 return true
400 }
401 if req.Context().Value(http.ServerContextKey) != nil {
402
403
404 return true
405 }
406
407
408 return false
409 }
410
411
412
413 func removeConnectionHeaders(h http.Header) {
414 for _, f := range h["Connection"] {
415 for _, sf := range strings.Split(f, ",") {
416 if sf = textproto.TrimString(sf); sf != "" {
417 h.Del(sf)
418 }
419 }
420 }
421 }
422
423
424
425 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
426 resCT := res.Header.Get("Content-Type")
427
428
429
430 if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
431 return -1
432 }
433
434
435 if res.ContentLength == -1 {
436 return -1
437 }
438
439 return p.FlushInterval
440 }
441
442 func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
443 if flushInterval != 0 {
444 if wf, ok := dst.(writeFlusher); ok {
445 mlw := &maxLatencyWriter{
446 dst: wf,
447 latency: flushInterval,
448 }
449 defer mlw.stop()
450
451
452 mlw.flushPending = true
453 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
454
455 dst = mlw
456 }
457 }
458
459 var buf []byte
460 if p.BufferPool != nil {
461 buf = p.BufferPool.Get()
462 defer p.BufferPool.Put(buf)
463 }
464 _, err := p.copyBuffer(dst, src, buf)
465 return err
466 }
467
468
469
470 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
471 if len(buf) == 0 {
472 buf = make([]byte, 32*1024)
473 }
474 var written int64
475 for {
476 nr, rerr := src.Read(buf)
477 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
478 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
479 }
480 if nr > 0 {
481 nw, werr := dst.Write(buf[:nr])
482 if nw > 0 {
483 written += int64(nw)
484 }
485 if werr != nil {
486 return written, werr
487 }
488 if nr != nw {
489 return written, io.ErrShortWrite
490 }
491 }
492 if rerr != nil {
493 if rerr == io.EOF {
494 rerr = nil
495 }
496 return written, rerr
497 }
498 }
499 }
500
501 func (p *ReverseProxy) logf(format string, args ...any) {
502 if p.ErrorLog != nil {
503 p.ErrorLog.Printf(format, args...)
504 } else {
505 log.Printf(format, args...)
506 }
507 }
508
509 type writeFlusher interface {
510 io.Writer
511 http.Flusher
512 }
513
514 type maxLatencyWriter struct {
515 dst writeFlusher
516 latency time.Duration
517
518 mu sync.Mutex
519 t *time.Timer
520 flushPending bool
521 }
522
523 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
524 m.mu.Lock()
525 defer m.mu.Unlock()
526 n, err = m.dst.Write(p)
527 if m.latency < 0 {
528 m.dst.Flush()
529 return
530 }
531 if m.flushPending {
532 return
533 }
534 if m.t == nil {
535 m.t = time.AfterFunc(m.latency, m.delayedFlush)
536 } else {
537 m.t.Reset(m.latency)
538 }
539 m.flushPending = true
540 return
541 }
542
543 func (m *maxLatencyWriter) delayedFlush() {
544 m.mu.Lock()
545 defer m.mu.Unlock()
546 if !m.flushPending {
547 return
548 }
549 m.dst.Flush()
550 m.flushPending = false
551 }
552
553 func (m *maxLatencyWriter) stop() {
554 m.mu.Lock()
555 defer m.mu.Unlock()
556 m.flushPending = false
557 if m.t != nil {
558 m.t.Stop()
559 }
560 }
561
562 func upgradeType(h http.Header) string {
563 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
564 return ""
565 }
566 return h.Get("Upgrade")
567 }
568
569 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
570 reqUpType := upgradeType(req.Header)
571 resUpType := upgradeType(res.Header)
572 if !ascii.IsPrint(resUpType) {
573 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
574 }
575 if !ascii.EqualFold(reqUpType, resUpType) {
576 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
577 return
578 }
579
580 hj, ok := rw.(http.Hijacker)
581 if !ok {
582 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
583 return
584 }
585 backConn, ok := res.Body.(io.ReadWriteCloser)
586 if !ok {
587 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
588 return
589 }
590
591 backConnCloseCh := make(chan bool)
592 go func() {
593
594
595 select {
596 case <-req.Context().Done():
597 case <-backConnCloseCh:
598 }
599 backConn.Close()
600 }()
601
602 defer close(backConnCloseCh)
603
604 conn, brw, err := hj.Hijack()
605 if err != nil {
606 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
607 return
608 }
609 defer conn.Close()
610
611 copyHeader(rw.Header(), res.Header)
612
613 res.Header = rw.Header()
614 res.Body = nil
615 if err := res.Write(brw); err != nil {
616 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
617 return
618 }
619 if err := brw.Flush(); err != nil {
620 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
621 return
622 }
623 errc := make(chan error, 1)
624 spc := switchProtocolCopier{user: conn, backend: backConn}
625 go spc.copyToBackend(errc)
626 go spc.copyFromBackend(errc)
627 <-errc
628 }
629
630
631
632 type switchProtocolCopier struct {
633 user, backend io.ReadWriter
634 }
635
636 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
637 _, err := io.Copy(c.user, c.backend)
638 errc <- err
639 }
640
641 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
642 _, err := io.Copy(c.backend, c.user)
643 errc <- err
644 }
645
646 func cleanQueryParams(s string) string {
647 reencode := func(s string) string {
648 v, _ := url.ParseQuery(s)
649 return v.Encode()
650 }
651 for i := 0; i < len(s); {
652 switch s[i] {
653 case ';':
654 return reencode(s)
655 case '%':
656 if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
657 return reencode(s)
658 }
659 i += 3
660 default:
661 i++
662 }
663 }
664 return s
665 }
666
667 func ishex(c byte) bool {
668 switch {
669 case '0' <= c && c <= '9':
670 return true
671 case 'a' <= c && c <= 'f':
672 return true
673 case 'A' <= c && c <= 'F':
674 return true
675 }
676 return false
677 }
678
View as plain text