1
2
3
4
5 package rsa
6
7
8
9 import (
10 "bytes"
11 "crypto"
12 "crypto/internal/boring"
13 "errors"
14 "hash"
15 "io"
16 "math/big"
17 )
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32 func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
33
34
35 hLen := hash.Size()
36 sLen := len(salt)
37 emLen := (emBits + 7) / 8
38
39
40
41
42
43
44
45 if len(mHash) != hLen {
46 return nil, errors.New("crypto/rsa: input must be hashed with given hash")
47 }
48
49
50
51 if emLen < hLen+sLen+2 {
52 return nil, errors.New("crypto/rsa: key size too small for PSS signature")
53 }
54
55 em := make([]byte, emLen)
56 psLen := emLen - sLen - hLen - 2
57 db := em[:psLen+1+sLen]
58 h := em[psLen+1+sLen : emLen-1]
59
60
61
62
63
64
65
66
67
68
69
70
71 var prefix [8]byte
72
73 hash.Write(prefix[:])
74 hash.Write(mHash)
75 hash.Write(salt)
76
77 h = hash.Sum(h[:0])
78 hash.Reset()
79
80
81
82
83
84
85
86 db[psLen] = 0x01
87 copy(db[psLen+1:], salt)
88
89
90
91
92
93 mgf1XOR(db, hash, h)
94
95
96
97
98 db[0] &= 0xff >> (8*emLen - emBits)
99
100
101 em[emLen-1] = 0xbc
102
103
104 return em, nil
105 }
106
107 func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
108
109
110 hLen := hash.Size()
111 if sLen == PSSSaltLengthEqualsHash {
112 sLen = hLen
113 }
114 emLen := (emBits + 7) / 8
115 if emLen != len(em) {
116 return errors.New("rsa: internal error: inconsistent length")
117 }
118
119
120
121
122
123
124 if hLen != len(mHash) {
125 return ErrVerification
126 }
127
128
129 if emLen < hLen+sLen+2 {
130 return ErrVerification
131 }
132
133
134
135 if em[emLen-1] != 0xbc {
136 return ErrVerification
137 }
138
139
140
141 db := em[:emLen-hLen-1]
142 h := em[emLen-hLen-1 : emLen-1]
143
144
145
146
147 var bitMask byte = 0xff >> (8*emLen - emBits)
148 if em[0] & ^bitMask != 0 {
149 return ErrVerification
150 }
151
152
153
154
155 mgf1XOR(db, hash, h)
156
157
158
159 db[0] &= bitMask
160
161
162 if sLen == PSSSaltLengthAuto {
163 psLen := bytes.IndexByte(db, 0x01)
164 if psLen < 0 {
165 return ErrVerification
166 }
167 sLen = len(db) - psLen - 1
168 }
169
170
171
172
173
174 psLen := emLen - hLen - sLen - 2
175 for _, e := range db[:psLen] {
176 if e != 0x00 {
177 return ErrVerification
178 }
179 }
180 if db[psLen] != 0x01 {
181 return ErrVerification
182 }
183
184
185 salt := db[len(db)-sLen:]
186
187
188
189
190
191
192
193 var prefix [8]byte
194 hash.Write(prefix[:])
195 hash.Write(mHash)
196 hash.Write(salt)
197
198 h0 := hash.Sum(nil)
199
200
201 if !bytes.Equal(h0, h) {
202 return ErrVerification
203 }
204 return nil
205 }
206
207
208
209
210
211 func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
212 emBits := priv.N.BitLen() - 1
213 em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
214 if err != nil {
215 return nil, err
216 }
217
218 if boring.Enabled {
219 bkey, err := boringPrivateKey(priv)
220 if err != nil {
221 return nil, err
222 }
223
224
225 s, err := boring.DecryptRSANoPadding(bkey, em)
226 if err != nil {
227 return nil, err
228 }
229 return s, nil
230 }
231
232 m := new(big.Int).SetBytes(em)
233 c, err := decryptAndCheck(rand, priv, m)
234 if err != nil {
235 return nil, err
236 }
237 s := make([]byte, priv.Size())
238 return c.FillBytes(s), nil
239 }
240
241 const (
242
243
244 PSSSaltLengthAuto = 0
245
246
247 PSSSaltLengthEqualsHash = -1
248 )
249
250
251 type PSSOptions struct {
252
253
254
255 SaltLength int
256
257
258
259
260 Hash crypto.Hash
261 }
262
263
264 func (opts *PSSOptions) HashFunc() crypto.Hash {
265 return opts.Hash
266 }
267
268 func (opts *PSSOptions) saltLength() int {
269 if opts == nil {
270 return PSSSaltLengthAuto
271 }
272 return opts.SaltLength
273 }
274
275
276
277
278
279
280 func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) {
281 if opts != nil && opts.Hash != 0 {
282 hash = opts.Hash
283 }
284
285 saltLength := opts.saltLength()
286 switch saltLength {
287 case PSSSaltLengthAuto:
288 saltLength = (priv.N.BitLen()-1+7)/8 - 2 - hash.Size()
289 case PSSSaltLengthEqualsHash:
290 saltLength = hash.Size()
291 }
292
293 if boring.Enabled && rand == boring.RandReader {
294 bkey, err := boringPrivateKey(priv)
295 if err != nil {
296 return nil, err
297 }
298 return boring.SignRSAPSS(bkey, hash, digest, saltLength)
299 }
300 boring.UnreachableExceptTests()
301
302 salt := make([]byte, saltLength)
303 if _, err := io.ReadFull(rand, salt); err != nil {
304 return nil, err
305 }
306 return signPSSWithSalt(rand, priv, hash, digest, salt)
307 }
308
309
310
311
312
313
314
315 func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error {
316 if boring.Enabled {
317 bkey, err := boringPublicKey(pub)
318 if err != nil {
319 return err
320 }
321 if err := boring.VerifyRSAPSS(bkey, hash, digest, sig, opts.saltLength()); err != nil {
322 return ErrVerification
323 }
324 return nil
325 }
326 if len(sig) != pub.Size() {
327 return ErrVerification
328 }
329 s := new(big.Int).SetBytes(sig)
330 m := encrypt(new(big.Int), pub, s)
331 emBits := pub.N.BitLen() - 1
332 emLen := (emBits + 7) / 8
333 if m.BitLen() > emLen*8 {
334 return ErrVerification
335 }
336 em := m.FillBytes(make([]byte, emLen))
337 return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
338 }
339
View as plain text