main
1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package packet
6
7import (
8 "crypto/cipher"
9 "crypto/sha1"
10 "crypto/subtle"
11 "hash"
12 "io"
13 "strconv"
14
15 "github.com/ProtonMail/go-crypto/openpgp/errors"
16)
17
18// seMdcReader wraps an io.Reader with a no-op Close method.
19type seMdcReader struct {
20 in io.Reader
21}
22
23func (ser seMdcReader) Read(buf []byte) (int, error) {
24 return ser.in.Read(buf)
25}
26
27func (ser seMdcReader) Close() error {
28 return nil
29}
30
31func (se *SymmetricallyEncrypted) decryptMdc(c CipherFunction, key []byte) (io.ReadCloser, error) {
32 if !c.IsSupported() {
33 return nil, errors.UnsupportedError("unsupported cipher: " + strconv.Itoa(int(c)))
34 }
35
36 if len(key) != c.KeySize() {
37 return nil, errors.InvalidArgumentError("SymmetricallyEncrypted: incorrect key length")
38 }
39
40 if se.prefix == nil {
41 se.prefix = make([]byte, c.blockSize()+2)
42 _, err := readFull(se.Contents, se.prefix)
43 if err != nil {
44 return nil, err
45 }
46 } else if len(se.prefix) != c.blockSize()+2 {
47 return nil, errors.InvalidArgumentError("can't try ciphers with different block lengths")
48 }
49
50 ocfbResync := OCFBResync
51 if se.IntegrityProtected {
52 // MDC packets use a different form of OCFB mode.
53 ocfbResync = OCFBNoResync
54 }
55
56 s := NewOCFBDecrypter(c.new(key), se.prefix, ocfbResync)
57
58 plaintext := cipher.StreamReader{S: s, R: se.Contents}
59
60 if se.IntegrityProtected {
61 // IntegrityProtected packets have an embedded hash that we need to check.
62 h := sha1.New()
63 h.Write(se.prefix)
64 return &seMDCReader{in: plaintext, h: h}, nil
65 }
66
67 // Otherwise, we just need to wrap plaintext so that it's a valid ReadCloser.
68 return seMdcReader{plaintext}, nil
69}
70
71const mdcTrailerSize = 1 /* tag byte */ + 1 /* length byte */ + sha1.Size
72
73// An seMDCReader wraps an io.Reader, maintains a running hash and keeps hold
74// of the most recent 22 bytes (mdcTrailerSize). Upon EOF, those bytes form an
75// MDC packet containing a hash of the previous Contents which is checked
76// against the running hash. See RFC 4880, section 5.13.
77type seMDCReader struct {
78 in io.Reader
79 h hash.Hash
80 trailer [mdcTrailerSize]byte
81 scratch [mdcTrailerSize]byte
82 trailerUsed int
83 error bool
84 eof bool
85}
86
87func (ser *seMDCReader) Read(buf []byte) (n int, err error) {
88 if ser.error {
89 err = io.ErrUnexpectedEOF
90 return
91 }
92 if ser.eof {
93 err = io.EOF
94 return
95 }
96
97 // If we haven't yet filled the trailer buffer then we must do that
98 // first.
99 for ser.trailerUsed < mdcTrailerSize {
100 n, err = ser.in.Read(ser.trailer[ser.trailerUsed:])
101 ser.trailerUsed += n
102 if err == io.EOF {
103 if ser.trailerUsed != mdcTrailerSize {
104 n = 0
105 err = io.ErrUnexpectedEOF
106 ser.error = true
107 return
108 }
109 ser.eof = true
110 n = 0
111 return
112 }
113
114 if err != nil {
115 n = 0
116 return
117 }
118 }
119
120 // If it's a short read then we read into a temporary buffer and shift
121 // the data into the caller's buffer.
122 if len(buf) <= mdcTrailerSize {
123 n, err = readFull(ser.in, ser.scratch[:len(buf)])
124 copy(buf, ser.trailer[:n])
125 ser.h.Write(buf[:n])
126 copy(ser.trailer[:], ser.trailer[n:])
127 copy(ser.trailer[mdcTrailerSize-n:], ser.scratch[:])
128 if n < len(buf) {
129 ser.eof = true
130 err = io.EOF
131 }
132 return
133 }
134
135 n, err = ser.in.Read(buf[mdcTrailerSize:])
136 copy(buf, ser.trailer[:])
137 ser.h.Write(buf[:n])
138 copy(ser.trailer[:], buf[n:])
139
140 if err == io.EOF {
141 ser.eof = true
142 }
143 return
144}
145
146// This is a new-format packet tag byte for a type 19 (Integrity Protected) packet.
147const mdcPacketTagByte = byte(0x80) | 0x40 | 19
148
149func (ser *seMDCReader) Close() error {
150 if ser.error {
151 return errors.ErrMDCHashMismatch
152 }
153
154 for !ser.eof {
155 // We haven't seen EOF so we need to read to the end
156 var buf [1024]byte
157 _, err := ser.Read(buf[:])
158 if err == io.EOF {
159 break
160 }
161 if err != nil {
162 return errors.ErrMDCHashMismatch
163 }
164 }
165
166 ser.h.Write(ser.trailer[:2])
167
168 final := ser.h.Sum(nil)
169 if subtle.ConstantTimeCompare(final, ser.trailer[2:]) != 1 {
170 return errors.ErrMDCHashMismatch
171 }
172 // The hash already includes the MDC header, but we still check its value
173 // to confirm encryption correctness
174 if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size {
175 return errors.ErrMDCHashMismatch
176 }
177 return nil
178}
179
180// An seMDCWriter writes through to an io.WriteCloser while maintains a running
181// hash of the data written. On close, it emits an MDC packet containing the
182// running hash.
183type seMDCWriter struct {
184 w io.WriteCloser
185 h hash.Hash
186}
187
188func (w *seMDCWriter) Write(buf []byte) (n int, err error) {
189 w.h.Write(buf)
190 return w.w.Write(buf)
191}
192
193func (w *seMDCWriter) Close() (err error) {
194 var buf [mdcTrailerSize]byte
195
196 buf[0] = mdcPacketTagByte
197 buf[1] = sha1.Size
198 w.h.Write(buf[:2])
199 digest := w.h.Sum(nil)
200 copy(buf[2:], digest)
201
202 _, err = w.w.Write(buf[:])
203 if err != nil {
204 return
205 }
206 return w.w.Close()
207}
208
209// noOpCloser is like an ioutil.NopCloser, but for an io.Writer.
210type noOpCloser struct {
211 w io.Writer
212}
213
214func (c noOpCloser) Write(data []byte) (n int, err error) {
215 return c.w.Write(data)
216}
217
218func (c noOpCloser) Close() error {
219 return nil
220}
221
222func serializeSymmetricallyEncryptedMdc(ciphertext io.WriteCloser, c CipherFunction, key []byte, config *Config) (Contents io.WriteCloser, err error) {
223 // Disallow old cipher suites
224 if !c.IsSupported() || c < CipherAES128 {
225 return nil, errors.InvalidArgumentError("invalid mdc cipher function")
226 }
227
228 if c.KeySize() != len(key) {
229 return nil, errors.InvalidArgumentError("error in mdc serialization: bad key length")
230 }
231
232 _, err = ciphertext.Write([]byte{symmetricallyEncryptedVersionMdc})
233 if err != nil {
234 return
235 }
236
237 block := c.new(key)
238 blockSize := block.BlockSize()
239 iv := make([]byte, blockSize)
240 _, err = io.ReadFull(config.Random(), iv)
241 if err != nil {
242 return nil, err
243 }
244 s, prefix := NewOCFBEncrypter(block, iv, OCFBNoResync)
245 _, err = ciphertext.Write(prefix)
246 if err != nil {
247 return
248 }
249 plaintext := cipher.StreamWriter{S: s, W: ciphertext}
250
251 h := sha1.New()
252 h.Write(iv)
253 h.Write(iv[blockSize-2:])
254 Contents = &seMDCWriter{w: plaintext, h: h}
255 return
256}