main
Raw Download raw file
  1// Copyright 2011 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package packet
  6
  7import (
  8	"crypto/cipher"
  9	"crypto/sha1"
 10	"crypto/subtle"
 11	"hash"
 12	"io"
 13	"strconv"
 14
 15	"github.com/ProtonMail/go-crypto/openpgp/errors"
 16)
 17
 18// seMdcReader wraps an io.Reader with a no-op Close method.
 19type seMdcReader struct {
 20	in io.Reader
 21}
 22
 23func (ser seMdcReader) Read(buf []byte) (int, error) {
 24	return ser.in.Read(buf)
 25}
 26
 27func (ser seMdcReader) Close() error {
 28	return nil
 29}
 30
 31func (se *SymmetricallyEncrypted) decryptMdc(c CipherFunction, key []byte) (io.ReadCloser, error) {
 32	if !c.IsSupported() {
 33		return nil, errors.UnsupportedError("unsupported cipher: " + strconv.Itoa(int(c)))
 34	}
 35
 36	if len(key) != c.KeySize() {
 37		return nil, errors.InvalidArgumentError("SymmetricallyEncrypted: incorrect key length")
 38	}
 39
 40	if se.prefix == nil {
 41		se.prefix = make([]byte, c.blockSize()+2)
 42		_, err := readFull(se.Contents, se.prefix)
 43		if err != nil {
 44			return nil, err
 45		}
 46	} else if len(se.prefix) != c.blockSize()+2 {
 47		return nil, errors.InvalidArgumentError("can't try ciphers with different block lengths")
 48	}
 49
 50	ocfbResync := OCFBResync
 51	if se.IntegrityProtected {
 52		// MDC packets use a different form of OCFB mode.
 53		ocfbResync = OCFBNoResync
 54	}
 55
 56	s := NewOCFBDecrypter(c.new(key), se.prefix, ocfbResync)
 57
 58	plaintext := cipher.StreamReader{S: s, R: se.Contents}
 59
 60	if se.IntegrityProtected {
 61		// IntegrityProtected packets have an embedded hash that we need to check.
 62		h := sha1.New()
 63		h.Write(se.prefix)
 64		return &seMDCReader{in: plaintext, h: h}, nil
 65	}
 66
 67	// Otherwise, we just need to wrap plaintext so that it's a valid ReadCloser.
 68	return seMdcReader{plaintext}, nil
 69}
 70
 71const mdcTrailerSize = 1 /* tag byte */ + 1 /* length byte */ + sha1.Size
 72
 73// An seMDCReader wraps an io.Reader, maintains a running hash and keeps hold
 74// of the most recent 22 bytes (mdcTrailerSize). Upon EOF, those bytes form an
 75// MDC packet containing a hash of the previous Contents which is checked
 76// against the running hash. See RFC 4880, section 5.13.
 77type seMDCReader struct {
 78	in          io.Reader
 79	h           hash.Hash
 80	trailer     [mdcTrailerSize]byte
 81	scratch     [mdcTrailerSize]byte
 82	trailerUsed int
 83	error       bool
 84	eof         bool
 85}
 86
 87func (ser *seMDCReader) Read(buf []byte) (n int, err error) {
 88	if ser.error {
 89		err = io.ErrUnexpectedEOF
 90		return
 91	}
 92	if ser.eof {
 93		err = io.EOF
 94		return
 95	}
 96
 97	// If we haven't yet filled the trailer buffer then we must do that
 98	// first.
 99	for ser.trailerUsed < mdcTrailerSize {
100		n, err = ser.in.Read(ser.trailer[ser.trailerUsed:])
101		ser.trailerUsed += n
102		if err == io.EOF {
103			if ser.trailerUsed != mdcTrailerSize {
104				n = 0
105				err = io.ErrUnexpectedEOF
106				ser.error = true
107				return
108			}
109			ser.eof = true
110			n = 0
111			return
112		}
113
114		if err != nil {
115			n = 0
116			return
117		}
118	}
119
120	// If it's a short read then we read into a temporary buffer and shift
121	// the data into the caller's buffer.
122	if len(buf) <= mdcTrailerSize {
123		n, err = readFull(ser.in, ser.scratch[:len(buf)])
124		copy(buf, ser.trailer[:n])
125		ser.h.Write(buf[:n])
126		copy(ser.trailer[:], ser.trailer[n:])
127		copy(ser.trailer[mdcTrailerSize-n:], ser.scratch[:])
128		if n < len(buf) {
129			ser.eof = true
130			err = io.EOF
131		}
132		return
133	}
134
135	n, err = ser.in.Read(buf[mdcTrailerSize:])
136	copy(buf, ser.trailer[:])
137	ser.h.Write(buf[:n])
138	copy(ser.trailer[:], buf[n:])
139
140	if err == io.EOF {
141		ser.eof = true
142	}
143	return
144}
145
146// This is a new-format packet tag byte for a type 19 (Integrity Protected) packet.
147const mdcPacketTagByte = byte(0x80) | 0x40 | 19
148
149func (ser *seMDCReader) Close() error {
150	if ser.error {
151		return errors.ErrMDCHashMismatch
152	}
153
154	for !ser.eof {
155		// We haven't seen EOF so we need to read to the end
156		var buf [1024]byte
157		_, err := ser.Read(buf[:])
158		if err == io.EOF {
159			break
160		}
161		if err != nil {
162			return errors.ErrMDCHashMismatch
163		}
164	}
165
166	ser.h.Write(ser.trailer[:2])
167
168	final := ser.h.Sum(nil)
169	if subtle.ConstantTimeCompare(final, ser.trailer[2:]) != 1 {
170		return errors.ErrMDCHashMismatch
171	}
172	// The hash already includes the MDC header, but we still check its value
173	// to confirm encryption correctness
174	if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size {
175		return errors.ErrMDCHashMismatch
176	}
177	return nil
178}
179
180// An seMDCWriter writes through to an io.WriteCloser while maintains a running
181// hash of the data written. On close, it emits an MDC packet containing the
182// running hash.
183type seMDCWriter struct {
184	w io.WriteCloser
185	h hash.Hash
186}
187
188func (w *seMDCWriter) Write(buf []byte) (n int, err error) {
189	w.h.Write(buf)
190	return w.w.Write(buf)
191}
192
193func (w *seMDCWriter) Close() (err error) {
194	var buf [mdcTrailerSize]byte
195
196	buf[0] = mdcPacketTagByte
197	buf[1] = sha1.Size
198	w.h.Write(buf[:2])
199	digest := w.h.Sum(nil)
200	copy(buf[2:], digest)
201
202	_, err = w.w.Write(buf[:])
203	if err != nil {
204		return
205	}
206	return w.w.Close()
207}
208
209// noOpCloser is like an ioutil.NopCloser, but for an io.Writer.
210type noOpCloser struct {
211	w io.Writer
212}
213
214func (c noOpCloser) Write(data []byte) (n int, err error) {
215	return c.w.Write(data)
216}
217
218func (c noOpCloser) Close() error {
219	return nil
220}
221
222func serializeSymmetricallyEncryptedMdc(ciphertext io.WriteCloser, c CipherFunction, key []byte, config *Config) (Contents io.WriteCloser, err error) {
223	// Disallow old cipher suites
224	if !c.IsSupported() || c < CipherAES128 {
225		return nil, errors.InvalidArgumentError("invalid mdc cipher function")
226	}
227
228	if c.KeySize() != len(key) {
229		return nil, errors.InvalidArgumentError("error in mdc serialization: bad key length")
230	}
231
232	_, err = ciphertext.Write([]byte{symmetricallyEncryptedVersionMdc})
233	if err != nil {
234		return
235	}
236
237	block := c.new(key)
238	blockSize := block.BlockSize()
239	iv := make([]byte, blockSize)
240	_, err = io.ReadFull(config.Random(), iv)
241	if err != nil {
242		return nil, err
243	}
244	s, prefix := NewOCFBEncrypter(block, iv, OCFBNoResync)
245	_, err = ciphertext.Write(prefix)
246	if err != nil {
247		return
248	}
249	plaintext := cipher.StreamWriter{S: s, W: ciphertext}
250
251	h := sha1.New()
252	h.Write(iv)
253	h.Write(iv[blockSize-2:])
254	Contents = &seMDCWriter{w: plaintext, h: h}
255	return
256}