main
Raw Download raw file
  1// Copyright 2011 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package ssh
  6
  7import (
  8	"bufio"
  9	"bytes"
 10	"errors"
 11	"io"
 12	"log"
 13)
 14
 15// debugTransport if set, will print packet types as they go over the
 16// wire. No message decoding is done, to minimize the impact on timing.
 17const debugTransport = false
 18
 19const (
 20	gcm128CipherID = "aes128-gcm@openssh.com"
 21	gcm256CipherID = "aes256-gcm@openssh.com"
 22	aes128cbcID    = "aes128-cbc"
 23	tripledescbcID = "3des-cbc"
 24)
 25
 26// packetConn represents a transport that implements packet based
 27// operations.
 28type packetConn interface {
 29	// Encrypt and send a packet of data to the remote peer.
 30	writePacket(packet []byte) error
 31
 32	// Read a packet from the connection. The read is blocking,
 33	// i.e. if error is nil, then the returned byte slice is
 34	// always non-empty.
 35	readPacket() ([]byte, error)
 36
 37	// Close closes the write-side of the connection.
 38	Close() error
 39}
 40
 41// transport is the keyingTransport that implements the SSH packet
 42// protocol.
 43type transport struct {
 44	reader connectionState
 45	writer connectionState
 46
 47	bufReader *bufio.Reader
 48	bufWriter *bufio.Writer
 49	rand      io.Reader
 50	isClient  bool
 51	io.Closer
 52
 53	strictMode     bool
 54	initialKEXDone bool
 55}
 56
 57// packetCipher represents a combination of SSH encryption/MAC
 58// protocol.  A single instance should be used for one direction only.
 59type packetCipher interface {
 60	// writeCipherPacket encrypts the packet and writes it to w. The
 61	// contents of the packet are generally scrambled.
 62	writeCipherPacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error
 63
 64	// readCipherPacket reads and decrypts a packet of data. The
 65	// returned packet may be overwritten by future calls of
 66	// readPacket.
 67	readCipherPacket(seqnum uint32, r io.Reader) ([]byte, error)
 68}
 69
 70// connectionState represents one side (read or write) of the
 71// connection. This is necessary because each direction has its own
 72// keys, and can even have its own algorithms
 73type connectionState struct {
 74	packetCipher
 75	seqNum           uint32
 76	dir              direction
 77	pendingKeyChange chan packetCipher
 78}
 79
 80func (t *transport) setStrictMode() error {
 81	if t.reader.seqNum != 1 {
 82		return errors.New("ssh: sequence number != 1 when strict KEX mode requested")
 83	}
 84	t.strictMode = true
 85	return nil
 86}
 87
 88func (t *transport) setInitialKEXDone() {
 89	t.initialKEXDone = true
 90}
 91
 92// prepareKeyChange sets up key material for a keychange. The key changes in
 93// both directions are triggered by reading and writing a msgNewKey packet
 94// respectively.
 95func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
 96	ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult)
 97	if err != nil {
 98		return err
 99	}
