main
1// Copyright 2017 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package ecdh implements ECDH encryption, suitable for OpenPGP,
6// as specified in RFC 6637, section 8.
7package ecdh
8
9import (
10 "bytes"
11 "errors"
12 "io"
13
14 "github.com/ProtonMail/go-crypto/openpgp/aes/keywrap"
15 "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm"
16 "github.com/ProtonMail/go-crypto/openpgp/internal/ecc"
17)
18
19type KDF struct {
20 Hash algorithm.Hash
21 Cipher algorithm.Cipher
22}
23
24type PublicKey struct {
25 curve ecc.ECDHCurve
26 Point []byte
27 KDF
28}
29
30type PrivateKey struct {
31 PublicKey
32 D []byte
33}
34
35func NewPublicKey(curve ecc.ECDHCurve, kdfHash algorithm.Hash, kdfCipher algorithm.Cipher) *PublicKey {
36 return &PublicKey{
37 curve: curve,
38 KDF: KDF{
39 Hash: kdfHash,
40 Cipher: kdfCipher,
41 },
42 }
43}
44
45func NewPrivateKey(key PublicKey) *PrivateKey {
46 return &PrivateKey{
47 PublicKey: key,
48 }
49}
50
51func (pk *PublicKey) GetCurve() ecc.ECDHCurve {
52 return pk.curve
53}
54
55func (pk *PublicKey) MarshalPoint() []byte {
56 return pk.curve.MarshalBytePoint(pk.Point)
57}
58
59func (pk *PublicKey) UnmarshalPoint(p []byte) error {
60 pk.Point = pk.curve.UnmarshalBytePoint(p)
61 if pk.Point == nil {
62 return errors.New("ecdh: failed to parse EC point")
63 }
64 return nil
65}
66
67func (sk *PrivateKey) MarshalByteSecret() []byte {
68 return sk.curve.MarshalByteSecret(sk.D)
69}
70
71func (sk *PrivateKey) UnmarshalByteSecret(d []byte) error {
72 sk.D = sk.curve.UnmarshalByteSecret(d)
73
74 if sk.D == nil {
75 return errors.New("ecdh: failed to parse scalar")
76 }
77 return nil
78}
79
80func GenerateKey(rand io.Reader, c ecc.ECDHCurve, kdf KDF) (priv *PrivateKey, err error) {
81 priv = new(PrivateKey)
82 priv.PublicKey.curve = c
83 priv.PublicKey.KDF = kdf
84 priv.PublicKey.Point, priv.D, err = c.GenerateECDH(rand)
85 return
86}
87
88func Encrypt(random io.Reader, pub *PublicKey, msg, curveOID, fingerprint []byte) (vsG, c []byte, err error) {
89 if len(msg) > 40 {
90 return nil, nil, errors.New("ecdh: message too long")
91 }
92 // the sender MAY use 21, 13, and 5 bytes of padding for AES-128,
93 // AES-192, and AES-256, respectively, to provide the same number of
94 // octets, 40 total, as an input to the key wrapping method.
95 padding := make([]byte, 40-len(msg))
96 for i := range padding {
97 padding[i] = byte(40 - len(msg))
98 }
99 m := append(msg, padding...)
100
101 ephemeral, zb, err := pub.curve.Encaps(random, pub.Point)
102 if err != nil {
103 return nil, nil, err
104 }
105
106 vsG = pub.curve.MarshalBytePoint(ephemeral)
107
108 z, err := buildKey(pub, zb, curveOID, fingerprint, false, false)
109 if err != nil {
110 return nil, nil, err
111 }
112
113 if c, err = keywrap.Wrap(z, m); err != nil {
114 return nil, nil, err
115 }
116
117 return vsG, c, nil
118
119}
120
121func Decrypt(priv *PrivateKey, vsG, c, curveOID, fingerprint []byte) (msg []byte, err error) {
122 var m []byte
123 zb, err := priv.PublicKey.curve.Decaps(priv.curve.UnmarshalBytePoint(vsG), priv.D)
124
125 // Try buildKey three times to workaround an old bug, see comments in buildKey.
126 for i := 0; i < 3; i++ {
127 var z []byte
128 // RFC6637 §8: "Compute Z = KDF( S, Z_len, Param );"
129 z, err = buildKey(&priv.PublicKey, zb, curveOID, fingerprint, i == 1, i == 2)
130 if err != nil {
131 return nil, err
132 }
133
134 // RFC6637 §8: "Compute C = AESKeyWrap( Z, c ) as per [RFC3394]"
135 m, err = keywrap.Unwrap(z, c)
136 if err == nil {
137 break
138 }
139 }
140
141 // Only return an error after we've tried all (required) variants of buildKey.
142 if err != nil {
143 return nil, err
144 }
145
146 // RFC6637 §8: "m = symm_alg_ID || session key || checksum || pkcs5_padding"
147 // The last byte should be the length of the padding, as per PKCS5; strip it off.
148 return m[:len(m)-int(m[len(m)-1])], nil
149}
150
151func buildKey(pub *PublicKey, zb []byte, curveOID, fingerprint []byte, stripLeading, stripTrailing bool) ([]byte, error) {
152 // Param = curve_OID_len || curve_OID || public_key_alg_ID || 03
153 // || 01 || KDF_hash_ID || KEK_alg_ID for AESKeyWrap
154 // || "Anonymous Sender " || recipient_fingerprint;
155 param := new(bytes.Buffer)
156 if _, err := param.Write(curveOID); err != nil {
157 return nil, err
158 }
159 algKDF := []byte{18, 3, 1, pub.KDF.Hash.Id(), pub.KDF.Cipher.Id()}
160 if _, err := param.Write(algKDF); err != nil {
161 return nil, err
162 }
163 if _, err := param.Write([]byte("Anonymous Sender ")); err != nil {
164 return nil, err
165 }
166 if _, err := param.Write(fingerprint[:]); err != nil {
167 return nil, err
168 }
169
170 // MB = Hash ( 00 || 00 || 00 || 01 || ZB || Param );
171 h := pub.KDF.Hash.New()
172 if _, err := h.Write([]byte{0x0, 0x0, 0x0, 0x1}); err != nil {
173 return nil, err
174 }
175 zbLen := len(zb)
176 i := 0
177 j := zbLen - 1
178 if stripLeading {
179 // Work around old go crypto bug where the leading zeros are missing.
180 for i < zbLen && zb[i] == 0 {
181 i++
182 }
183 }
184 if stripTrailing {
185 // Work around old OpenPGP.js bug where insignificant trailing zeros in
186 // this little-endian number are missing.
187 // (See https://github.com/openpgpjs/openpgpjs/pull/853.)
188 for j >= 0 && zb[j] == 0 {
189 j--
190 }
191 }
192 if _, err := h.Write(zb[i : j+1]); err != nil {
193 return nil, err
194 }
195 if _, err := h.Write(param.Bytes()); err != nil {
196 return nil, err
197 }
198 mb := h.Sum(nil)
199
200 return mb[:pub.KDF.Cipher.KeySize()], nil // return oBits leftmost bits of MB.
201
202}
203
204func Validate(priv *PrivateKey) error {
205 return priv.curve.ValidateECDH(priv.Point, priv.D)
206}