main
Raw Download raw file
  1// Copyright (C) 2019 ProtonTech AG
  2
  3// Package ocb provides an implementation of the OCB (offset codebook) mode of
  4// operation, as described in RFC-7253 of the IRTF and in Rogaway, Bellare,
  5// Black and Krovetz - OCB: A BLOCK-CIPHER MODE OF OPERATION FOR EFFICIENT
  6// AUTHENTICATED ENCRYPTION (2003).
  7// Security considerations (from RFC-7253): A private key MUST NOT be used to
  8// encrypt more than 2^48 blocks. Tag length should be at least 12 bytes (a
  9// brute-force forging adversary succeeds after 2^{tag length} attempts). A
 10// single key SHOULD NOT be used to decrypt ciphertext with different tag
 11// lengths. Nonces need not be secret, but MUST NOT be reused.
 12// This package only supports underlying block ciphers with 128-bit blocks,
 13// such as AES-{128, 192, 256}, but may be extended to other sizes.
 14package ocb
 15
 16import (
 17	"bytes"
 18	"crypto/cipher"
 19	"crypto/subtle"
 20	"errors"
 21	"math/bits"
 22
 23	"github.com/ProtonMail/go-crypto/internal/byteutil"
 24)
 25
 26type ocb struct {
 27	block     cipher.Block
 28	tagSize   int
 29	nonceSize int
 30	mask      mask
 31	// Optimized en/decrypt: For each nonce N used to en/decrypt, the 'Ktop'
 32	// internal variable can be reused for en/decrypting with nonces sharing
 33	// all but the last 6 bits with N. The prefix of the first nonce used to
 34	// compute the new Ktop, and the Ktop value itself, are stored in
 35	// reusableKtop. If using incremental nonces, this saves one block cipher
 36	// call every 63 out of 64 OCB encryptions, and stores one nonce and one
 37	// output of the block cipher in memory only.
 38	reusableKtop reusableKtop
 39}
 40
 41type mask struct {
 42	// L_*, L_$, (L_i)_{i ∈ N}
 43	lAst []byte
 44	lDol []byte
 45	L    [][]byte
 46}
 47
 48type reusableKtop struct {
 49	noncePrefix []byte
 50	Ktop        []byte
 51}
 52
 53const (
 54	defaultTagSize   = 16
 55	defaultNonceSize = 15
 56)
 57
 58const (
 59	enc = iota
 60	dec
 61)
 62
 63func (o *ocb) NonceSize() int {
 64	return o.nonceSize
 65}
 66
 67func (o *ocb) Overhead() int {
 68	return o.tagSize
 69}
 70
 71// NewOCB returns an OCB instance with the given block cipher and default
 72// tag and nonce sizes.
 73func NewOCB(block cipher.Block) (cipher.AEAD, error) {
 74	return NewOCBWithNonceAndTagSize(block, defaultNonceSize, defaultTagSize)
 75}
 76
 77// NewOCBWithNonceAndTagSize returns an OCB instance with the given block
 78// cipher, nonce length, and tag length. Panics on zero nonceSize and
 79// exceedingly long tag size.
 80//
 81// It is recommended to use at least 12 bytes as tag length.
 82func NewOCBWithNonceAndTagSize(
 83	block cipher.Block, nonceSize, tagSize int) (cipher.AEAD, error) {
 84	if block.BlockSize() != 16 {
 85		return nil, ocbError("Block cipher must have 128-bit blocks")
 86	}
 87	if nonceSize < 1 {
 88		return nil, ocbError("Incorrect nonce length")
 89	}
 90	if nonceSize >= block.BlockSize() {
 91		return nil, ocbError("Nonce length exceeds blocksize - 1")
 92	}
 93	if tagSize > block.BlockSize() {
 94		return nil, ocbError("Custom tag length exceeds blocksize")
 95	}
 96	return &ocb{
 97		block:     block,
 98		tagSize:   tagSize,
 99		nonceSize: nonceSize,
100		mask:      initializeMaskTable(block),
101		reusableKtop: reusableKtop{
102			noncePrefix: nil,
103			Ktop:        nil,
104		},
105	}, nil
106}
107
108func (o *ocb) Seal(dst, nonce, plaintext, adata []byte) []byte {
109	if len(nonce) > o.nonceSize {
110		panic("crypto/ocb: Incorrect nonce length given to OCB")
111	}
112	sep := len(plaintext)
113	ret, out := byteutil.SliceForAppend(dst, sep+o.tagSize)
114	tag := o.crypt(enc, out[:sep], nonce, adata, plaintext)
115	copy(out[sep:], tag)
116	return ret
117}
118
119func (o *ocb) Open(dst, nonce, ciphertext, adata []byte) ([]byte, error) {
120	if len(nonce) > o.nonceSize {
121		panic("Nonce too long for this instance")
122	}
123	if len(ciphertext) < o.tagSize {
124		return nil, ocbError("Ciphertext shorter than tag length")
125	}
126	sep := len(ciphertext) - o.tagSize
127	ret, out := byteutil.SliceForAppend(dst, sep)
128	ciphertextData := ciphertext[:sep]
129	tag := o.crypt(dec, out, nonce, adata, ciphertextData)
130	if subtle.ConstantTimeCompare(tag, ciphertext[sep:]) == 1 {
131		return ret, nil
132	}
133	for i := range out {
134		out[i] = 0
135	}
136	return nil, ocbError("Tag authentication failed")
137}
138
139// On instruction enc (resp. dec), crypt is the encrypt (resp. decrypt)
140// function. It writes the resulting plain/ciphertext into Y and returns
141// the tag.
142func (o *ocb) crypt(instruction int, Y, nonce, adata, X []byte) []byte {
143	//
144	// Consider X as a sequence of 128-bit blocks
145	//
146	// Note: For encryption (resp. decryption), X is the plaintext (resp., the
147	// ciphertext without the tag).
148	blockSize := o.block.BlockSize()
149
150	//
151	// Nonce-dependent and per-encryption variables
152	//
153	// Zero out the last 6 bits of the nonce into truncatedNonce to see if Ktop
154	// is already computed.
155	truncatedNonce := make([]byte, len(nonce))
156	copy(truncatedNonce, nonce)
157	truncatedNonce[len(truncatedNonce)-1] &= 192
158	var Ktop []byte
159	if bytes.Equal(truncatedNonce, o.reusableKtop.noncePrefix) {
160		Ktop = o.reusableKtop.Ktop
161	} else {
162		// Nonce = num2str(TAGLEN mod 128, 7) || zeros(120 - bitlen(N)) || 1 || N
163		paddedNonce := append(make([]byte, blockSize-1-len(nonce)), 1)
164		paddedNonce = append(paddedNonce, truncatedNonce...)
165		paddedNonce[0] |= byte(((8 * o.tagSize) % (8 * blockSize)) << 1)
166		// Last 6 bits of paddedNonce are already zero. Encrypt into Ktop
167		paddedNonce[blockSize-1] &= 192
168		Ktop = paddedNonce
169		o.block.Encrypt(Ktop, Ktop)
170		o.reusableKtop.noncePrefix = truncatedNonce
171		o.reusableKtop.Ktop = Ktop
172	}
173
174	// Stretch = Ktop || ((lower half of Ktop) XOR (lower half of Ktop << 8))
175	xorHalves := make([]byte, blockSize/2)
176	byteutil.XorBytes(xorHalves, Ktop[:blockSize/2], Ktop[1:1+blockSize/2])
177	stretch := append(Ktop, xorHalves...)
178	bottom := int(nonce[len(nonce)-1] & 63)
179	offset := make([]byte, len(stretch))
180	byteutil.ShiftNBytesLeft(offset, stretch, bottom)
181	offset = offset[:blockSize]
182
183	//
184	// Process any whole blocks
185	//
186	// Note: For encryption Y is ciphertext || tag, for decryption Y is
187	// plaintext || tag.
188	checksum := make([]byte, blockSize)
189	m := len(X) / blockSize
190	for i := 0; i < m; i++ {
191		index := bits.TrailingZeros(uint(i + 1))
192		if len(o.mask.L)-1 < index {
193			o.mask.extendTable(index)
194		}
195		byteutil.XorBytesMut(offset, o.mask.L[bits.TrailingZeros(uint(i+1))])
196		blockX := X[i*blockSize : (i+1)*blockSize]
197		blockY := Y[i*blockSize : (i+1)*blockSize]
198		switch instruction {
199		case enc:
200			byteutil.XorBytesMut(checksum, blockX)
201			byteutil.XorBytes(blockY, blockX, offset)
202			o.block.Encrypt(blockY, blockY)
203			byteutil.XorBytesMut(blockY, offset)
204		case dec:
205			byteutil.XorBytes(blockY, blockX, offset)
206			o.block.Decrypt(blockY, blockY)
207			byteutil.XorBytesMut(blockY, offset)
208			byteutil.XorBytesMut(checksum, blockY)
209		}
210	}
211	//
212	// Process any final partial block and compute raw tag
213	//
214	tag := make([]byte, blockSize)
215	if len(X)%blockSize != 0 {
216		byteutil.XorBytesMut(offset, o.mask.lAst)
217		pad := make([]byte, blockSize)
218		o.block.Encrypt(pad, offset)
219		chunkX := X[blockSize*m:]
220		chunkY := Y[blockSize*m : len(X)]
221		switch instruction {
222		case enc:
223			byteutil.XorBytesMut(checksum, chunkX)
224			checksum[len(chunkX)] ^= 128
225			byteutil.XorBytes(chunkY, chunkX, pad[:len(chunkX)])
226			// P_* || bit(1) || zeroes(127) - len(P_*)
227		case dec:
228			byteutil.XorBytes(chunkY, chunkX, pad[:len(chunkX)])
229			// P_* || bit(1) || zeroes(127) - len(P_*)
230			byteutil.XorBytesMut(checksum, chunkY)
231			checksum[len(chunkY)] ^= 128
232		}
233	}
234	byteutil.XorBytes(tag, checksum, offset)
235	byteutil.XorBytesMut(tag, o.mask.lDol)
236	o.block.Encrypt(tag, tag)
237	byteutil.XorBytesMut(tag, o.hash(adata))
238	return tag[:o.tagSize]
239}
240
241// This hash function is used to compute the tag. Per design, on empty input it
242// returns a slice of zeros, of the same length as the underlying block cipher
243// block size.
244func (o *ocb) hash(adata []byte) []byte {
245	//
246	// Consider A as a sequence of 128-bit blocks
247	//
248	A := make([]byte, len(adata))
249	copy(A, adata)
250	blockSize := o.block.BlockSize()
251
252	//
253	// Process any whole blocks
254	//
255	sum := make([]byte, blockSize)
256	offset := make([]byte, blockSize)
257	m := len(A) / blockSize
258	for i := 0; i < m; i++ {
259		chunk := A[blockSize*i : blockSize*(i+1)]
260		index := bits.TrailingZeros(uint(i + 1))
261		// If the mask table is too short
262		if len(o.mask.L)-1 < index {
263			o.mask.extendTable(index)
264		}
265		byteutil.XorBytesMut(offset, o.mask.L[index])
266		byteutil.XorBytesMut(chunk, offset)
267		o.block.Encrypt(chunk, chunk)
268		byteutil.XorBytesMut(sum, chunk)
269	}
270
271	//
272	// Process any final partial block; compute final hash value
273	//
274	if len(A)%blockSize != 0 {
275		byteutil.XorBytesMut(offset, o.mask.lAst)
276		// Pad block with 1 || 0 ^ 127 - bitlength(a)
277		ending := make([]byte, blockSize-len(A)%blockSize)
278		ending[0] = 0x80
279		encrypted := append(A[blockSize*m:], ending...)
280		byteutil.XorBytesMut(encrypted, offset)
281		o.block.Encrypt(encrypted, encrypted)
282		byteutil.XorBytesMut(sum, encrypted)
283	}
284	return sum
285}
286
287func initializeMaskTable(block cipher.Block) mask {
288	//
289	// Key-dependent variables
290	//
291	lAst := make([]byte, block.BlockSize())
292	block.Encrypt(lAst, lAst)
293	lDol := byteutil.GfnDouble(lAst)
294	L := make([][]byte, 1)
295	L[0] = byteutil.GfnDouble(lDol)
296
297	return mask{
298		lAst: lAst,
299		lDol: lDol,
300		L:    L,
301	}
302}
303
304// Extends the L array of mask m up to L[limit], with L[i] = GfnDouble(L[i-1])
305func (m *mask) extendTable(limit int) {
306	for i := len(m.L); i <= limit; i++ {
307		m.L = append(m.L, byteutil.GfnDouble(m.L[i-1]))
308	}
309}
310
311func ocbError(err string) error {
312	return errors.New("crypto/ocb: " + err)
313}