100	t.reader.pendingKeyChange <- ciph
101
102	ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult)
103	if err != nil {
104		return err
105	}
106	t.writer.pendingKeyChange <- ciph
107
108	return nil
109}
110
111func (t *transport) printPacket(p []byte, write bool) {
112	if len(p) == 0 {
113		return
114	}
115	who := "server"
116	if t.isClient {
117		who = "client"
118	}
119	what := "read"
120	if write {
121		what = "write"
122	}
123
124	log.Println(what, who, p[0])
125}
126
127// Read and decrypt next packet.
128func (t *transport) readPacket() (p []byte, err error) {
129	for {
130		p, err = t.reader.readPacket(t.bufReader, t.strictMode)
131		if err != nil {
132			break
133		}
134		// in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX
135		if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) {
136			break
137		}
138	}
139	if debugTransport {
140		t.printPacket(p, false)
141	}
142
143	return p, err
144}
145
146func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) {
147	packet, err := s.packetCipher.readCipherPacket(s.seqNum, r)
148	s.seqNum++
149	if err == nil && len(packet) == 0 {
150		err = errors.New("ssh: zero length packet")
151	}
152
153	if len(packet) > 0 {
154		switch packet[0] {
155		case msgNewKeys:
156			select {
157			case cipher := <-s.pendingKeyChange:
158				s.packetCipher = cipher
159				if strictMode {
160					s.seqNum = 0
161				}
162			default:
163				return nil, errors.New("ssh: got bogus newkeys message")
164			}
165
166		case msgDisconnect:
167			// Transform a disconnect message into an
168			// error. Since this is lowest level at which
169			// we interpret message types, doing it here
170			// ensures that we don't have to handle it
171			// elsewhere.
172			var msg disconnectMsg
173			if err := Unmarshal(packet, &msg); err != nil {
174				return nil, err
175			}
176			return nil, &msg
177		}
178	}
179
180	// The packet may point to an internal buffer, so copy the
181	// packet out here.
182	fresh := make([]byte, len(packet))
183	copy(fresh, packet)
184
185	return fresh, err
186}
187
188func (t *transport) writePacket(packet []byte) error {
189	if debugTransport {
190		t.printPacket(packet, true)
191	}
192	return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode)
193}
194
195func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error {
196	changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
197
198	err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet)
199	if err != nil {
200		return err
201	}
202	if err = w.Flush(); err != nil {
203		return err
204	}
205	s.seqNum++
206	if changeKeys {
207		select {
208		case cipher := <-s.pendingKeyChange:
209			s.packetCipher = cipher
210			if strictMode {
211				s.seqNum = 0
212			}
213		default:
214			panic("ssh: no key material for msgNewKeys")
215		}
216	}
217	return err
218}
219
220func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport {
221	t := &transport{
222		bufReader: bufio.NewReader(rwc),
223		bufWriter: bufio.NewWriter(rwc),
224		rand:      rand,
225		reader: connectionState{
226			packetCipher:     &streamPacketCipher{cipher: noneCipher{}},
227			pendingKeyChange: make(chan packetCipher, 1),
228		},
229		writer: connectionState{
230			packetCipher:     &streamPacketCipher{cipher: noneCipher{}},
231			pendingKeyChange: make(chan packetCipher, 1),
232		},
233		Closer: rwc,
234	}
235	t.isClient = isClient
236
237	if isClient {
238		t.reader.dir = serverKeys
239		t.writer.dir = clientKeys
240	} else {
241		t.reader.dir = clientKeys
242		t.writer.dir = serverKeys
243	}
244
245	return t
246}
247
248type direction struct {
249	ivTag     []byte
250	keyTag    []byte
251	macKeyTag []byte
252}
253
254var (
255	serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
256	clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
257)
258
259// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
260// described in RFC 4253, section 6.4. direction should either be serverKeys
261// (to setup server->client keys) or clientKeys (for client->server keys).
262func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
263	cipherMode := cipherModes[algs.Cipher]
264
265	iv := make([]byte, cipherMode.ivSize)
266	key := make([]byte, cipherMode.keySize)
267
268	generateKeyMaterial(iv, d.ivTag, kex)
269	generateKeyMaterial(key, d.keyTag, kex)
270
271	var macKey []byte
272	if !aeadCiphers[algs.Cipher] {
273		macMode := macModes[algs.MAC]
274		macKey = make([]byte, macMode.keySize)
275		generateKeyMaterial(macKey, d.macKeyTag, kex)
276	}
277
278	return cipherModes[algs.Cipher].create(key, iv, macKey, algs)
279}
280
281// generateKeyMaterial fills out with key material generated from tag, K, H
282// and sessionId, as specified in RFC 4253, section 7.2.
283func generateKeyMaterial(out, tag []byte, r *kexResult) {
284	var digestsSoFar []byte
285
286	h := r.Hash.New()
287	for len(out) > 0 {
288		h.Reset()
289		h.Write(r.K)
290		h.Write(r.H)
291
292		if len(digestsSoFar) == 0 {
293			h.Write(tag)
294			h.Write(r.SessionID)
295		} else {
296			h.Write(digestsSoFar)
297		}
298
299		digest := h.Sum(nil)
300		n := copy(out, digest)
301		out = out[n:]
302		if len(out) > 0 {
303			digestsSoFar = append(digestsSoFar, digest...)
304		}
305	}
306}
307
308const packageVersion = "SSH-2.0-Go"
309
310// Sends and receives a version line.  The versionLine string should
311// be US ASCII, start with "SSH-2.0-", and should not include a
312// newline. exchangeVersions returns the other side's version line.
313func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) {
314	// Contrary to the RFC, we do not ignore lines that don't
315	// start with "SSH-2.0-" to make the library usable with
316	// nonconforming servers.
317	for _, c := range versionLine {
318		// The spec disallows non US-ASCII chars, and
319		// specifically forbids null chars.
320		if c < 32 {
321			return nil, errors.New("ssh: junk character in version line")
322		}
323	}
324	if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil {
325		return
326	}
327
328	them, err = readVersion(rw)
329	return them, err
330}
331
332// maxVersionStringBytes is the maximum number of bytes that we'll
333// accept as a version string. RFC 4253 section 4.2 limits this at 255
334// chars
335const maxVersionStringBytes = 255
336
337// Read version string as specified by RFC 4253, section 4.2.
338func readVersion(r io.Reader) ([]byte, error) {
339	versionString := make([]byte, 0, 64)
340	var ok bool
341	var buf [1]byte
342
343	for length := 0; length < maxVersionStringBytes; length++ {
344		_, err := io.ReadFull(r, buf[:])
345		if err != nil {
346			return nil, err
347		}
348		// The RFC says that the version should be terminated with \r\n
349		// but several SSH servers actually only send a \n.
350		if buf[0] == '\n' {
351			if !bytes.HasPrefix(versionString, []byte("SSH-")) {
352				// RFC 4253 says we need to ignore all version string lines
353				// except the one containing the SSH version (provided that
354				// all the lines do not exceed 255 bytes in total).
355				versionString = versionString[:0]
356				continue
357			}
358			ok = true
359			break
360		}
361
362		// non ASCII chars are disallowed, but we are lenient,
363		// since Go doesn't use null-terminated strings.
364
365		// The RFC allows a comment after a space, however,
366		// all of it (version and comments) goes into the
367		// session hash.
368		versionString = append(versionString, buf[0])
369	}
370
371	if !ok {
372		return nil, errors.New("ssh: overflow reading version string")
373	}
374
375	// There might be a '\r' on the end which we should remove.
376	if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' {
377		versionString = versionString[:len(versionString)-1]
378	}
379	return versionString, nil
380}