main
Raw Download raw file
  1// Copyright 2013 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package ssh
  6
  7import (
  8	"encoding/binary"
  9	"fmt"
 10	"io"
 11	"log"
 12	"sync"
 13	"sync/atomic"
 14)
 15
 16// debugMux, if set, causes messages in the connection protocol to be
 17// logged.
 18const debugMux = false
 19
 20// chanList is a thread safe channel list.
 21type chanList struct {
 22	// protects concurrent access to chans
 23	sync.Mutex
 24
 25	// chans are indexed by the local id of the channel, which the
 26	// other side should send in the PeersId field.
 27	chans []*channel
 28
 29	// This is a debugging aid: it offsets all IDs by this
 30	// amount. This helps distinguish otherwise identical
 31	// server/client muxes
 32	offset uint32
 33}
 34
 35// Assigns a channel ID to the given channel.
 36func (c *chanList) add(ch *channel) uint32 {
 37	c.Lock()
 38	defer c.Unlock()
 39	for i := range c.chans {
 40		if c.chans[i] == nil {
 41			c.chans[i] = ch
 42			return uint32(i) + c.offset
 43		}
 44	}
 45	c.chans = append(c.chans, ch)
 46	return uint32(len(c.chans)-1) + c.offset
 47}
 48
 49// getChan returns the channel for the given ID.
 50func (c *chanList) getChan(id uint32) *channel {
 51	id -= c.offset
 52
 53	c.Lock()
 54	defer c.Unlock()
 55	if id < uint32(len(c.chans)) {
 56		return c.chans[id]
 57	}
 58	return nil
 59}
 60
 61func (c *chanList) remove(id uint32) {
 62	id -= c.offset
 63	c.Lock()
 64	if id < uint32(len(c.chans)) {
 65		c.chans[id] = nil
 66	}
 67	c.Unlock()
 68}
 69
 70// dropAll forgets all channels it knows, returning them in a slice.
 71func (c *chanList) dropAll() []*channel {
 72	c.Lock()
 73	defer c.Unlock()
 74	var r []*channel
 75
 76	for _, ch := range c.chans {
 77		if ch == nil {
 78			continue
 79		}
 80		r = append(r, ch)
 81	}
 82	c.chans = nil
 83	return r
 84}
 85
 86// mux represents the state for the SSH connection protocol, which
 87// multiplexes many channels onto a single packet transport.
 88type mux struct {
 89	conn     packetConn
 90	chanList chanList
 91
 92	incomingChannels chan NewChannel
 93
 94	globalSentMu     sync.Mutex
 95	globalResponses  chan interface{}
 96	incomingRequests chan *Request
 97
 98	errCond *sync.Cond
 99	err     error
100}
101
102// When debugging, each new chanList instantiation has a different
103// offset.
104var globalOff uint32
105
106func (m *mux) Wait() error {
107	m.errCond.L.Lock()
108	defer m.errCond.L.Unlock()
109	for m.err == nil {
110		m.errCond.Wait()
111	}
112	return m.err
113}
114
115// newMux returns a mux that runs over the given connection.
116func newMux(p packetConn) *mux {
117	m := &mux{
118		conn:             p,
119		incomingChannels: make(chan NewChannel, chanSize),
120		globalResponses:  make(chan interface{}, 1),
121		incomingRequests: make(chan *Request, chanSize),
122		errCond:          newCond(),
123	}
124	if debugMux {
125		m.chanList.offset = atomic.AddUint32(&globalOff, 1)
126	}
127
128	go m.loop()
129	return m
130}
131
132func (m *mux) sendMessage(msg interface{}) error {
133	p := Marshal(msg)
134	if debugMux {
135		log.Printf("send global(%d): %#v", m.chanList.offset, msg)
136	}
137	return m.conn.writePacket(p)
138}
139
140func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
141	if wantReply {
142		m.globalSentMu.Lock()
143		defer m.globalSentMu.Unlock()
144	}
145
146	if err := m.sendMessage(globalRequestMsg{
147		Type:      name,
148		WantReply: wantReply,
149		Data:      payload,
150	}); err != nil {
151		return false, nil, err
152	}
153
154	if !wantReply {
155		return false, nil, nil
156	}
157
158	msg, ok := <-m.globalResponses
159	if !ok {
160		return false, nil, io.EOF
161	}
162	switch msg := msg.(type) {
163	case *globalRequestFailureMsg:
164		return false, msg.Data, nil
165	case *globalRequestSuccessMsg:
166		return true, msg.Data, nil
167	default:
168		return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
169	}
170}
171
172// ackRequest must be called after processing a global request that
173// has WantReply set.
174func (m *mux) ackRequest(ok bool, data []byte) error {
175	if ok {
176		return m.sendMessage(globalRequestSuccessMsg{Data: data})
177	}
178	return m.sendMessage(globalRequestFailureMsg{Data: data})
179}
180
181func (m *mux) Close() error {
182	return m.conn.Close()
183}
184
185// loop runs the connection machine. It will process packets until an
186// error is encountered. To synchronize on loop exit, use mux.Wait.
187func (m *mux) loop() {
188	var err error
189	for err == nil {
190		err = m.onePacket()
191	}
192
193	for _, ch := range m.chanList.dropAll() {
194		ch.close()
195	}
196
197	close(m.incomingChannels)
198	close(m.incomingRequests)
199	close(m.globalResponses)
200
201	m.conn.Close()
202
203	m.errCond.L.Lock()
204	m.err = err
205	m.errCond.Broadcast()
206	m.errCond.L.Unlock()
207
208	if debugMux {
209		log.Println("loop exit", err)
210	}
211}
212
213// onePacket reads and processes one packet.
214func (m *mux) onePacket() error {
215	packet, err := m.conn.readPacket()
216	if err != nil {
217		return err
218	}
219
220	if debugMux {
221		if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
222			log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
223		} else {
224			p, _ := decode(packet)
225			log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
226		}
227	}
228
229	switch packet[0] {
230	case msgChannelOpen:
231		return m.handleChannelOpen(packet)
232	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
233		return m.handleGlobalPacket(packet)
234	case msgPing:
235		var msg pingMsg
236		if err := Unmarshal(packet, &msg); err != nil {
237			return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err)
238		}
239		return m.sendMessage(pongMsg(msg))
240	}
241
242	// assume a channel packet.
243	if len(packet) < 5 {
244		return parseError(packet[0])
245	}
246	id := binary.BigEndian.Uint32(packet[1:])
247	ch := m.chanList.getChan(id)
248	if ch == nil {
249		return m.handleUnknownChannelPacket(id, packet)
250	}
251
252	return ch.handlePacket(packet)
253}
254
255func (m *mux) handleGlobalPacket(packet []byte) error {
256	msg, err := decode(packet)
257	if err != nil {
258		return err
259	}
260
261	switch msg := msg.(type) {
262	case *globalRequestMsg:
263		m.incomingRequests <- &Request{
264			Type:      msg.Type,
265			WantReply: msg.WantReply,
266			Payload:   msg.Data,
267			mux:       m,
268		}
269	case *globalRequestSuccessMsg, *globalRequestFailureMsg:
270		m.globalResponses <- msg
271	default:
272		panic(fmt.Sprintf("not a global message %#v", msg))
273	}
274
275	return nil
276}
277
278// handleChannelOpen schedules a channel to be Accept()ed.
279func (m *mux) handleChannelOpen(packet []byte) error {
280	var msg channelOpenMsg
281	if err := Unmarshal(packet, &msg); err != nil {
282		return err
283	}
284
285	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
286		failMsg := channelOpenFailureMsg{
287			PeersID:  msg.PeersID,
288			Reason:   ConnectionFailed,
289			Message:  "invalid request",
290			Language: "en_US.UTF-8",
291		}
292		return m.sendMessage(failMsg)
293	}
294
295	c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
296	c.remoteId = msg.PeersID
297	c.maxRemotePayload = msg.MaxPacketSize
298	c.remoteWin.add(msg.PeersWindow)
299	m.incomingChannels <- c
300	return nil
301}
302
303func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
304	ch, err := m.openChannel(chanType, extra)
305	if err != nil {
306		return nil, nil, err
307	}
308
309	return ch, ch.incomingRequests, nil
310}
311
312func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
313	ch := m.newChannel(chanType, channelOutbound, extra)
314
315	ch.maxIncomingPayload = channelMaxPacket
316
317	open := channelOpenMsg{
318		ChanType:         chanType,
319		PeersWindow:      ch.myWindow,
320		MaxPacketSize:    ch.maxIncomingPayload,
321		TypeSpecificData: extra,
322		PeersID:          ch.localId,
323	}
324	if err := m.sendMessage(open); err != nil {
325		return nil, err
326	}
327
328	switch msg := (<-ch.msg).(type) {
329	case *channelOpenConfirmMsg:
330		return ch, nil
331	case *channelOpenFailureMsg:
332		return nil, &OpenChannelError{msg.Reason, msg.Message}
333	default:
334		return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
335	}
336}
337
338func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error {
339	msg, err := decode(packet)
340	if err != nil {
341		return err
342	}
343
344	switch msg := msg.(type) {
345	// RFC 4254 section 5.4 says unrecognized channel requests should
346	// receive a failure response.
347	case *channelRequestMsg:
348		if msg.WantReply {
349			return m.sendMessage(channelRequestFailureMsg{
350				PeersID: msg.PeersID,
351			})
352		}
353		return nil
354	default:
355		return fmt.Errorf("ssh: invalid channel %d", id)
356	}
357}