1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "context"
12 "crypto/cipher"
13 "crypto/subtle"
14 "crypto/x509"
15 "errors"
16 "fmt"
17 "hash"
18 "io"
19 "net"
20 "sync"
21 "sync/atomic"
22 "time"
23 )
24
25
26
27 type Conn struct {
28
29 conn net.Conn
30 isClient bool
31 handshakeFn func(context.Context) error
32
33
34
35
36
37 handshakeStatus uint32
38
39 handshakeMutex sync.Mutex
40 handshakeErr error
41 vers uint16
42 haveVers bool
43 config *Config
44
45
46
47 handshakes int
48 didResume bool
49 cipherSuite uint16
50 ocspResponse []byte
51 scts [][]byte
52 peerCertificates []*x509.Certificate
53
54
55 verifiedChains [][]*x509.Certificate
56
57 serverName string
58
59
60
61 secureRenegotiation bool
62
63 ekm func(label string, context []byte, length int) ([]byte, error)
64
65
66 resumptionSecret []byte
67
68
69
70
71 ticketKeys []ticketKey
72
73
74
75
76
77 clientFinishedIsFirst bool
78
79
80 closeNotifyErr error
81
82
83 closeNotifySent bool
84
85
86
87
88
89 clientFinished [12]byte
90 serverFinished [12]byte
91
92
93 clientProtocol string
94
95
96 in, out halfConn
97 rawInput bytes.Buffer
98 input bytes.Reader
99 hand bytes.Buffer
100 buffering bool
101 sendBuf []byte
102
103
104
105 bytesSent int64
106 packetsSent int64
107
108
109
110
111 retryCount int
112
113
114
115
116 activeCall int32
117
118 tmp [16]byte
119 }
120
121
122
123
124
125
126 func (c *Conn) LocalAddr() net.Addr {
127 return c.conn.LocalAddr()
128 }
129
130
131 func (c *Conn) RemoteAddr() net.Addr {
132 return c.conn.RemoteAddr()
133 }
134
135
136
137
138 func (c *Conn) SetDeadline(t time.Time) error {
139 return c.conn.SetDeadline(t)
140 }
141
142
143
144 func (c *Conn) SetReadDeadline(t time.Time) error {
145 return c.conn.SetReadDeadline(t)
146 }
147
148
149
150
151 func (c *Conn) SetWriteDeadline(t time.Time) error {
152 return c.conn.SetWriteDeadline(t)
153 }
154
155
156
157
158 func (c *Conn) NetConn() net.Conn {
159 return c.conn
160 }
161
162
163
164 type halfConn struct {
165 sync.Mutex
166
167 err error
168 version uint16
169 cipher any
170 mac hash.Hash
171 seq [8]byte
172
173 scratchBuf [13]byte
174
175 nextCipher any
176 nextMac hash.Hash
177
178 trafficSecret []byte
179 }
180
181 type permanentError struct {
182 err net.Error
183 }
184
185 func (e *permanentError) Error() string { return e.err.Error() }
186 func (e *permanentError) Unwrap() error { return e.err }
187 func (e *permanentError) Timeout() bool { return e.err.Timeout() }
188 func (e *permanentError) Temporary() bool { return false }
189
190 func (hc *halfConn) setErrorLocked(err error) error {
191 if e, ok := err.(net.Error); ok {
192 hc.err = &permanentError{err: e}
193 } else {
194 hc.err = err
195 }
196 return hc.err
197 }
198
199
200
201 func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) {
202 hc.version = version
203 hc.nextCipher = cipher
204 hc.nextMac = mac
205 }
206
207
208
209 func (hc *halfConn) changeCipherSpec() error {
210 if hc.nextCipher == nil || hc.version == VersionTLS13 {
211 return alertInternalError
212 }
213 hc.cipher = hc.nextCipher
214 hc.mac = hc.nextMac
215 hc.nextCipher = nil
216 hc.nextMac = nil
217 for i := range hc.seq {
218 hc.seq[i] = 0
219 }
220 return nil
221 }
222
223 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) {
224 hc.trafficSecret = secret
225 key, iv := suite.trafficKey(secret)
226 hc.cipher = suite.aead(key, iv)
227 for i := range hc.seq {
228 hc.seq[i] = 0
229 }
230 }
231
232
233 func (hc *halfConn) incSeq() {
234 for i := 7; i >= 0; i-- {
235 hc.seq[i]++
236 if hc.seq[i] != 0 {
237 return
238 }
239 }
240
241
242
243
244 panic("TLS: sequence number wraparound")
245 }
246
247
248
249
250 func (hc *halfConn) explicitNonceLen() int {
251 if hc.cipher == nil {
252 return 0
253 }
254
255 switch c := hc.cipher.(type) {
256 case cipher.Stream:
257 return 0
258 case aead:
259 return c.explicitNonceLen()
260 case cbcMode:
261
262 if hc.version >= VersionTLS11 {
263 return c.BlockSize()
264 }
265 return 0
266 default:
267 panic("unknown cipher type")
268 }
269 }
270
271
272
273
274 func extractPadding(payload []byte) (toRemove int, good byte) {
275 if len(payload) < 1 {
276 return 0, 0
277 }
278
279 paddingLen := payload[len(payload)-1]
280 t := uint(len(payload)-1) - uint(paddingLen)
281
282 good = byte(int32(^t) >> 31)
283
284
285 toCheck := 256
286
287 if toCheck > len(payload) {
288 toCheck = len(payload)
289 }
290
291 for i := 0; i < toCheck; i++ {
292 t := uint(paddingLen) - uint(i)
293
294 mask := byte(int32(^t) >> 31)
295 b := payload[len(payload)-1-i]
296 good &^= mask&paddingLen ^ mask&b
297 }
298
299
300
301 good &= good << 4
302 good &= good << 2
303 good &= good << 1
304 good = uint8(int8(good) >> 7)
305
306
307
308
309
310
311
312
313
314
315 paddingLen &= good
316
317 toRemove = int(paddingLen) + 1
318 return
319 }
320
321 func roundUp(a, b int) int {
322 return a + (b-a%b)%b
323 }
324
325
326 type cbcMode interface {
327 cipher.BlockMode
328 SetIV([]byte)
329 }
330
331
332
333 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
334 var plaintext []byte
335 typ := recordType(record[0])
336 payload := record[recordHeaderLen:]
337
338
339
340 if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
341 return payload, typ, nil
342 }
343
344 paddingGood := byte(255)
345 paddingLen := 0
346
347 explicitNonceLen := hc.explicitNonceLen()
348
349 if hc.cipher != nil {
350 switch c := hc.cipher.(type) {
351 case cipher.Stream:
352 c.XORKeyStream(payload, payload)
353 case aead:
354 if len(payload) < explicitNonceLen {
355 return nil, 0, alertBadRecordMAC
356 }
357 nonce := payload[:explicitNonceLen]
358 if len(nonce) == 0 {
359 nonce = hc.seq[:]
360 }
361 payload = payload[explicitNonceLen:]
362
363 var additionalData []byte
364 if hc.version == VersionTLS13 {
365 additionalData = record[:recordHeaderLen]
366 } else {
367 additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
368 additionalData = append(additionalData, record[:3]...)
369 n := len(payload) - c.Overhead()
370 additionalData = append(additionalData, byte(n>>8), byte(n))
371 }
372
373 var err error
374 plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
375 if err != nil {
376 return nil, 0, alertBadRecordMAC
377 }
378 case cbcMode:
379 blockSize := c.BlockSize()
380 minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
381 if len(payload)%blockSize != 0 || len(payload) < minPayload {
382 return nil, 0, alertBadRecordMAC
383 }
384
385 if explicitNonceLen > 0 {
386 c.SetIV(payload[:explicitNonceLen])
387 payload = payload[explicitNonceLen:]
388 }
389 c.CryptBlocks(payload, payload)
390
391
392
393
394
395
396
397 paddingLen, paddingGood = extractPadding(payload)
398 default:
399 panic("unknown cipher type")
400 }
401
402 if hc.version == VersionTLS13 {
403 if typ != recordTypeApplicationData {
404 return nil, 0, alertUnexpectedMessage
405 }
406 if len(plaintext) > maxPlaintext+1 {
407 return nil, 0, alertRecordOverflow
408 }
409
410 for i := len(plaintext) - 1; i >= 0; i-- {
411 if plaintext[i] != 0 {
412 typ = recordType(plaintext[i])
413 plaintext = plaintext[:i]
414 break
415 }
416 if i == 0 {
417 return nil, 0, alertUnexpectedMessage
418 }
419 }
420 }
421 } else {
422 plaintext = payload
423 }
424
425 if hc.mac != nil {
426 macSize := hc.mac.Size()
427 if len(payload) < macSize {
428 return nil, 0, alertBadRecordMAC
429 }
430
431 n := len(payload) - macSize - paddingLen
432 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
433 record[3] = byte(n >> 8)
434 record[4] = byte(n)
435 remoteMAC := payload[n : n+macSize]
436 localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
437
438
439
440
441
442
443
444
445 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
446 if macAndPaddingGood != 1 {
447 return nil, 0, alertBadRecordMAC
448 }
449
450 plaintext = payload[:n]
451 }
452
453 hc.incSeq()
454 return plaintext, typ, nil
455 }
456
457
458
459
460 func sliceForAppend(in []byte, n int) (head, tail []byte) {
461 if total := len(in) + n; cap(in) >= total {
462 head = in[:total]
463 } else {
464 head = make([]byte, total)
465 copy(head, in)
466 }
467 tail = head[len(in):]
468 return
469 }
470
471
472
473 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
474 if hc.cipher == nil {
475 return append(record, payload...), nil
476 }
477
478 var explicitNonce []byte
479 if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
480 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
481 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
482
483
484
485
486
487
488
489
490
491 copy(explicitNonce, hc.seq[:])
492 } else {
493 if _, err := io.ReadFull(rand, explicitNonce); err != nil {
494 return nil, err
495 }
496 }
497 }
498
499 var dst []byte
500 switch c := hc.cipher.(type) {
501 case cipher.Stream:
502 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
503 record, dst = sliceForAppend(record, len(payload)+len(mac))
504 c.XORKeyStream(dst[:len(payload)], payload)
505 c.XORKeyStream(dst[len(payload):], mac)
506 case aead:
507 nonce := explicitNonce
508 if len(nonce) == 0 {
509 nonce = hc.seq[:]
510 }
511
512 if hc.version == VersionTLS13 {
513 record = append(record, payload...)
514
515
516 record = append(record, record[0])
517 record[0] = byte(recordTypeApplicationData)
518
519 n := len(payload) + 1 + c.Overhead()
520 record[3] = byte(n >> 8)
521 record[4] = byte(n)
522
523 record = c.Seal(record[:recordHeaderLen],
524 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
525 } else {
526 additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
527 additionalData = append(additionalData, record[:recordHeaderLen]...)
528 record = c.Seal(record, nonce, payload, additionalData)
529 }
530 case cbcMode:
531 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
532 blockSize := c.BlockSize()
533 plaintextLen := len(payload) + len(mac)
534 paddingLen := blockSize - plaintextLen%blockSize
535 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
536 copy(dst, payload)
537 copy(dst[len(payload):], mac)
538 for i := plaintextLen; i < len(dst); i++ {
539 dst[i] = byte(paddingLen - 1)
540 }
541 if len(explicitNonce) > 0 {
542 c.SetIV(explicitNonce)
543 }
544 c.CryptBlocks(dst, dst)
545 default:
546 panic("unknown cipher type")
547 }
548
549
550 n := len(record) - recordHeaderLen
551 record[3] = byte(n >> 8)
552 record[4] = byte(n)
553 hc.incSeq()
554
555 return record, nil
556 }
557
558
559 type RecordHeaderError struct {
560
561 Msg string
562
563
564 RecordHeader [5]byte
565
566
567
568
569 Conn net.Conn
570 }
571
572 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
573
574 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
575 err.Msg = msg
576 err.Conn = conn
577 copy(err.RecordHeader[:], c.rawInput.Bytes())
578 return err
579 }
580
581 func (c *Conn) readRecord() error {
582 return c.readRecordOrCCS(false)
583 }
584
585 func (c *Conn) readChangeCipherSpec() error {
586 return c.readRecordOrCCS(true)
587 }
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
604 if c.in.err != nil {
605 return c.in.err
606 }
607 handshakeComplete := c.handshakeComplete()
608
609
610 if c.input.Len() != 0 {
611 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
612 }
613 c.input.Reset(nil)
614
615
616 if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
617
618
619
620 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
621 err = io.EOF
622 }
623 if e, ok := err.(net.Error); !ok || !e.Temporary() {
624 c.in.setErrorLocked(err)
625 }
626 return err
627 }
628 hdr := c.rawInput.Bytes()[:recordHeaderLen]
629 typ := recordType(hdr[0])
630
631
632
633
634
635 if !handshakeComplete && typ == 0x80 {
636 c.sendAlert(alertProtocolVersion)
637 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
638 }
639
640 vers := uint16(hdr[1])<<8 | uint16(hdr[2])
641 n := int(hdr[3])<<8 | int(hdr[4])
642 if c.haveVers && c.vers != VersionTLS13 && vers != c.vers {
643 c.sendAlert(alertProtocolVersion)
644 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
645 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
646 }
647 if !c.haveVers {
648
649
650
651
652 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
653 return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
654 }
655 }
656 if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
657 c.sendAlert(alertRecordOverflow)
658 msg := fmt.Sprintf("oversized record received with length %d", n)
659 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
660 }
661 if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
662 if e, ok := err.(net.Error); !ok || !e.Temporary() {
663 c.in.setErrorLocked(err)
664 }
665 return err
666 }
667
668
669 record := c.rawInput.Next(recordHeaderLen + n)
670 data, typ, err := c.in.decrypt(record)
671 if err != nil {
672 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
673 }
674 if len(data) > maxPlaintext {
675 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
676 }
677
678
679 if c.in.cipher == nil && typ == recordTypeApplicationData {
680 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
681 }
682
683 if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
684
685 c.retryCount = 0
686 }
687
688
689 if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
690 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
691 }
692
693 switch typ {
694 default:
695 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
696
697 case recordTypeAlert:
698 if len(data) != 2 {
699 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
700 }
701 if alert(data[1]) == alertCloseNotify {
702 return c.in.setErrorLocked(io.EOF)
703 }
704 if c.vers == VersionTLS13 {
705 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
706 }
707 switch data[0] {
708 case alertLevelWarning:
709
710 return c.retryReadRecord(expectChangeCipherSpec)
711 case alertLevelError:
712 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
713 default:
714 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
715 }
716
717 case recordTypeChangeCipherSpec:
718 if len(data) != 1 || data[0] != 1 {
719 return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
720 }
721
722 if c.hand.Len() > 0 {
723 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
724 }
725
726
727
728
729
730 if c.vers == VersionTLS13 {
731 return c.retryReadRecord(expectChangeCipherSpec)
732 }
733 if !expectChangeCipherSpec {
734 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
735 }
736 if err := c.in.changeCipherSpec(); err != nil {
737 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
738 }
739
740 case recordTypeApplicationData:
741 if !handshakeComplete || expectChangeCipherSpec {
742 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
743 }
744
745
746 if len(data) == 0 {
747 return c.retryReadRecord(expectChangeCipherSpec)
748 }
749
750
751
752 c.input.Reset(data)
753
754 case recordTypeHandshake:
755 if len(data) == 0 || expectChangeCipherSpec {
756 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
757 }
758 c.hand.Write(data)
759 }
760
761 return nil
762 }
763
764
765
766 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
767 c.retryCount++
768 if c.retryCount > maxUselessRecords {
769 c.sendAlert(alertUnexpectedMessage)
770 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
771 }
772 return c.readRecordOrCCS(expectChangeCipherSpec)
773 }
774
775
776
777
778 type atLeastReader struct {
779 R io.Reader
780 N int64
781 }
782
783 func (r *atLeastReader) Read(p []byte) (int, error) {
784 if r.N <= 0 {
785 return 0, io.EOF
786 }
787 n, err := r.R.Read(p)
788 r.N -= int64(n)
789 if r.N > 0 && err == io.EOF {
790 return n, io.ErrUnexpectedEOF
791 }
792 if r.N <= 0 && err == nil {
793 return n, io.EOF
794 }
795 return n, err
796 }
797
798
799
800 func (c *Conn) readFromUntil(r io.Reader, n int) error {
801 if c.rawInput.Len() >= n {
802 return nil
803 }
804 needs := n - c.rawInput.Len()
805
806
807
808 c.rawInput.Grow(needs + bytes.MinRead)
809 _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
810 return err
811 }
812
813
814 func (c *Conn) sendAlertLocked(err alert) error {
815 switch err {
816 case alertNoRenegotiation, alertCloseNotify:
817 c.tmp[0] = alertLevelWarning
818 default:
819 c.tmp[0] = alertLevelError
820 }
821 c.tmp[1] = byte(err)
822
823 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
824 if err == alertCloseNotify {
825
826 return writeErr
827 }
828
829 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
830 }
831
832
833 func (c *Conn) sendAlert(err alert) error {
834 c.out.Lock()
835 defer c.out.Unlock()
836 return c.sendAlertLocked(err)
837 }
838
839 const (
840
841
842
843
844
845 tcpMSSEstimate = 1208
846
847
848
849
850 recordSizeBoostThreshold = 128 * 1024
851 )
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
870 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
871 return maxPlaintext
872 }
873
874 if c.bytesSent >= recordSizeBoostThreshold {
875 return maxPlaintext
876 }
877
878
879 payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
880 if c.out.cipher != nil {
881 switch ciph := c.out.cipher.(type) {
882 case cipher.Stream:
883 payloadBytes -= c.out.mac.Size()
884 case cipher.AEAD:
885 payloadBytes -= ciph.Overhead()
886 case cbcMode:
887 blockSize := ciph.BlockSize()
888
889
890 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
891
892
893 payloadBytes -= c.out.mac.Size()
894 default:
895 panic("unknown cipher type")
896 }
897 }
898 if c.vers == VersionTLS13 {
899 payloadBytes--
900 }
901
902
903 pkt := c.packetsSent
904 c.packetsSent++
905 if pkt > 1000 {
906 return maxPlaintext
907 }
908
909 n := payloadBytes * int(pkt+1)
910 if n > maxPlaintext {
911 n = maxPlaintext
912 }
913 return n
914 }
915
916 func (c *Conn) write(data []byte) (int, error) {
917 if c.buffering {
918 c.sendBuf = append(c.sendBuf, data...)
919 return len(data), nil
920 }
921
922 n, err := c.conn.Write(data)
923 c.bytesSent += int64(n)
924 return n, err
925 }
926
927 func (c *Conn) flush() (int, error) {
928 if len(c.sendBuf) == 0 {
929 return 0, nil
930 }
931
932 n, err := c.conn.Write(c.sendBuf)
933 c.bytesSent += int64(n)
934 c.sendBuf = nil
935 c.buffering = false
936 return n, err
937 }
938
939
940 var outBufPool = sync.Pool{
941 New: func() any {
942 return new([]byte)
943 },
944 }
945
946
947
948 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
949 outBufPtr := outBufPool.Get().(*[]byte)
950 outBuf := *outBufPtr
951 defer func() {
952
953
954
955
956
957 *outBufPtr = outBuf
958 outBufPool.Put(outBufPtr)
959 }()
960
961 var n int
962 for len(data) > 0 {
963 m := len(data)
964 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
965 m = maxPayload
966 }
967
968 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
969 outBuf[0] = byte(typ)
970 vers := c.vers
971 if vers == 0 {
972
973
974 vers = VersionTLS10
975 } else if vers == VersionTLS13 {
976
977
978 vers = VersionTLS12
979 }
980 outBuf[1] = byte(vers >> 8)
981 outBuf[2] = byte(vers)
982 outBuf[3] = byte(m >> 8)
983 outBuf[4] = byte(m)
984
985 var err error
986 outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
987 if err != nil {
988 return n, err
989 }
990 if _, err := c.write(outBuf); err != nil {
991 return n, err
992 }
993 n += m
994 data = data[m:]
995 }
996
997 if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
998 if err := c.out.changeCipherSpec(); err != nil {
999 return n, c.sendAlertLocked(err.(alert))
1000 }
1001 }
1002
1003 return n, nil
1004 }
1005
1006
1007
1008 func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
1009 c.out.Lock()
1010 defer c.out.Unlock()
1011
1012 return c.writeRecordLocked(typ, data)
1013 }
1014
1015
1016
1017 func (c *Conn) readHandshake() (any, error) {
1018 for c.hand.Len() < 4 {
1019 if err := c.readRecord(); err != nil {
1020 return nil, err
1021 }
1022 }
1023
1024 data := c.hand.Bytes()
1025 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1026 if n > maxHandshake {
1027 c.sendAlertLocked(alertInternalError)
1028 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
1029 }
1030 for c.hand.Len() < 4+n {
1031 if err := c.readRecord(); err != nil {
1032 return nil, err
1033 }
1034 }
1035 data = c.hand.Next(4 + n)
1036 var m handshakeMessage
1037 switch data[0] {
1038 case typeHelloRequest:
1039 m = new(helloRequestMsg)
1040 case typeClientHello:
1041 m = new(clientHelloMsg)
1042 case typeServerHello:
1043 m = new(serverHelloMsg)
1044 case typeNewSessionTicket:
1045 if c.vers == VersionTLS13 {
1046 m = new(newSessionTicketMsgTLS13)
1047 } else {
1048 m = new(newSessionTicketMsg)
1049 }
1050 case typeCertificate:
1051 if c.vers == VersionTLS13 {
1052 m = new(certificateMsgTLS13)
1053 } else {
1054 m = new(certificateMsg)
1055 }
1056 case typeCertificateRequest:
1057 if c.vers == VersionTLS13 {
1058 m = new(certificateRequestMsgTLS13)
1059 } else {
1060 m = &certificateRequestMsg{
1061 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1062 }
1063 }
1064 case typeCertificateStatus:
1065 m = new(certificateStatusMsg)
1066 case typeServerKeyExchange:
1067 m = new(serverKeyExchangeMsg)
1068 case typeServerHelloDone:
1069 m = new(serverHelloDoneMsg)
1070 case typeClientKeyExchange:
1071 m = new(clientKeyExchangeMsg)
1072 case typeCertificateVerify:
1073 m = &certificateVerifyMsg{
1074 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1075 }
1076 case typeFinished:
1077 m = new(finishedMsg)
1078 case typeEncryptedExtensions:
1079 m = new(encryptedExtensionsMsg)
1080 case typeEndOfEarlyData:
1081 m = new(endOfEarlyDataMsg)
1082 case typeKeyUpdate:
1083 m = new(keyUpdateMsg)
1084 default:
1085 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1086 }
1087
1088
1089
1090
1091 data = append([]byte(nil), data...)
1092
1093 if !m.unmarshal(data) {
1094 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1095 }
1096 return m, nil
1097 }
1098
1099 var (
1100 errShutdown = errors.New("tls: protocol is shutdown")
1101 )
1102
1103
1104
1105
1106
1107
1108
1109 func (c *Conn) Write(b []byte) (int, error) {
1110
1111 for {
1112 x := atomic.LoadInt32(&c.activeCall)
1113 if x&1 != 0 {
1114 return 0, net.ErrClosed
1115 }
1116 if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
1117 break
1118 }
1119 }
1120 defer atomic.AddInt32(&c.activeCall, -2)
1121
1122 if err := c.Handshake(); err != nil {
1123 return 0, err
1124 }
1125
1126 c.out.Lock()
1127 defer c.out.Unlock()
1128
1129 if err := c.out.err; err != nil {
1130 return 0, err
1131 }
1132
1133 if !c.handshakeComplete() {
1134 return 0, alertInternalError
1135 }
1136
1137 if c.closeNotifySent {
1138 return 0, errShutdown
1139 }
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150 var m int
1151 if len(b) > 1 && c.vers == VersionTLS10 {
1152 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1153 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1154 if err != nil {
1155 return n, c.out.setErrorLocked(err)
1156 }
1157 m, b = 1, b[1:]
1158 }
1159 }
1160
1161 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1162 return n + m, c.out.setErrorLocked(err)
1163 }
1164
1165
1166 func (c *Conn) handleRenegotiation() error {
1167 if c.vers == VersionTLS13 {
1168 return errors.New("tls: internal error: unexpected renegotiation")
1169 }
1170
1171 msg, err := c.readHandshake()
1172 if err != nil {
1173 return err
1174 }
1175
1176 helloReq, ok := msg.(*helloRequestMsg)
1177 if !ok {
1178 c.sendAlert(alertUnexpectedMessage)
1179 return unexpectedMessageError(helloReq, msg)
1180 }
1181
1182 if !c.isClient {
1183 return c.sendAlert(alertNoRenegotiation)
1184 }
1185
1186 switch c.config.Renegotiation {
1187 case RenegotiateNever:
1188 return c.sendAlert(alertNoRenegotiation)
1189 case RenegotiateOnceAsClient:
1190 if c.handshakes > 1 {
1191 return c.sendAlert(alertNoRenegotiation)
1192 }
1193 case RenegotiateFreelyAsClient:
1194
1195 default:
1196 c.sendAlert(alertInternalError)
1197 return errors.New("tls: unknown Renegotiation value")
1198 }
1199
1200 c.handshakeMutex.Lock()
1201 defer c.handshakeMutex.Unlock()
1202
1203 atomic.StoreUint32(&c.handshakeStatus, 0)
1204 if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
1205 c.handshakes++
1206 }
1207 return c.handshakeErr
1208 }
1209
1210
1211
1212 func (c *Conn) handlePostHandshakeMessage() error {
1213 if c.vers != VersionTLS13 {
1214 return c.handleRenegotiation()
1215 }
1216
1217 msg, err := c.readHandshake()
1218 if err != nil {
1219 return err
1220 }
1221
1222 c.retryCount++
1223 if c.retryCount > maxUselessRecords {
1224 c.sendAlert(alertUnexpectedMessage)
1225 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1226 }
1227
1228 switch msg := msg.(type) {
1229 case *newSessionTicketMsgTLS13:
1230 return c.handleNewSessionTicket(msg)
1231 case *keyUpdateMsg:
1232 return c.handleKeyUpdate(msg)
1233 default:
1234 c.sendAlert(alertUnexpectedMessage)
1235 return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1236 }
1237 }
1238
1239 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1240 cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1241 if cipherSuite == nil {
1242 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1243 }
1244
1245 newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1246 c.in.setTrafficSecret(cipherSuite, newSecret)
1247
1248 if keyUpdate.updateRequested {
1249 c.out.Lock()
1250 defer c.out.Unlock()
1251
1252 msg := &keyUpdateMsg{}
1253 _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal())
1254 if err != nil {
1255
1256 c.out.setErrorLocked(err)
1257 return nil
1258 }
1259
1260 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1261 c.out.setTrafficSecret(cipherSuite, newSecret)
1262 }
1263
1264 return nil
1265 }
1266
1267
1268
1269
1270
1271
1272
1273 func (c *Conn) Read(b []byte) (int, error) {
1274 if err := c.Handshake(); err != nil {
1275 return 0, err
1276 }
1277 if len(b) == 0 {
1278
1279
1280 return 0, nil
1281 }
1282
1283 c.in.Lock()
1284 defer c.in.Unlock()
1285
1286 for c.input.Len() == 0 {
1287 if err := c.readRecord(); err != nil {
1288 return 0, err
1289 }
1290 for c.hand.Len() > 0 {
1291 if err := c.handlePostHandshakeMessage(); err != nil {
1292 return 0, err
1293 }
1294 }
1295 }
1296
1297 n, _ := c.input.Read(b)
1298
1299
1300
1301
1302
1303
1304
1305
1306 if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1307 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1308 if err := c.readRecord(); err != nil {
1309 return n, err
1310 }
1311 }
1312
1313 return n, nil
1314 }
1315
1316
1317 func (c *Conn) Close() error {
1318
1319 var x int32
1320 for {
1321 x = atomic.LoadInt32(&c.activeCall)
1322 if x&1 != 0 {
1323 return net.ErrClosed
1324 }
1325 if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) {
1326 break
1327 }
1328 }
1329 if x != 0 {
1330
1331
1332
1333
1334
1335
1336 return c.conn.Close()
1337 }
1338
1339 var alertErr error
1340 if c.handshakeComplete() {
1341 if err := c.closeNotify(); err != nil {
1342 alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1343 }
1344 }
1345
1346 if err := c.conn.Close(); err != nil {
1347 return err
1348 }
1349 return alertErr
1350 }
1351
1352 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1353
1354
1355
1356
1357 func (c *Conn) CloseWrite() error {
1358 if !c.handshakeComplete() {
1359 return errEarlyCloseWrite
1360 }
1361
1362 return c.closeNotify()
1363 }
1364
1365 func (c *Conn) closeNotify() error {
1366 c.out.Lock()
1367 defer c.out.Unlock()
1368
1369 if !c.closeNotifySent {
1370
1371 c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1372 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1373 c.closeNotifySent = true
1374
1375 c.SetWriteDeadline(time.Now())
1376 }
1377 return c.closeNotifyErr
1378 }
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388 func (c *Conn) Handshake() error {
1389 return c.HandshakeContext(context.Background())
1390 }
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402 func (c *Conn) HandshakeContext(ctx context.Context) error {
1403
1404
1405 return c.handshakeContext(ctx)
1406 }
1407
1408 func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
1409
1410
1411
1412 if c.handshakeComplete() {
1413 return nil
1414 }
1415
1416 handshakeCtx, cancel := context.WithCancel(ctx)
1417
1418
1419
1420 defer cancel()
1421
1422
1423
1424
1425
1426
1427 if ctx.Done() != nil {
1428 done := make(chan struct{})
1429 interruptRes := make(chan error, 1)
1430 defer func() {
1431 close(done)
1432 if ctxErr := <-interruptRes; ctxErr != nil {
1433
1434 ret = ctxErr
1435 }
1436 }()
1437 go func() {
1438 select {
1439 case <-handshakeCtx.Done():
1440
1441 _ = c.conn.Close()
1442 interruptRes <- handshakeCtx.Err()
1443 case <-done:
1444 interruptRes <- nil
1445 }
1446 }()
1447 }
1448
1449 c.handshakeMutex.Lock()
1450 defer c.handshakeMutex.Unlock()
1451
1452 if err := c.handshakeErr; err != nil {
1453 return err
1454 }
1455 if c.handshakeComplete() {
1456 return nil
1457 }
1458
1459 c.in.Lock()
1460 defer c.in.Unlock()
1461
1462 c.handshakeErr = c.handshakeFn(handshakeCtx)
1463 if c.handshakeErr == nil {
1464 c.handshakes++
1465 } else {
1466
1467
1468 c.flush()
1469 }
1470
1471 if c.handshakeErr == nil && !c.handshakeComplete() {
1472 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1473 }
1474 if c.handshakeErr != nil && c.handshakeComplete() {
1475 panic("tls: internal error: handshake returned an error but is marked successful")
1476 }
1477
1478 return c.handshakeErr
1479 }
1480
1481
1482 func (c *Conn) ConnectionState() ConnectionState {
1483 c.handshakeMutex.Lock()
1484 defer c.handshakeMutex.Unlock()
1485 return c.connectionStateLocked()
1486 }
1487
1488 func (c *Conn) connectionStateLocked() ConnectionState {
1489 var state ConnectionState
1490 state.HandshakeComplete = c.handshakeComplete()
1491 state.Version = c.vers
1492 state.NegotiatedProtocol = c.clientProtocol
1493 state.DidResume = c.didResume
1494 state.NegotiatedProtocolIsMutual = true
1495 state.ServerName = c.serverName
1496 state.CipherSuite = c.cipherSuite
1497 state.PeerCertificates = c.peerCertificates
1498 state.VerifiedChains = c.verifiedChains
1499 state.SignedCertificateTimestamps = c.scts
1500 state.OCSPResponse = c.ocspResponse
1501 if !c.didResume && c.vers != VersionTLS13 {
1502 if c.clientFinishedIsFirst {
1503 state.TLSUnique = c.clientFinished[:]
1504 } else {
1505 state.TLSUnique = c.serverFinished[:]
1506 }
1507 }
1508 if c.config.Renegotiation != RenegotiateNever {
1509 state.ekm = noExportedKeyingMaterial
1510 } else {
1511 state.ekm = c.ekm
1512 }
1513 return state
1514 }
1515
1516
1517
1518 func (c *Conn) OCSPResponse() []byte {
1519 c.handshakeMutex.Lock()
1520 defer c.handshakeMutex.Unlock()
1521
1522 return c.ocspResponse
1523 }
1524
1525
1526
1527
1528 func (c *Conn) VerifyHostname(host string) error {
1529 c.handshakeMutex.Lock()
1530 defer c.handshakeMutex.Unlock()
1531 if !c.isClient {
1532 return errors.New("tls: VerifyHostname called on TLS server connection")
1533 }
1534 if !c.handshakeComplete() {
1535 return errors.New("tls: handshake has not yet been performed")
1536 }
1537 if len(c.verifiedChains) == 0 {
1538 return errors.New("tls: handshake did not verify certificate chain")
1539 }
1540 return c.peerCertificates[0].VerifyHostname(host)
1541 }
1542
1543 func (c *Conn) handshakeComplete() bool {
1544 return atomic.LoadUint32(&c.handshakeStatus) == 1
1545 }
1546
View as plain text