...

Source file src/crypto/rsa/pss.go

Documentation: crypto/rsa

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rsa
     6  
     7  // This file implements the RSASSA-PSS signature scheme according to RFC 8017.
     8  
     9  import (
    10  	"bytes"
    11  	"crypto"
    12  	"crypto/internal/boring"
    13  	"errors"
    14  	"hash"
    15  	"io"
    16  	"math/big"
    17  )
    18  
    19  // Per RFC 8017, Section 9.1
    20  //
    21  //     EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc
    22  //
    23  // where
    24  //
    25  //     DB = PS || 0x01 || salt
    26  //
    27  // and PS can be empty so
    28  //
    29  //     emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2
    30  //
    31  
    32  func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
    33  	// See RFC 8017, Section 9.1.1.
    34  
    35  	hLen := hash.Size()
    36  	sLen := len(salt)
    37  	emLen := (emBits + 7) / 8
    38  
    39  	// 1.  If the length of M is greater than the input limitation for the
    40  	//     hash function (2^61 - 1 octets for SHA-1), output "message too
    41  	//     long" and stop.
    42  	//
    43  	// 2.  Let mHash = Hash(M), an octet string of length hLen.
    44  
    45  	if len(mHash) != hLen {
    46  		return nil, errors.New("crypto/rsa: input must be hashed with given hash")
    47  	}
    48  
    49  	// 3.  If emLen < hLen + sLen + 2, output "encoding error" and stop.
    50  
    51  	if emLen < hLen+sLen+2 {
    52  		return nil, errors.New("crypto/rsa: key size too small for PSS signature")
    53  	}
    54  
    55  	em := make([]byte, emLen)
    56  	psLen := emLen - sLen - hLen - 2
    57  	db := em[:psLen+1+sLen]
    58  	h := em[psLen+1+sLen : emLen-1]
    59  
    60  	// 4.  Generate a random octet string salt of length sLen; if sLen = 0,
    61  	//     then salt is the empty string.
    62  	//
    63  	// 5.  Let
    64  	//       M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt;
    65  	//
    66  	//     M' is an octet string of length 8 + hLen + sLen with eight
    67  	//     initial zero octets.
    68  	//
    69  	// 6.  Let H = Hash(M'), an octet string of length hLen.
    70  
    71  	var prefix [8]byte
    72  
    73  	hash.Write(prefix[:])
    74  	hash.Write(mHash)
    75  	hash.Write(salt)
    76  
    77  	h = hash.Sum(h[:0])
    78  	hash.Reset()
    79  
    80  	// 7.  Generate an octet string PS consisting of emLen - sLen - hLen - 2
    81  	//     zero octets. The length of PS may be 0.
    82  	//
    83  	// 8.  Let DB = PS || 0x01 || salt; DB is an octet string of length
    84  	//     emLen - hLen - 1.
    85  
    86  	db[psLen] = 0x01
    87  	copy(db[psLen+1:], salt)
    88  
    89  	// 9.  Let dbMask = MGF(H, emLen - hLen - 1).
    90  	//
    91  	// 10. Let maskedDB = DB \xor dbMask.
    92  
    93  	mgf1XOR(db, hash, h)
    94  
    95  	// 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in
    96  	//     maskedDB to zero.
    97  
    98  	db[0] &= 0xff >> (8*emLen - emBits)
    99  
   100  	// 12. Let EM = maskedDB || H || 0xbc.
   101  	em[emLen-1] = 0xbc
   102  
   103  	// 13. Output EM.
   104  	return em, nil
   105  }
   106  
   107  func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
   108  	// See RFC 8017, Section 9.1.2.
   109  
   110  	hLen := hash.Size()
   111  	if sLen == PSSSaltLengthEqualsHash {
   112  		sLen = hLen
   113  	}
   114  	emLen := (emBits + 7) / 8
   115  	if emLen != len(em) {
   116  		return errors.New("rsa: internal error: inconsistent length")
   117  	}
   118  
   119  	// 1.  If the length of M is greater than the input limitation for the
   120  	//     hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
   121  	//     and stop.
   122  	//
   123  	// 2.  Let mHash = Hash(M), an octet string of length hLen.
   124  	if hLen != len(mHash) {
   125  		return ErrVerification
   126  	}
   127  
   128  	// 3.  If emLen < hLen + sLen + 2, output "inconsistent" and stop.
   129  	if emLen < hLen+sLen+2 {
   130  		return ErrVerification
   131  	}
   132  
   133  	// 4.  If the rightmost octet of EM does not have hexadecimal value
   134  	//     0xbc, output "inconsistent" and stop.
   135  	if em[emLen-1] != 0xbc {
   136  		return ErrVerification
   137  	}
   138  
   139  	// 5.  Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
   140  	//     let H be the next hLen octets.
   141  	db := em[:emLen-hLen-1]
   142  	h := em[emLen-hLen-1 : emLen-1]
   143  
   144  	// 6.  If the leftmost 8 * emLen - emBits bits of the leftmost octet in
   145  	//     maskedDB are not all equal to zero, output "inconsistent" and
   146  	//     stop.
   147  	var bitMask byte = 0xff >> (8*emLen - emBits)
   148  	if em[0] & ^bitMask != 0 {
   149  		return ErrVerification
   150  	}
   151  
   152  	// 7.  Let dbMask = MGF(H, emLen - hLen - 1).
   153  	//
   154  	// 8.  Let DB = maskedDB \xor dbMask.
   155  	mgf1XOR(db, hash, h)
   156  
   157  	// 9.  Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
   158  	//     to zero.
   159  	db[0] &= bitMask
   160  
   161  	// If we don't know the salt length, look for the 0x01 delimiter.
   162  	if sLen == PSSSaltLengthAuto {
   163  		psLen := bytes.IndexByte(db, 0x01)
   164  		if psLen < 0 {
   165  			return ErrVerification
   166  		}
   167  		sLen = len(db) - psLen - 1
   168  	}
   169  
   170  	// 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
   171  	//     or if the octet at position emLen - hLen - sLen - 1 (the leftmost
   172  	//     position is "position 1") does not have hexadecimal value 0x01,
   173  	//     output "inconsistent" and stop.
   174  	psLen := emLen - hLen - sLen - 2
   175  	for _, e := range db[:psLen] {
   176  		if e != 0x00 {
   177  			return ErrVerification
   178  		}
   179  	}
   180  	if db[psLen] != 0x01 {
   181  		return ErrVerification
   182  	}
   183  
   184  	// 11.  Let salt be the last sLen octets of DB.
   185  	salt := db[len(db)-sLen:]
   186  
   187  	// 12.  Let
   188  	//          M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
   189  	//     M' is an octet string of length 8 + hLen + sLen with eight
   190  	//     initial zero octets.
   191  	//
   192  	// 13. Let H' = Hash(M'), an octet string of length hLen.
   193  	var prefix [8]byte
   194  	hash.Write(prefix[:])
   195  	hash.Write(mHash)
   196  	hash.Write(salt)
   197  
   198  	h0 := hash.Sum(nil)
   199  
   200  	// 14. If H = H', output "consistent." Otherwise, output "inconsistent."
   201  	if !bytes.Equal(h0, h) { // TODO: constant time?
   202  		return ErrVerification
   203  	}
   204  	return nil
   205  }
   206  
   207  // signPSSWithSalt calculates the signature of hashed using PSS with specified salt.
   208  // Note that hashed must be the result of hashing the input message using the
   209  // given hash function. salt is a random sequence of bytes whose length will be
   210  // later used to verify the signature.
   211  func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
   212  	emBits := priv.N.BitLen() - 1
   213  	em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
   214  	if err != nil {
   215  		return nil, err
   216  	}
   217  
   218  	if boring.Enabled {
   219  		bkey, err := boringPrivateKey(priv)
   220  		if err != nil {
   221  			return nil, err
   222  		}
   223  		// Note: BoringCrypto takes care of the "AndCheck" part of "decryptAndCheck".
   224  		// (It's not just decrypt.)
   225  		s, err := boring.DecryptRSANoPadding(bkey, em)
   226  		if err != nil {
   227  			return nil, err
   228  		}
   229  		return s, nil
   230  	}
   231  
   232  	m := new(big.Int).SetBytes(em)
   233  	c, err := decryptAndCheck(rand, priv, m)
   234  	if err != nil {
   235  		return nil, err
   236  	}
   237  	s := make([]byte, priv.Size())
   238  	return c.FillBytes(s), nil
   239  }
   240  
   241  const (
   242  	// PSSSaltLengthAuto causes the salt in a PSS signature to be as large
   243  	// as possible when signing, and to be auto-detected when verifying.
   244  	PSSSaltLengthAuto = 0
   245  	// PSSSaltLengthEqualsHash causes the salt length to equal the length
   246  	// of the hash used in the signature.
   247  	PSSSaltLengthEqualsHash = -1
   248  )
   249  
   250  // PSSOptions contains options for creating and verifying PSS signatures.
   251  type PSSOptions struct {
   252  	// SaltLength controls the length of the salt used in the PSS
   253  	// signature. It can either be a number of bytes, or one of the special
   254  	// PSSSaltLength constants.
   255  	SaltLength int
   256  
   257  	// Hash is the hash function used to generate the message digest. If not
   258  	// zero, it overrides the hash function passed to SignPSS. It's required
   259  	// when using PrivateKey.Sign.
   260  	Hash crypto.Hash
   261  }
   262  
   263  // HashFunc returns opts.Hash so that PSSOptions implements crypto.SignerOpts.
   264  func (opts *PSSOptions) HashFunc() crypto.Hash {
   265  	return opts.Hash
   266  }
   267  
   268  func (opts *PSSOptions) saltLength() int {
   269  	if opts == nil {
   270  		return PSSSaltLengthAuto
   271  	}
   272  	return opts.SaltLength
   273  }
   274  
   275  // SignPSS calculates the signature of digest using PSS.
   276  //
   277  // digest must be the result of hashing the input message using the given hash
   278  // function. The opts argument may be nil, in which case sensible defaults are
   279  // used. If opts.Hash is set, it overrides hash.
   280  func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) {
   281  	if opts != nil && opts.Hash != 0 {
   282  		hash = opts.Hash
   283  	}
   284  
   285  	saltLength := opts.saltLength()
   286  	switch saltLength {
   287  	case PSSSaltLengthAuto:
   288  		saltLength = (priv.N.BitLen()-1+7)/8 - 2 - hash.Size()
   289  	case PSSSaltLengthEqualsHash:
   290  		saltLength = hash.Size()
   291  	}
   292  
   293  	if boring.Enabled && rand == boring.RandReader {
   294  		bkey, err := boringPrivateKey(priv)
   295  		if err != nil {
   296  			return nil, err
   297  		}
   298  		return boring.SignRSAPSS(bkey, hash, digest, saltLength)
   299  	}
   300  	boring.UnreachableExceptTests()
   301  
   302  	salt := make([]byte, saltLength)
   303  	if _, err := io.ReadFull(rand, salt); err != nil {
   304  		return nil, err
   305  	}
   306  	return signPSSWithSalt(rand, priv, hash, digest, salt)
   307  }
   308  
   309  // VerifyPSS verifies a PSS signature.
   310  //
   311  // A valid signature is indicated by returning a nil error. digest must be the
   312  // result of hashing the input message using the given hash function. The opts
   313  // argument may be nil, in which case sensible defaults are used. opts.Hash is
   314  // ignored.
   315  func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error {
   316  	if boring.Enabled {
   317  		bkey, err := boringPublicKey(pub)
   318  		if err != nil {
   319  			return err
   320  		}
   321  		if err := boring.VerifyRSAPSS(bkey, hash, digest, sig, opts.saltLength()); err != nil {
   322  			return ErrVerification
   323  		}
   324  		return nil
   325  	}
   326  	if len(sig) != pub.Size() {
   327  		return ErrVerification
   328  	}
   329  	s := new(big.Int).SetBytes(sig)
   330  	m := encrypt(new(big.Int), pub, s)
   331  	emBits := pub.N.BitLen() - 1
   332  	emLen := (emBits + 7) / 8
   333  	if m.BitLen() > emLen*8 {
   334  		return ErrVerification
   335  	}
   336  	em := m.FillBytes(make([]byte, emLen))
   337  	return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
   338  }
   339  

View as plain text