...

Source file src/crypto/tls/handshake_messages.go

Documentation: crypto/tls

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"fmt"
     9  	"strings"
    10  
    11  	"golang.org/x/crypto/cryptobyte"
    12  )
    13  
    14  // The marshalingFunction type is an adapter to allow the use of ordinary
    15  // functions as cryptobyte.MarshalingValue.
    16  type marshalingFunction func(b *cryptobyte.Builder) error
    17  
    18  func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
    19  	return f(b)
    20  }
    21  
    22  // addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
    23  // the length of the sequence is not the value specified, it produces an error.
    24  func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
    25  	b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
    26  		if len(v) != n {
    27  			return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
    28  		}
    29  		b.AddBytes(v)
    30  		return nil
    31  	}))
    32  }
    33  
    34  // addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
    35  func addUint64(b *cryptobyte.Builder, v uint64) {
    36  	b.AddUint32(uint32(v >> 32))
    37  	b.AddUint32(uint32(v))
    38  }
    39  
    40  // readUint64 decodes a big-endian, 64-bit value into out and advances over it.
    41  // It reports whether the read was successful.
    42  func readUint64(s *cryptobyte.String, out *uint64) bool {
    43  	var hi, lo uint32
    44  	if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
    45  		return false
    46  	}
    47  	*out = uint64(hi)<<32 | uint64(lo)
    48  	return true
    49  }
    50  
    51  // readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
    52  // []byte instead of a cryptobyte.String.
    53  func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    54  	return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
    55  }
    56  
    57  // readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
    58  // []byte instead of a cryptobyte.String.
    59  func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    60  	return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
    61  }
    62  
    63  // readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
    64  // []byte instead of a cryptobyte.String.
    65  func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    66  	return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
    67  }
    68  
    69  type clientHelloMsg struct {
    70  	raw                              []byte
    71  	vers                             uint16
    72  	random                           []byte
    73  	sessionId                        []byte
    74  	cipherSuites                     []uint16
    75  	compressionMethods               []uint8
    76  	serverName                       string
    77  	ocspStapling                     bool
    78  	supportedCurves                  []CurveID
    79  	supportedPoints                  []uint8
    80  	ticketSupported                  bool
    81  	sessionTicket                    []uint8
    82  	supportedSignatureAlgorithms     []SignatureScheme
    83  	supportedSignatureAlgorithmsCert []SignatureScheme
    84  	secureRenegotiationSupported     bool
    85  	secureRenegotiation              []byte
    86  	alpnProtocols                    []string
    87  	scts                             bool
    88  	supportedVersions                []uint16
    89  	cookie                           []byte
    90  	keyShares                        []keyShare
    91  	earlyData                        bool
    92  	pskModes                         []uint8
    93  	pskIdentities                    []pskIdentity
    94  	pskBinders                       [][]byte
    95  }
    96  
    97  func (m *clientHelloMsg) marshal() []byte {
    98  	if m.raw != nil {
    99  		return m.raw
   100  	}
   101  
   102  	var b cryptobyte.Builder
   103  	b.AddUint8(typeClientHello)
   104  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   105  		b.AddUint16(m.vers)
   106  		addBytesWithLength(b, m.random, 32)
   107  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   108  			b.AddBytes(m.sessionId)
   109  		})
   110  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   111  			for _, suite := range m.cipherSuites {
   112  				b.AddUint16(suite)
   113  			}
   114  		})
   115  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   116  			b.AddBytes(m.compressionMethods)
   117  		})
   118  
   119  		// If extensions aren't present, omit them.
   120  		var extensionsPresent bool
   121  		bWithoutExtensions := *b
   122  
   123  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   124  			if len(m.serverName) > 0 {
   125  				// RFC 6066, Section 3
   126  				b.AddUint16(extensionServerName)
   127  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   128  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   129  						b.AddUint8(0) // name_type = host_name
   130  						b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   131  							b.AddBytes([]byte(m.serverName))
   132  						})
   133  					})
   134  				})
   135  			}
   136  			if m.ocspStapling {
   137  				// RFC 4366, Section 3.6
   138  				b.AddUint16(extensionStatusRequest)
   139  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   140  					b.AddUint8(1)  // status_type = ocsp
   141  					b.AddUint16(0) // empty responder_id_list
   142  					b.AddUint16(0) // empty request_extensions
   143  				})
   144  			}
   145  			if len(m.supportedCurves) > 0 {
   146  				// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
   147  				b.AddUint16(extensionSupportedCurves)
   148  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   149  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   150  						for _, curve := range m.supportedCurves {
   151  							b.AddUint16(uint16(curve))
   152  						}
   153  					})
   154  				})
   155  			}
   156  			if len(m.supportedPoints) > 0 {
   157  				// RFC 4492, Section 5.1.2
   158  				b.AddUint16(extensionSupportedPoints)
   159  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   160  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   161  						b.AddBytes(m.supportedPoints)
   162  					})
   163  				})
   164  			}
   165  			if m.ticketSupported {
   166  				// RFC 5077, Section 3.2
   167  				b.AddUint16(extensionSessionTicket)
   168  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   169  					b.AddBytes(m.sessionTicket)
   170  				})
   171  			}
   172  			if len(m.supportedSignatureAlgorithms) > 0 {
   173  				// RFC 5246, Section 7.4.1.4.1
   174  				b.AddUint16(extensionSignatureAlgorithms)
   175  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   176  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   177  						for _, sigAlgo := range m.supportedSignatureAlgorithms {
   178  							b.AddUint16(uint16(sigAlgo))
   179  						}
   180  					})
   181  				})
   182  			}
   183  			if len(m.supportedSignatureAlgorithmsCert) > 0 {
   184  				// RFC 8446, Section 4.2.3
   185  				b.AddUint16(extensionSignatureAlgorithmsCert)
   186  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   187  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   188  						for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
   189  							b.AddUint16(uint16(sigAlgo))
   190  						}
   191  					})
   192  				})
   193  			}
   194  			if m.secureRenegotiationSupported {
   195  				// RFC 5746, Section 3.2
   196  				b.AddUint16(extensionRenegotiationInfo)
   197  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   198  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   199  						b.AddBytes(m.secureRenegotiation)
   200  					})
   201  				})
   202  			}
   203  			if len(m.alpnProtocols) > 0 {
   204  				// RFC 7301, Section 3.1
   205  				b.AddUint16(extensionALPN)
   206  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   207  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   208  						for _, proto := range m.alpnProtocols {
   209  							b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   210  								b.AddBytes([]byte(proto))
   211  							})
   212  						}
   213  					})
   214  				})
   215  			}
   216  			if m.scts {
   217  				// RFC 6962, Section 3.3.1
   218  				b.AddUint16(extensionSCT)
   219  				b.AddUint16(0) // empty extension_data
   220  			}
   221  			if len(m.supportedVersions) > 0 {
   222  				// RFC 8446, Section 4.2.1
   223  				b.AddUint16(extensionSupportedVersions)
   224  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   225  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   226  						for _, vers := range m.supportedVersions {
   227  							b.AddUint16(vers)
   228  						}
   229  					})
   230  				})
   231  			}
   232  			if len(m.cookie) > 0 {
   233  				// RFC 8446, Section 4.2.2
   234  				b.AddUint16(extensionCookie)
   235  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   236  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   237  						b.AddBytes(m.cookie)
   238  					})
   239  				})
   240  			}
   241  			if len(m.keyShares) > 0 {
   242  				// RFC 8446, Section 4.2.8
   243  				b.AddUint16(extensionKeyShare)
   244  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   245  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   246  						for _, ks := range m.keyShares {
   247  							b.AddUint16(uint16(ks.group))
   248  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   249  								b.AddBytes(ks.data)
   250  							})
   251  						}
   252  					})
   253  				})
   254  			}
   255  			if m.earlyData {
   256  				// RFC 8446, Section 4.2.10
   257  				b.AddUint16(extensionEarlyData)
   258  				b.AddUint16(0) // empty extension_data
   259  			}
   260  			if len(m.pskModes) > 0 {
   261  				// RFC 8446, Section 4.2.9
   262  				b.AddUint16(extensionPSKModes)
   263  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   264  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   265  						b.AddBytes(m.pskModes)
   266  					})
   267  				})
   268  			}
   269  			if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
   270  				// RFC 8446, Section 4.2.11
   271  				b.AddUint16(extensionPreSharedKey)
   272  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   273  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   274  						for _, psk := range m.pskIdentities {
   275  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   276  								b.AddBytes(psk.label)
   277  							})
   278  							b.AddUint32(psk.obfuscatedTicketAge)
   279  						}
   280  					})
   281  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   282  						for _, binder := range m.pskBinders {
   283  							b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   284  								b.AddBytes(binder)
   285  							})
   286  						}
   287  					})
   288  				})
   289  			}
   290  
   291  			extensionsPresent = len(b.BytesOrPanic()) > 2
   292  		})
   293  
   294  		if !extensionsPresent {
   295  			*b = bWithoutExtensions
   296  		}
   297  	})
   298  
   299  	m.raw = b.BytesOrPanic()
   300  	return m.raw
   301  }
   302  
   303  // marshalWithoutBinders returns the ClientHello through the
   304  // PreSharedKeyExtension.identities field, according to RFC 8446, Section
   305  // 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length.
   306  func (m *clientHelloMsg) marshalWithoutBinders() []byte {
   307  	bindersLen := 2 // uint16 length prefix
   308  	for _, binder := range m.pskBinders {
   309  		bindersLen += 1 // uint8 length prefix
   310  		bindersLen += len(binder)
   311  	}
   312  
   313  	fullMessage := m.marshal()
   314  	return fullMessage[:len(fullMessage)-bindersLen]
   315  }
   316  
   317  // updateBinders updates the m.pskBinders field, if necessary updating the
   318  // cached marshaled representation. The supplied binders must have the same
   319  // length as the current m.pskBinders.
   320  func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) {
   321  	if len(pskBinders) != len(m.pskBinders) {
   322  		panic("tls: internal error: pskBinders length mismatch")
   323  	}
   324  	for i := range m.pskBinders {
   325  		if len(pskBinders[i]) != len(m.pskBinders[i]) {
   326  			panic("tls: internal error: pskBinders length mismatch")
   327  		}
   328  	}
   329  	m.pskBinders = pskBinders
   330  	if m.raw != nil {
   331  		lenWithoutBinders := len(m.marshalWithoutBinders())
   332  		b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders])
   333  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   334  			for _, binder := range m.pskBinders {
   335  				b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   336  					b.AddBytes(binder)
   337  				})
   338  			}
   339  		})
   340  		if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) {
   341  			panic("tls: internal error: failed to update binders")
   342  		}
   343  	}
   344  }
   345  
   346  func (m *clientHelloMsg) unmarshal(data []byte) bool {
   347  	*m = clientHelloMsg{raw: data}
   348  	s := cryptobyte.String(data)
   349  
   350  	if !s.Skip(4) || // message type and uint24 length field
   351  		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
   352  		!readUint8LengthPrefixed(&s, &m.sessionId) {
   353  		return false
   354  	}
   355  
   356  	var cipherSuites cryptobyte.String
   357  	if !s.ReadUint16LengthPrefixed(&cipherSuites) {
   358  		return false
   359  	}
   360  	m.cipherSuites = []uint16{}
   361  	m.secureRenegotiationSupported = false
   362  	for !cipherSuites.Empty() {
   363  		var suite uint16
   364  		if !cipherSuites.ReadUint16(&suite) {
   365  			return false
   366  		}
   367  		if suite == scsvRenegotiation {
   368  			m.secureRenegotiationSupported = true
   369  		}
   370  		m.cipherSuites = append(m.cipherSuites, suite)
   371  	}
   372  
   373  	if !readUint8LengthPrefixed(&s, &m.compressionMethods) {
   374  		return false
   375  	}
   376  
   377  	if s.Empty() {
   378  		// ClientHello is optionally followed by extension data
   379  		return true
   380  	}
   381  
   382  	var extensions cryptobyte.String
   383  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   384  		return false
   385  	}
   386  
   387  	seenExts := make(map[uint16]bool)
   388  	for !extensions.Empty() {
   389  		var extension uint16
   390  		var extData cryptobyte.String
   391  		if !extensions.ReadUint16(&extension) ||
   392  			!extensions.ReadUint16LengthPrefixed(&extData) {
   393  			return false
   394  		}
   395  
   396  		if seenExts[extension] {
   397  			return false
   398  		}
   399  		seenExts[extension] = true
   400  
   401  		switch extension {
   402  		case extensionServerName:
   403  			// RFC 6066, Section 3
   404  			var nameList cryptobyte.String
   405  			if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
   406  				return false
   407  			}
   408  			for !nameList.Empty() {
   409  				var nameType uint8
   410  				var serverName cryptobyte.String
   411  				if !nameList.ReadUint8(&nameType) ||
   412  					!nameList.ReadUint16LengthPrefixed(&serverName) ||
   413  					serverName.Empty() {
   414  					return false
   415  				}
   416  				if nameType != 0 {
   417  					continue
   418  				}
   419  				if len(m.serverName) != 0 {
   420  					// Multiple names of the same name_type are prohibited.
   421  					return false
   422  				}
   423  				m.serverName = string(serverName)
   424  				// An SNI value may not include a trailing dot.
   425  				if strings.HasSuffix(m.serverName, ".") {
   426  					return false
   427  				}
   428  			}
   429  		case extensionStatusRequest:
   430  			// RFC 4366, Section 3.6
   431  			var statusType uint8
   432  			var ignored cryptobyte.String
   433  			if !extData.ReadUint8(&statusType) ||
   434  				!extData.ReadUint16LengthPrefixed(&ignored) ||
   435  				!extData.ReadUint16LengthPrefixed(&ignored) {
   436  				return false
   437  			}
   438  			m.ocspStapling = statusType == statusTypeOCSP
   439  		case extensionSupportedCurves:
   440  			// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
   441  			var curves cryptobyte.String
   442  			if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() {
   443  				return false
   444  			}
   445  			for !curves.Empty() {
   446  				var curve uint16
   447  				if !curves.ReadUint16(&curve) {
   448  					return false
   449  				}
   450  				m.supportedCurves = append(m.supportedCurves, CurveID(curve))
   451  			}
   452  		case extensionSupportedPoints:
   453  			// RFC 4492, Section 5.1.2
   454  			if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
   455  				len(m.supportedPoints) == 0 {
   456  				return false
   457  			}
   458  		case extensionSessionTicket:
   459  			// RFC 5077, Section 3.2
   460  			m.ticketSupported = true
   461  			extData.ReadBytes(&m.sessionTicket, len(extData))
   462  		case extensionSignatureAlgorithms:
   463  			// RFC 5246, Section 7.4.1.4.1
   464  			var sigAndAlgs cryptobyte.String
   465  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
   466  				return false
   467  			}
   468  			for !sigAndAlgs.Empty() {
   469  				var sigAndAlg uint16
   470  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
   471  					return false
   472  				}
   473  				m.supportedSignatureAlgorithms = append(
   474  					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
   475  			}
   476  		case extensionSignatureAlgorithmsCert:
   477  			// RFC 8446, Section 4.2.3
   478  			var sigAndAlgs cryptobyte.String
   479  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
   480  				return false
   481  			}
   482  			for !sigAndAlgs.Empty() {
   483  				var sigAndAlg uint16
   484  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
   485  					return false
   486  				}
   487  				m.supportedSignatureAlgorithmsCert = append(
   488  					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
   489  			}
   490  		case extensionRenegotiationInfo:
   491  			// RFC 5746, Section 3.2
   492  			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
   493  				return false
   494  			}
   495  			m.secureRenegotiationSupported = true
   496  		case extensionALPN:
   497  			// RFC 7301, Section 3.1
   498  			var protoList cryptobyte.String
   499  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   500  				return false
   501  			}
   502  			for !protoList.Empty() {
   503  				var proto cryptobyte.String
   504  				if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
   505  					return false
   506  				}
   507  				m.alpnProtocols = append(m.alpnProtocols, string(proto))
   508  			}
   509  		case extensionSCT:
   510  			// RFC 6962, Section 3.3.1
   511  			m.scts = true
   512  		case extensionSupportedVersions:
   513  			// RFC 8446, Section 4.2.1
   514  			var versList cryptobyte.String
   515  			if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() {
   516  				return false
   517  			}
   518  			for !versList.Empty() {
   519  				var vers uint16
   520  				if !versList.ReadUint16(&vers) {
   521  					return false
   522  				}
   523  				m.supportedVersions = append(m.supportedVersions, vers)
   524  			}
   525  		case extensionCookie:
   526  			// RFC 8446, Section 4.2.2
   527  			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
   528  				len(m.cookie) == 0 {
   529  				return false
   530  			}
   531  		case extensionKeyShare:
   532  			// RFC 8446, Section 4.2.8
   533  			var clientShares cryptobyte.String
   534  			if !extData.ReadUint16LengthPrefixed(&clientShares) {
   535  				return false
   536  			}
   537  			for !clientShares.Empty() {
   538  				var ks keyShare
   539  				if !clientShares.ReadUint16((*uint16)(&ks.group)) ||
   540  					!readUint16LengthPrefixed(&clientShares, &ks.data) ||
   541  					len(ks.data) == 0 {
   542  					return false
   543  				}
   544  				m.keyShares = append(m.keyShares, ks)
   545  			}
   546  		case extensionEarlyData:
   547  			// RFC 8446, Section 4.2.10
   548  			m.earlyData = true
   549  		case extensionPSKModes:
   550  			// RFC 8446, Section 4.2.9
   551  			if !readUint8LengthPrefixed(&extData, &m.pskModes) {
   552  				return false
   553  			}
   554  		case extensionPreSharedKey:
   555  			// RFC 8446, Section 4.2.11
   556  			if !extensions.Empty() {
   557  				return false // pre_shared_key must be the last extension
   558  			}
   559  			var identities cryptobyte.String
   560  			if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() {
   561  				return false
   562  			}
   563  			for !identities.Empty() {
   564  				var psk pskIdentity
   565  				if !readUint16LengthPrefixed(&identities, &psk.label) ||
   566  					!identities.ReadUint32(&psk.obfuscatedTicketAge) ||
   567  					len(psk.label) == 0 {
   568  					return false
   569  				}
   570  				m.pskIdentities = append(m.pskIdentities, psk)
   571  			}
   572  			var binders cryptobyte.String
   573  			if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() {
   574  				return false
   575  			}
   576  			for !binders.Empty() {
   577  				var binder []byte
   578  				if !readUint8LengthPrefixed(&binders, &binder) ||
   579  					len(binder) == 0 {
   580  					return false
   581  				}
   582  				m.pskBinders = append(m.pskBinders, binder)
   583  			}
   584  		default:
   585  			// Ignore unknown extensions.
   586  			continue
   587  		}
   588  
   589  		if !extData.Empty() {
   590  			return false
   591  		}
   592  	}
   593  
   594  	return true
   595  }
   596  
   597  type serverHelloMsg struct {
   598  	raw                          []byte
   599  	vers                         uint16
   600  	random                       []byte
   601  	sessionId                    []byte
   602  	cipherSuite                  uint16
   603  	compressionMethod            uint8
   604  	ocspStapling                 bool
   605  	ticketSupported              bool
   606  	secureRenegotiationSupported bool
   607  	secureRenegotiation          []byte
   608  	alpnProtocol                 string
   609  	scts                         [][]byte
   610  	supportedVersion             uint16
   611  	serverShare                  keyShare
   612  	selectedIdentityPresent      bool
   613  	selectedIdentity             uint16
   614  	supportedPoints              []uint8
   615  
   616  	// HelloRetryRequest extensions
   617  	cookie        []byte
   618  	selectedGroup CurveID
   619  }
   620  
   621  func (m *serverHelloMsg) marshal() []byte {
   622  	if m.raw != nil {
   623  		return m.raw
   624  	}
   625  
   626  	var b cryptobyte.Builder
   627  	b.AddUint8(typeServerHello)
   628  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   629  		b.AddUint16(m.vers)
   630  		addBytesWithLength(b, m.random, 32)
   631  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   632  			b.AddBytes(m.sessionId)
   633  		})
   634  		b.AddUint16(m.cipherSuite)
   635  		b.AddUint8(m.compressionMethod)
   636  
   637  		// If extensions aren't present, omit them.
   638  		var extensionsPresent bool
   639  		bWithoutExtensions := *b
   640  
   641  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   642  			if m.ocspStapling {
   643  				b.AddUint16(extensionStatusRequest)
   644  				b.AddUint16(0) // empty extension_data
   645  			}
   646  			if m.ticketSupported {
   647  				b.AddUint16(extensionSessionTicket)
   648  				b.AddUint16(0) // empty extension_data
   649  			}
   650  			if m.secureRenegotiationSupported {
   651  				b.AddUint16(extensionRenegotiationInfo)
   652  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   653  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   654  						b.AddBytes(m.secureRenegotiation)
   655  					})
   656  				})
   657  			}
   658  			if len(m.alpnProtocol) > 0 {
   659  				b.AddUint16(extensionALPN)
   660  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   661  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   662  						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   663  							b.AddBytes([]byte(m.alpnProtocol))
   664  						})
   665  					})
   666  				})
   667  			}
   668  			if len(m.scts) > 0 {
   669  				b.AddUint16(extensionSCT)
   670  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   671  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   672  						for _, sct := range m.scts {
   673  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   674  								b.AddBytes(sct)
   675  							})
   676  						}
   677  					})
   678  				})
   679  			}
   680  			if m.supportedVersion != 0 {
   681  				b.AddUint16(extensionSupportedVersions)
   682  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   683  					b.AddUint16(m.supportedVersion)
   684  				})
   685  			}
   686  			if m.serverShare.group != 0 {
   687  				b.AddUint16(extensionKeyShare)
   688  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   689  					b.AddUint16(uint16(m.serverShare.group))
   690  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   691  						b.AddBytes(m.serverShare.data)
   692  					})
   693  				})
   694  			}
   695  			if m.selectedIdentityPresent {
   696  				b.AddUint16(extensionPreSharedKey)
   697  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   698  					b.AddUint16(m.selectedIdentity)
   699  				})
   700  			}
   701  
   702  			if len(m.cookie) > 0 {
   703  				b.AddUint16(extensionCookie)
   704  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   705  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   706  						b.AddBytes(m.cookie)
   707  					})
   708  				})
   709  			}
   710  			if m.selectedGroup != 0 {
   711  				b.AddUint16(extensionKeyShare)
   712  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   713  					b.AddUint16(uint16(m.selectedGroup))
   714  				})
   715  			}
   716  			if len(m.supportedPoints) > 0 {
   717  				b.AddUint16(extensionSupportedPoints)
   718  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   719  					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   720  						b.AddBytes(m.supportedPoints)
   721  					})
   722  				})
   723  			}
   724  
   725  			extensionsPresent = len(b.BytesOrPanic()) > 2
   726  		})
   727  
   728  		if !extensionsPresent {
   729  			*b = bWithoutExtensions
   730  		}
   731  	})
   732  
   733  	m.raw = b.BytesOrPanic()
   734  	return m.raw
   735  }
   736  
   737  func (m *serverHelloMsg) unmarshal(data []byte) bool {
   738  	*m = serverHelloMsg{raw: data}
   739  	s := cryptobyte.String(data)
   740  
   741  	if !s.Skip(4) || // message type and uint24 length field
   742  		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
   743  		!readUint8LengthPrefixed(&s, &m.sessionId) ||
   744  		!s.ReadUint16(&m.cipherSuite) ||
   745  		!s.ReadUint8(&m.compressionMethod) {
   746  		return false
   747  	}
   748  
   749  	if s.Empty() {
   750  		// ServerHello is optionally followed by extension data
   751  		return true
   752  	}
   753  
   754  	var extensions cryptobyte.String
   755  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   756  		return false
   757  	}
   758  
   759  	seenExts := make(map[uint16]bool)
   760  	for !extensions.Empty() {
   761  		var extension uint16
   762  		var extData cryptobyte.String
   763  		if !extensions.ReadUint16(&extension) ||
   764  			!extensions.ReadUint16LengthPrefixed(&extData) {
   765  			return false
   766  		}
   767  
   768  		if seenExts[extension] {
   769  			return false
   770  		}
   771  		seenExts[extension] = true
   772  
   773  		switch extension {
   774  		case extensionStatusRequest:
   775  			m.ocspStapling = true
   776  		case extensionSessionTicket:
   777  			m.ticketSupported = true
   778  		case extensionRenegotiationInfo:
   779  			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
   780  				return false
   781  			}
   782  			m.secureRenegotiationSupported = true
   783  		case extensionALPN:
   784  			var protoList cryptobyte.String
   785  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   786  				return false
   787  			}
   788  			var proto cryptobyte.String
   789  			if !protoList.ReadUint8LengthPrefixed(&proto) ||
   790  				proto.Empty() || !protoList.Empty() {
   791  				return false
   792  			}
   793  			m.alpnProtocol = string(proto)
   794  		case extensionSCT:
   795  			var sctList cryptobyte.String
   796  			if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
   797  				return false
   798  			}
   799  			for !sctList.Empty() {
   800  				var sct []byte
   801  				if !readUint16LengthPrefixed(&sctList, &sct) ||
   802  					len(sct) == 0 {
   803  					return false
   804  				}
   805  				m.scts = append(m.scts, sct)
   806  			}
   807  		case extensionSupportedVersions:
   808  			if !extData.ReadUint16(&m.supportedVersion) {
   809  				return false
   810  			}
   811  		case extensionCookie:
   812  			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
   813  				len(m.cookie) == 0 {
   814  				return false
   815  			}
   816  		case extensionKeyShare:
   817  			// This extension has different formats in SH and HRR, accept either
   818  			// and let the handshake logic decide. See RFC 8446, Section 4.2.8.
   819  			if len(extData) == 2 {
   820  				if !extData.ReadUint16((*uint16)(&m.selectedGroup)) {
   821  					return false
   822  				}
   823  			} else {
   824  				if !extData.ReadUint16((*uint16)(&m.serverShare.group)) ||
   825  					!readUint16LengthPrefixed(&extData, &m.serverShare.data) {
   826  					return false
   827  				}
   828  			}
   829  		case extensionPreSharedKey:
   830  			m.selectedIdentityPresent = true
   831  			if !extData.ReadUint16(&m.selectedIdentity) {
   832  				return false
   833  			}
   834  		case extensionSupportedPoints:
   835  			// RFC 4492, Section 5.1.2
   836  			if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
   837  				len(m.supportedPoints) == 0 {
   838  				return false
   839  			}
   840  		default:
   841  			// Ignore unknown extensions.
   842  			continue
   843  		}
   844  
   845  		if !extData.Empty() {
   846  			return false
   847  		}
   848  	}
   849  
   850  	return true
   851  }
   852  
   853  type encryptedExtensionsMsg struct {
   854  	raw          []byte
   855  	alpnProtocol string
   856  }
   857  
   858  func (m *encryptedExtensionsMsg) marshal() []byte {
   859  	if m.raw != nil {
   860  		return m.raw
   861  	}
   862  
   863  	var b cryptobyte.Builder
   864  	b.AddUint8(typeEncryptedExtensions)
   865  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   866  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   867  			if len(m.alpnProtocol) > 0 {
   868  				b.AddUint16(extensionALPN)
   869  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   870  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   871  						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   872  							b.AddBytes([]byte(m.alpnProtocol))
   873  						})
   874  					})
   875  				})
   876  			}
   877  		})
   878  	})
   879  
   880  	m.raw = b.BytesOrPanic()
   881  	return m.raw
   882  }
   883  
   884  func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
   885  	*m = encryptedExtensionsMsg{raw: data}
   886  	s := cryptobyte.String(data)
   887  
   888  	var extensions cryptobyte.String
   889  	if !s.Skip(4) || // message type and uint24 length field
   890  		!s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   891  		return false
   892  	}
   893  
   894  	for !extensions.Empty() {
   895  		var extension uint16
   896  		var extData cryptobyte.String
   897  		if !extensions.ReadUint16(&extension) ||
   898  			!extensions.ReadUint16LengthPrefixed(&extData) {
   899  			return false
   900  		}
   901  
   902  		switch extension {
   903  		case extensionALPN:
   904  			var protoList cryptobyte.String
   905  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   906  				return false
   907  			}
   908  			var proto cryptobyte.String
   909  			if !protoList.ReadUint8LengthPrefixed(&proto) ||
   910  				proto.Empty() || !protoList.Empty() {
   911  				return false
   912  			}
   913  			m.alpnProtocol = string(proto)
   914  		default:
   915  			// Ignore unknown extensions.
   916  			continue
   917  		}
   918  
   919  		if !extData.Empty() {
   920  			return false
   921  		}
   922  	}
   923  
   924  	return true
   925  }
   926  
   927  type endOfEarlyDataMsg struct{}
   928  
   929  func (m *endOfEarlyDataMsg) marshal() []byte {
   930  	x := make([]byte, 4)
   931  	x[0] = typeEndOfEarlyData
   932  	return x
   933  }
   934  
   935  func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
   936  	return len(data) == 4
   937  }
   938  
   939  type keyUpdateMsg struct {
   940  	raw             []byte
   941  	updateRequested bool
   942  }
   943  
   944  func (m *keyUpdateMsg) marshal() []byte {
   945  	if m.raw != nil {
   946  		return m.raw
   947  	}
   948  
   949  	var b cryptobyte.Builder
   950  	b.AddUint8(typeKeyUpdate)
   951  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   952  		if m.updateRequested {
   953  			b.AddUint8(1)
   954  		} else {
   955  			b.AddUint8(0)
   956  		}
   957  	})
   958  
   959  	m.raw = b.BytesOrPanic()
   960  	return m.raw
   961  }
   962  
   963  func (m *keyUpdateMsg) unmarshal(data []byte) bool {
   964  	m.raw = data
   965  	s := cryptobyte.String(data)
   966  
   967  	var updateRequested uint8
   968  	if !s.Skip(4) || // message type and uint24 length field
   969  		!s.ReadUint8(&updateRequested) || !s.Empty() {
   970  		return false
   971  	}
   972  	switch updateRequested {
   973  	case 0:
   974  		m.updateRequested = false
   975  	case 1:
   976  		m.updateRequested = true
   977  	default:
   978  		return false
   979  	}
   980  	return true
   981  }
   982  
   983  type newSessionTicketMsgTLS13 struct {
   984  	raw          []byte
   985  	lifetime     uint32
   986  	ageAdd       uint32
   987  	nonce        []byte
   988  	label        []byte
   989  	maxEarlyData uint32
   990  }
   991  
   992  func (m *newSessionTicketMsgTLS13) marshal() []byte {
   993  	if m.raw != nil {
   994  		return m.raw
   995  	}
   996  
   997  	var b cryptobyte.Builder
   998  	b.AddUint8(typeNewSessionTicket)
   999  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1000  		b.AddUint32(m.lifetime)
  1001  		b.AddUint32(m.ageAdd)
  1002  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
  1003  			b.AddBytes(m.nonce)
  1004  		})
  1005  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1006  			b.AddBytes(m.label)
  1007  		})
  1008  
  1009  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1010  			if m.maxEarlyData > 0 {
  1011  				b.AddUint16(extensionEarlyData)
  1012  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1013  					b.AddUint32(m.maxEarlyData)
  1014  				})
  1015  			}
  1016  		})
  1017  	})
  1018  
  1019  	m.raw = b.BytesOrPanic()
  1020  	return m.raw
  1021  }
  1022  
  1023  func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
  1024  	*m = newSessionTicketMsgTLS13{raw: data}
  1025  	s := cryptobyte.String(data)
  1026  
  1027  	var extensions cryptobyte.String
  1028  	if !s.Skip(4) || // message type and uint24 length field
  1029  		!s.ReadUint32(&m.lifetime) ||
  1030  		!s.ReadUint32(&m.ageAdd) ||
  1031  		!readUint8LengthPrefixed(&s, &m.nonce) ||
  1032  		!readUint16LengthPrefixed(&s, &m.label) ||
  1033  		!s.ReadUint16LengthPrefixed(&extensions) ||
  1034  		!s.Empty() {
  1035  		return false
  1036  	}
  1037  
  1038  	for !extensions.Empty() {
  1039  		var extension uint16
  1040  		var extData cryptobyte.String
  1041  		if !extensions.ReadUint16(&extension) ||
  1042  			!extensions.ReadUint16LengthPrefixed(&extData) {
  1043  			return false
  1044  		}
  1045  
  1046  		switch extension {
  1047  		case extensionEarlyData:
  1048  			if !extData.ReadUint32(&m.maxEarlyData) {
  1049  				return false
  1050  			}
  1051  		default:
  1052  			// Ignore unknown extensions.
  1053  			continue
  1054  		}
  1055  
  1056  		if !extData.Empty() {
  1057  			return false
  1058  		}
  1059  	}
  1060  
  1061  	return true
  1062  }
  1063  
  1064  type certificateRequestMsgTLS13 struct {
  1065  	raw                              []byte
  1066  	ocspStapling                     bool
  1067  	scts                             bool
  1068  	supportedSignatureAlgorithms     []SignatureScheme
  1069  	supportedSignatureAlgorithmsCert []SignatureScheme
  1070  	certificateAuthorities           [][]byte
  1071  }
  1072  
  1073  func (m *certificateRequestMsgTLS13) marshal() []byte {
  1074  	if m.raw != nil {
  1075  		return m.raw
  1076  	}
  1077  
  1078  	var b cryptobyte.Builder
  1079  	b.AddUint8(typeCertificateRequest)
  1080  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1081  		// certificate_request_context (SHALL be zero length unless used for
  1082  		// post-handshake authentication)
  1083  		b.AddUint8(0)
  1084  
  1085  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1086  			if m.ocspStapling {
  1087  				b.AddUint16(extensionStatusRequest)
  1088  				b.AddUint16(0) // empty extension_data
  1089  			}
  1090  			if m.scts {
  1091  				// RFC 8446, Section 4.4.2.1 makes no mention of
  1092  				// signed_certificate_timestamp in CertificateRequest, but
  1093  				// "Extensions in the Certificate message from the client MUST
  1094  				// correspond to extensions in the CertificateRequest message
  1095  				// from the server." and it appears in the table in Section 4.2.
  1096  				b.AddUint16(extensionSCT)
  1097  				b.AddUint16(0) // empty extension_data
  1098  			}
  1099  			if len(m.supportedSignatureAlgorithms) > 0 {
  1100  				b.AddUint16(extensionSignatureAlgorithms)
  1101  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1102  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1103  						for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1104  							b.AddUint16(uint16(sigAlgo))
  1105  						}
  1106  					})
  1107  				})
  1108  			}
  1109  			if len(m.supportedSignatureAlgorithmsCert) > 0 {
  1110  				b.AddUint16(extensionSignatureAlgorithmsCert)
  1111  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1112  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1113  						for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
  1114  							b.AddUint16(uint16(sigAlgo))
  1115  						}
  1116  					})
  1117  				})
  1118  			}
  1119  			if len(m.certificateAuthorities) > 0 {
  1120  				b.AddUint16(extensionCertificateAuthorities)
  1121  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1122  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1123  						for _, ca := range m.certificateAuthorities {
  1124  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1125  								b.AddBytes(ca)
  1126  							})
  1127  						}
  1128  					})
  1129  				})
  1130  			}
  1131  		})
  1132  	})
  1133  
  1134  	m.raw = b.BytesOrPanic()
  1135  	return m.raw
  1136  }
  1137  
  1138  func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
  1139  	*m = certificateRequestMsgTLS13{raw: data}
  1140  	s := cryptobyte.String(data)
  1141  
  1142  	var context, extensions cryptobyte.String
  1143  	if !s.Skip(4) || // message type and uint24 length field
  1144  		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
  1145  		!s.ReadUint16LengthPrefixed(&extensions) ||
  1146  		!s.Empty() {
  1147  		return false
  1148  	}
  1149  
  1150  	for !extensions.Empty() {
  1151  		var extension uint16
  1152  		var extData cryptobyte.String
  1153  		if !extensions.ReadUint16(&extension) ||
  1154  			!extensions.ReadUint16LengthPrefixed(&extData) {
  1155  			return false
  1156  		}
  1157  
  1158  		switch extension {
  1159  		case extensionStatusRequest:
  1160  			m.ocspStapling = true
  1161  		case extensionSCT:
  1162  			m.scts = true
  1163  		case extensionSignatureAlgorithms:
  1164  			var sigAndAlgs cryptobyte.String
  1165  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
  1166  				return false
  1167  			}
  1168  			for !sigAndAlgs.Empty() {
  1169  				var sigAndAlg uint16
  1170  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
  1171  					return false
  1172  				}
  1173  				m.supportedSignatureAlgorithms = append(
  1174  					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
  1175  			}
  1176  		case extensionSignatureAlgorithmsCert:
  1177  			var sigAndAlgs cryptobyte.String
  1178  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
  1179  				return false
  1180  			}
  1181  			for !sigAndAlgs.Empty() {
  1182  				var sigAndAlg uint16
  1183  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
  1184  					return false
  1185  				}
  1186  				m.supportedSignatureAlgorithmsCert = append(
  1187  					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
  1188  			}
  1189  		case extensionCertificateAuthorities:
  1190  			var auths cryptobyte.String
  1191  			if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() {
  1192  				return false
  1193  			}
  1194  			for !auths.Empty() {
  1195  				var ca []byte
  1196  				if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 {
  1197  					return false
  1198  				}
  1199  				m.certificateAuthorities = append(m.certificateAuthorities, ca)
  1200  			}
  1201  		default:
  1202  			// Ignore unknown extensions.
  1203  			continue
  1204  		}
  1205  
  1206  		if !extData.Empty() {
  1207  			return false
  1208  		}
  1209  	}
  1210  
  1211  	return true
  1212  }
  1213  
  1214  type certificateMsg struct {
  1215  	raw          []byte
  1216  	certificates [][]byte
  1217  }
  1218  
  1219  func (m *certificateMsg) marshal() (x []byte) {
  1220  	if m.raw != nil {
  1221  		return m.raw
  1222  	}
  1223  
  1224  	var i int
  1225  	for _, slice := range m.certificates {
  1226  		i += len(slice)
  1227  	}
  1228  
  1229  	length := 3 + 3*len(m.certificates) + i
  1230  	x = make([]byte, 4+length)
  1231  	x[0] = typeCertificate
  1232  	x[1] = uint8(length >> 16)
  1233  	x[2] = uint8(length >> 8)
  1234  	x[3] = uint8(length)
  1235  
  1236  	certificateOctets := length - 3
  1237  	x[4] = uint8(certificateOctets >> 16)
  1238  	x[5] = uint8(certificateOctets >> 8)
  1239  	x[6] = uint8(certificateOctets)
  1240  
  1241  	y := x[7:]
  1242  	for _, slice := range m.certificates {
  1243  		y[0] = uint8(len(slice) >> 16)
  1244  		y[1] = uint8(len(slice) >> 8)
  1245  		y[2] = uint8(len(slice))
  1246  		copy(y[3:], slice)
  1247  		y = y[3+len(slice):]
  1248  	}
  1249  
  1250  	m.raw = x
  1251  	return
  1252  }
  1253  
  1254  func (m *certificateMsg) unmarshal(data []byte) bool {
  1255  	if len(data) < 7 {
  1256  		return false
  1257  	}
  1258  
  1259  	m.raw = data
  1260  	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
  1261  	if uint32(len(data)) != certsLen+7 {
  1262  		return false
  1263  	}
  1264  
  1265  	numCerts := 0
  1266  	d := data[7:]
  1267  	for certsLen > 0 {
  1268  		if len(d) < 4 {
  1269  			return false
  1270  		}
  1271  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1272  		if uint32(len(d)) < 3+certLen {
  1273  			return false
  1274  		}
  1275  		d = d[3+certLen:]
  1276  		certsLen -= 3 + certLen
  1277  		numCerts++
  1278  	}
  1279  
  1280  	m.certificates = make([][]byte, numCerts)
  1281  	d = data[7:]
  1282  	for i := 0; i < numCerts; i++ {
  1283  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1284  		m.certificates[i] = d[3 : 3+certLen]
  1285  		d = d[3+certLen:]
  1286  	}
  1287  
  1288  	return true
  1289  }
  1290  
  1291  type certificateMsgTLS13 struct {
  1292  	raw          []byte
  1293  	certificate  Certificate
  1294  	ocspStapling bool
  1295  	scts         bool
  1296  }
  1297  
  1298  func (m *certificateMsgTLS13) marshal() []byte {
  1299  	if m.raw != nil {
  1300  		return m.raw
  1301  	}
  1302  
  1303  	var b cryptobyte.Builder
  1304  	b.AddUint8(typeCertificate)
  1305  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1306  		b.AddUint8(0) // certificate_request_context
  1307  
  1308  		certificate := m.certificate
  1309  		if !m.ocspStapling {
  1310  			certificate.OCSPStaple = nil
  1311  		}
  1312  		if !m.scts {
  1313  			certificate.SignedCertificateTimestamps = nil
  1314  		}
  1315  		marshalCertificate(b, certificate)
  1316  	})
  1317  
  1318  	m.raw = b.BytesOrPanic()
  1319  	return m.raw
  1320  }
  1321  
  1322  func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
  1323  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1324  		for i, cert := range certificate.Certificate {
  1325  			b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1326  				b.AddBytes(cert)
  1327  			})
  1328  			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1329  				if i > 0 {
  1330  					// This library only supports OCSP and SCT for leaf certificates.
  1331  					return
  1332  				}
  1333  				if certificate.OCSPStaple != nil {
  1334  					b.AddUint16(extensionStatusRequest)
  1335  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1336  						b.AddUint8(statusTypeOCSP)
  1337  						b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1338  							b.AddBytes(certificate.OCSPStaple)
  1339  						})
  1340  					})
  1341  				}
  1342  				if certificate.SignedCertificateTimestamps != nil {
  1343  					b.AddUint16(extensionSCT)
  1344  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1345  						b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1346  							for _, sct := range certificate.SignedCertificateTimestamps {
  1347  								b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1348  									b.AddBytes(sct)
  1349  								})
  1350  							}
  1351  						})
  1352  					})
  1353  				}
  1354  			})
  1355  		}
  1356  	})
  1357  }
  1358  
  1359  func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
  1360  	*m = certificateMsgTLS13{raw: data}
  1361  	s := cryptobyte.String(data)
  1362  
  1363  	var context cryptobyte.String
  1364  	if !s.Skip(4) || // message type and uint24 length field
  1365  		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
  1366  		!unmarshalCertificate(&s, &m.certificate) ||
  1367  		!s.Empty() {
  1368  		return false
  1369  	}
  1370  
  1371  	m.scts = m.certificate.SignedCertificateTimestamps != nil
  1372  	m.ocspStapling = m.certificate.OCSPStaple != nil
  1373  
  1374  	return true
  1375  }
  1376  
  1377  func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
  1378  	var certList cryptobyte.String
  1379  	if !s.ReadUint24LengthPrefixed(&certList) {
  1380  		return false
  1381  	}
  1382  	for !certList.Empty() {
  1383  		var cert []byte
  1384  		var extensions cryptobyte.String
  1385  		if !readUint24LengthPrefixed(&certList, &cert) ||
  1386  			!certList.ReadUint16LengthPrefixed(&extensions) {
  1387  			return false
  1388  		}
  1389  		certificate.Certificate = append(certificate.Certificate, cert)
  1390  		for !extensions.Empty() {
  1391  			var extension uint16
  1392  			var extData cryptobyte.String
  1393  			if !extensions.ReadUint16(&extension) ||
  1394  				!extensions.ReadUint16LengthPrefixed(&extData) {
  1395  				return false
  1396  			}
  1397  			if len(certificate.Certificate) > 1 {
  1398  				// This library only supports OCSP and SCT for leaf certificates.
  1399  				continue
  1400  			}
  1401  
  1402  			switch extension {
  1403  			case extensionStatusRequest:
  1404  				var statusType uint8
  1405  				if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
  1406  					!readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) ||
  1407  					len(certificate.OCSPStaple) == 0 {
  1408  					return false
  1409  				}
  1410  			case extensionSCT:
  1411  				var sctList cryptobyte.String
  1412  				if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
  1413  					return false
  1414  				}
  1415  				for !sctList.Empty() {
  1416  					var sct []byte
  1417  					if !readUint16LengthPrefixed(&sctList, &sct) ||
  1418  						len(sct) == 0 {
  1419  						return false
  1420  					}
  1421  					certificate.SignedCertificateTimestamps = append(
  1422  						certificate.SignedCertificateTimestamps, sct)
  1423  				}
  1424  			default:
  1425  				// Ignore unknown extensions.
  1426  				continue
  1427  			}
  1428  
  1429  			if !extData.Empty() {
  1430  				return false
  1431  			}
  1432  		}
  1433  	}
  1434  	return true
  1435  }
  1436  
  1437  type serverKeyExchangeMsg struct {
  1438  	raw []byte
  1439  	key []byte
  1440  }
  1441  
  1442  func (m *serverKeyExchangeMsg) marshal() []byte {
  1443  	if m.raw != nil {
  1444  		return m.raw
  1445  	}
  1446  	length := len(m.key)
  1447  	x := make([]byte, length+4)
  1448  	x[0] = typeServerKeyExchange
  1449  	x[1] = uint8(length >> 16)
  1450  	x[2] = uint8(length >> 8)
  1451  	x[3] = uint8(length)
  1452  	copy(x[4:], m.key)
  1453  
  1454  	m.raw = x
  1455  	return x
  1456  }
  1457  
  1458  func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
  1459  	m.raw = data
  1460  	if len(data) < 4 {
  1461  		return false
  1462  	}
  1463  	m.key = data[4:]
  1464  	return true
  1465  }
  1466  
  1467  type certificateStatusMsg struct {
  1468  	raw      []byte
  1469  	response []byte
  1470  }
  1471  
  1472  func (m *certificateStatusMsg) marshal() []byte {
  1473  	if m.raw != nil {
  1474  		return m.raw
  1475  	}
  1476  
  1477  	var b cryptobyte.Builder
  1478  	b.AddUint8(typeCertificateStatus)
  1479  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1480  		b.AddUint8(statusTypeOCSP)
  1481  		b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1482  			b.AddBytes(m.response)
  1483  		})
  1484  	})
  1485  
  1486  	m.raw = b.BytesOrPanic()
  1487  	return m.raw
  1488  }
  1489  
  1490  func (m *certificateStatusMsg) unmarshal(data []byte) bool {
  1491  	m.raw = data
  1492  	s := cryptobyte.String(data)
  1493  
  1494  	var statusType uint8
  1495  	if !s.Skip(4) || // message type and uint24 length field
  1496  		!s.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
  1497  		!readUint24LengthPrefixed(&s, &m.response) ||
  1498  		len(m.response) == 0 || !s.Empty() {
  1499  		return false
  1500  	}
  1501  	return true
  1502  }
  1503  
  1504  type serverHelloDoneMsg struct{}
  1505  
  1506  func (m *serverHelloDoneMsg) marshal() []byte {
  1507  	x := make([]byte, 4)
  1508  	x[0] = typeServerHelloDone
  1509  	return x
  1510  }
  1511  
  1512  func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
  1513  	return len(data) == 4
  1514  }
  1515  
  1516  type clientKeyExchangeMsg struct {
  1517  	raw        []byte
  1518  	ciphertext []byte
  1519  }
  1520  
  1521  func (m *clientKeyExchangeMsg) marshal() []byte {
  1522  	if m.raw != nil {
  1523  		return m.raw
  1524  	}
  1525  	length := len(m.ciphertext)
  1526  	x := make([]byte, length+4)
  1527  	x[0] = typeClientKeyExchange
  1528  	x[1] = uint8(length >> 16)
  1529  	x[2] = uint8(length >> 8)
  1530  	x[3] = uint8(length)
  1531  	copy(x[4:], m.ciphertext)
  1532  
  1533  	m.raw = x
  1534  	return x
  1535  }
  1536  
  1537  func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
  1538  	m.raw = data
  1539  	if len(data) < 4 {
  1540  		return false
  1541  	}
  1542  	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  1543  	if l != len(data)-4 {
  1544  		return false
  1545  	}
  1546  	m.ciphertext = data[4:]
  1547  	return true
  1548  }
  1549  
  1550  type finishedMsg struct {
  1551  	raw        []byte
  1552  	verifyData []byte
  1553  }
  1554  
  1555  func (m *finishedMsg) marshal() []byte {
  1556  	if m.raw != nil {
  1557  		return m.raw
  1558  	}
  1559  
  1560  	var b cryptobyte.Builder
  1561  	b.AddUint8(typeFinished)
  1562  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1563  		b.AddBytes(m.verifyData)
  1564  	})
  1565  
  1566  	m.raw = b.BytesOrPanic()
  1567  	return m.raw
  1568  }
  1569  
  1570  func (m *finishedMsg) unmarshal(data []byte) bool {
  1571  	m.raw = data
  1572  	s := cryptobyte.String(data)
  1573  	return s.Skip(1) &&
  1574  		readUint24LengthPrefixed(&s, &m.verifyData) &&
  1575  		s.Empty()
  1576  }
  1577  
  1578  type certificateRequestMsg struct {
  1579  	raw []byte
  1580  	// hasSignatureAlgorithm indicates whether this message includes a list of
  1581  	// supported signature algorithms. This change was introduced with TLS 1.2.
  1582  	hasSignatureAlgorithm bool
  1583  
  1584  	certificateTypes             []byte
  1585  	supportedSignatureAlgorithms []SignatureScheme
  1586  	certificateAuthorities       [][]byte
  1587  }
  1588  
  1589  func (m *certificateRequestMsg) marshal() (x []byte) {
  1590  	if m.raw != nil {
  1591  		return m.raw
  1592  	}
  1593  
  1594  	// See RFC 4346, Section 7.4.4.
  1595  	length := 1 + len(m.certificateTypes) + 2
  1596  	casLength := 0
  1597  	for _, ca := range m.certificateAuthorities {
  1598  		casLength += 2 + len(ca)
  1599  	}
  1600  	length += casLength
  1601  
  1602  	if m.hasSignatureAlgorithm {
  1603  		length += 2 + 2*len(m.supportedSignatureAlgorithms)
  1604  	}
  1605  
  1606  	x = make([]byte, 4+length)
  1607  	x[0] = typeCertificateRequest
  1608  	x[1] = uint8(length >> 16)
  1609  	x[2] = uint8(length >> 8)
  1610  	x[3] = uint8(length)
  1611  
  1612  	x[4] = uint8(len(m.certificateTypes))
  1613  
  1614  	copy(x[5:], m.certificateTypes)
  1615  	y := x[5+len(m.certificateTypes):]
  1616  
  1617  	if m.hasSignatureAlgorithm {
  1618  		n := len(m.supportedSignatureAlgorithms) * 2
  1619  		y[0] = uint8(n >> 8)
  1620  		y[1] = uint8(n)
  1621  		y = y[2:]
  1622  		for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1623  			y[0] = uint8(sigAlgo >> 8)
  1624  			y[1] = uint8(sigAlgo)
  1625  			y = y[2:]
  1626  		}
  1627  	}
  1628  
  1629  	y[0] = uint8(casLength >> 8)
  1630  	y[1] = uint8(casLength)
  1631  	y = y[2:]
  1632  	for _, ca := range m.certificateAuthorities {
  1633  		y[0] = uint8(len(ca) >> 8)
  1634  		y[1] = uint8(len(ca))
  1635  		y = y[2:]
  1636  		copy(y, ca)
  1637  		y = y[len(ca):]
  1638  	}
  1639  
  1640  	m.raw = x
  1641  	return
  1642  }
  1643  
  1644  func (m *certificateRequestMsg) unmarshal(data []byte) bool {
  1645  	m.raw = data
  1646  
  1647  	if len(data) < 5 {
  1648  		return false
  1649  	}
  1650  
  1651  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1652  	if uint32(len(data))-4 != length {
  1653  		return false
  1654  	}
  1655  
  1656  	numCertTypes := int(data[4])
  1657  	data = data[5:]
  1658  	if numCertTypes == 0 || len(data) <= numCertTypes {
  1659  		return false
  1660  	}
  1661  
  1662  	m.certificateTypes = make([]byte, numCertTypes)
  1663  	if copy(m.certificateTypes, data) != numCertTypes {
  1664  		return false
  1665  	}
  1666  
  1667  	data = data[numCertTypes:]
  1668  
  1669  	if m.hasSignatureAlgorithm {
  1670  		if len(data) < 2 {
  1671  			return false
  1672  		}
  1673  		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1674  		data = data[2:]
  1675  		if sigAndHashLen&1 != 0 {
  1676  			return false
  1677  		}
  1678  		if len(data) < int(sigAndHashLen) {
  1679  			return false
  1680  		}
  1681  		numSigAlgos := sigAndHashLen / 2
  1682  		m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
  1683  		for i := range m.supportedSignatureAlgorithms {
  1684  			m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  1685  			data = data[2:]
  1686  		}
  1687  	}
  1688  
  1689  	if len(data) < 2 {
  1690  		return false
  1691  	}
  1692  	casLength := uint16(data[0])<<8 | uint16(data[1])
  1693  	data = data[2:]
  1694  	if len(data) < int(casLength) {
  1695  		return false
  1696  	}
  1697  	cas := make([]byte, casLength)
  1698  	copy(cas, data)
  1699  	data = data[casLength:]
  1700  
  1701  	m.certificateAuthorities = nil
  1702  	for len(cas) > 0 {
  1703  		if len(cas) < 2 {
  1704  			return false
  1705  		}
  1706  		caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1707  		cas = cas[2:]
  1708  
  1709  		if len(cas) < int(caLen) {
  1710  			return false
  1711  		}
  1712  
  1713  		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1714  		cas = cas[caLen:]
  1715  	}
  1716  
  1717  	return len(data) == 0
  1718  }
  1719  
  1720  type certificateVerifyMsg struct {
  1721  	raw                   []byte
  1722  	hasSignatureAlgorithm bool // format change introduced in TLS 1.2
  1723  	signatureAlgorithm    SignatureScheme
  1724  	signature             []byte
  1725  }
  1726  
  1727  func (m *certificateVerifyMsg) marshal() (x []byte) {
  1728  	if m.raw != nil {
  1729  		return m.raw
  1730  	}
  1731  
  1732  	var b cryptobyte.Builder
  1733  	b.AddUint8(typeCertificateVerify)
  1734  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1735  		if m.hasSignatureAlgorithm {
  1736  			b.AddUint16(uint16(m.signatureAlgorithm))
  1737  		}
  1738  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1739  			b.AddBytes(m.signature)
  1740  		})
  1741  	})
  1742  
  1743  	m.raw = b.BytesOrPanic()
  1744  	return m.raw
  1745  }
  1746  
  1747  func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
  1748  	m.raw = data
  1749  	s := cryptobyte.String(data)
  1750  
  1751  	if !s.Skip(4) { // message type and uint24 length field
  1752  		return false
  1753  	}
  1754  	if m.hasSignatureAlgorithm {
  1755  		if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) {
  1756  			return false
  1757  		}
  1758  	}
  1759  	return readUint16LengthPrefixed(&s, &m.signature) && s.Empty()
  1760  }
  1761  
  1762  type newSessionTicketMsg struct {
  1763  	raw    []byte
  1764  	ticket []byte
  1765  }
  1766  
  1767  func (m *newSessionTicketMsg) marshal() (x []byte) {
  1768  	if m.raw != nil {
  1769  		return m.raw
  1770  	}
  1771  
  1772  	// See RFC 5077, Section 3.3.
  1773  	ticketLen := len(m.ticket)
  1774  	length := 2 + 4 + ticketLen
  1775  	x = make([]byte, 4+length)
  1776  	x[0] = typeNewSessionTicket
  1777  	x[1] = uint8(length >> 16)
  1778  	x[2] = uint8(length >> 8)
  1779  	x[3] = uint8(length)
  1780  	x[8] = uint8(ticketLen >> 8)
  1781  	x[9] = uint8(ticketLen)
  1782  	copy(x[10:], m.ticket)
  1783  
  1784  	m.raw = x
  1785  
  1786  	return
  1787  }
  1788  
  1789  func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
  1790  	m.raw = data
  1791  
  1792  	if len(data) < 10 {
  1793  		return false
  1794  	}
  1795  
  1796  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1797  	if uint32(len(data))-4 != length {
  1798  		return false
  1799  	}
  1800  
  1801  	ticketLen := int(data[8])<<8 + int(data[9])
  1802  	if len(data)-10 != ticketLen {
  1803  		return false
  1804  	}
  1805  
  1806  	m.ticket = data[10:]
  1807  
  1808  	return true
  1809  }
  1810  
  1811  type helloRequestMsg struct {
  1812  }
  1813  
  1814  func (*helloRequestMsg) marshal() []byte {
  1815  	return []byte{typeHelloRequest, 0, 0, 0}
  1816  }
  1817  
  1818  func (*helloRequestMsg) unmarshal(data []byte) bool {
  1819  	return len(data) == 4
  1820  }
  1821  

View as plain text