main
Raw Download raw file
  1// Copyright 2011 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package ssh
  6
  7import (
  8	"context"
  9	"errors"
 10	"fmt"
 11	"io"
 12	"math/rand"
 13	"net"
 14	"strconv"
 15	"strings"
 16	"sync"
 17	"time"
 18)
 19
 20// Listen requests the remote peer open a listening socket on
 21// addr. Incoming connections will be available by calling Accept on
 22// the returned net.Listener. The listener must be serviced, or the
 23// SSH connection may hang.
 24// N must be "tcp", "tcp4", "tcp6", or "unix".
 25func (c *Client) Listen(n, addr string) (net.Listener, error) {
 26	switch n {
 27	case "tcp", "tcp4", "tcp6":
 28		laddr, err := net.ResolveTCPAddr(n, addr)
 29		if err != nil {
 30			return nil, err
 31		}
 32		return c.ListenTCP(laddr)
 33	case "unix":
 34		return c.ListenUnix(addr)
 35	default:
 36		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
 37	}
 38}
 39
 40// Automatic port allocation is broken with OpenSSH before 6.0. See
 41// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017.  In
 42// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0,
 43// rather than the actual port number. This means you can never open
 44// two different listeners with auto allocated ports. We work around
 45// this by trying explicit ports until we succeed.
 46
 47const openSSHPrefix = "OpenSSH_"
 48
 49var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano()))
 50
 51// isBrokenOpenSSHVersion returns true if the given version string
 52// specifies a version of OpenSSH that is known to have a bug in port
 53// forwarding.
 54func isBrokenOpenSSHVersion(versionStr string) bool {
 55	i := strings.Index(versionStr, openSSHPrefix)
 56	if i < 0 {
 57		return false
 58	}
 59	i += len(openSSHPrefix)
 60	j := i
 61	for ; j < len(versionStr); j++ {
 62		if versionStr[j] < '0' || versionStr[j] > '9' {
 63			break
 64		}
 65	}
 66	version, _ := strconv.Atoi(versionStr[i:j])
 67	return version < 6
 68}
 69
 70// autoPortListenWorkaround simulates automatic port allocation by
 71// trying random ports repeatedly.
 72func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
 73	var sshListener net.Listener
 74	var err error
 75	const tries = 10
 76	for i := 0; i < tries; i++ {
 77		addr := *laddr
 78		addr.Port = 1024 + portRandomizer.Intn(60000)
 79		sshListener, err = c.ListenTCP(&addr)
 80		if err == nil {
 81			laddr.Port = addr.Port
 82			return sshListener, err
 83		}
 84	}
 85	return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err)
 86}
 87
 88// RFC 4254 7.1
 89type channelForwardMsg struct {
 90	addr  string
 91	rport uint32
 92}
 93
 94// handleForwards starts goroutines handling forwarded connections.
 95// It's called on first use by (*Client).ListenTCP to not launch
 96// goroutines until needed.
 97func (c *Client) handleForwards() {
 98	go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip"))
 99	go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
100}
101
102// ListenTCP requests the remote peer open a listening socket
103// on laddr. Incoming connections will be available by calling
104// Accept on the returned net.Listener.
105func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
106	c.handleForwardsOnce.Do(c.handleForwards)
107	if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
108		return c.autoPortListenWorkaround(laddr)
109	}
110
111	m := channelForwardMsg{
112		laddr.IP.String(),
113		uint32(laddr.Port),
114	}
115	// send message
116	ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
117	if err != nil {
118		return nil, err
119	}
120	if !ok {
121		return nil, errors.New("ssh: tcpip-forward request denied by peer")
122	}
123
124	// If the original port was 0, then the remote side will
125	// supply a real port number in the response.
126	if laddr.Port == 0 {
127		var p struct {
128			Port uint32
129		}
130		if err := Unmarshal(resp, &p); err != nil {
131			return nil, err
132		}
133		laddr.Port = int(p.Port)
134	}
135
136	// Register this forward, using the port number we obtained.
137	ch := c.forwards.add(laddr)
138
139	return &tcpListener{laddr, c, ch}, nil
140}
141
142// forwardList stores a mapping between remote
143// forward requests and the tcpListeners.
144type forwardList struct {
145	sync.Mutex
146	entries []forwardEntry
147}
148
149// forwardEntry represents an established mapping of a laddr on a
150// remote ssh server to a channel connected to a tcpListener.
151type forwardEntry struct {
152	laddr net.Addr
153	c     chan forward
154}
155
156// forward represents an incoming forwarded tcpip connection. The
157// arguments to add/remove/lookup should be address as specified in
158// the original forward-request.
159type forward struct {
160	newCh NewChannel // the ssh client channel underlying this forward
161	raddr net.Addr   // the raddr of the incoming connection
162}
163
164func (l *forwardList) add(addr net.Addr) chan forward {
165	l.Lock()
166	defer l.Unlock()
167	f := forwardEntry{
168		laddr: addr,
169		c:     make(chan forward, 1),
170	}
171	l.entries = append(l.entries, f)
172	return f.c
173}
174
175// See RFC 4254, section 7.2
176type forwardedTCPPayload struct {
177	Addr       string
178	Port       uint32
179	OriginAddr string
180	OriginPort uint32
181}
182
183// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
184func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
185	if port == 0 || port > 65535 {
186		return nil, fmt.Errorf("ssh: port number out of range: %d", port)
187	}
188	ip := net.ParseIP(string(addr))
189	if ip == nil {
190		return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
191	}
192	return &net.TCPAddr{IP: ip, Port: int(port)}, nil
193}
194
195func (l *forwardList) handleChannels(in <-chan NewChannel) {
196	for ch := range in {
197		var (
198			laddr net.Addr
199			raddr net.Addr
200			err   error
201		)
202		switch channelType := ch.ChannelType(); channelType {
203		case "forwarded-tcpip":
204			var payload forwardedTCPPayload
205			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
206				ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
207				continue
208			}
209
210			// RFC 4254 section 7.2 specifies that incoming
211			// addresses should list the address, in string
212			// format. It is implied that this should be an IP
213			// address, as it would be impossible to connect to it
214			// otherwise.
215			laddr, err = parseTCPAddr(payload.Addr, payload.Port)
216			if err != nil {
217				ch.Reject(ConnectionFailed, err.Error())
218				continue
219			}
220			raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
221			if err != nil {
222				ch.Reject(ConnectionFailed, err.Error())
223				continue
224			}
225
226		case "forwarded-streamlocal@openssh.com":
227			var payload forwardedStreamLocalPayload
228			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
229				ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
230				continue
231			}
232			laddr = &net.UnixAddr{
233				Name: payload.SocketPath,
234				Net:  "unix",
235			}
236			raddr = &net.UnixAddr{
237				Name: "@",
238				Net:  "unix",
239			}
240		default:
241			panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
242		}
243		if ok := l.forward(laddr, raddr, ch); !ok {
244			// Section 7.2, implementations MUST reject spurious incoming
245			// connections.
246			ch.Reject(Prohibited, "no forward for address")
247			continue
248		}
249
250	}
251}
252
253// remove removes the forward entry, and the channel feeding its
254// listener.
255func (l *forwardList) remove(addr net.Addr) {
256	l.Lock()
257	defer l.Unlock()
258	for i, f := range l.entries {
259		if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
260			l.entries = append(l.entries[:i], l.entries[i+1:]...)
261			close(f.c)
262			return
263		}
264	}
265}
266
267// closeAll closes and clears all forwards.
268func (l *forwardList) closeAll() {
269	l.Lock()
270	defer l.Unlock()
271	for _, f := range l.entries {
272		close(f.c)
273	}
274	l.entries = nil
275}
276
277func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
278	l.Lock()
279	defer l.Unlock()
280	for _, f := range l.entries {
281		if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
282			f.c <- forward{newCh: ch, raddr: raddr}
283			return true
284		}
285	}
286	return false
287}
288
289type tcpListener struct {
290	laddr *net.TCPAddr
291
292	conn *Client
293	in   <-chan forward
294}
295
296// Accept waits for and returns the next connection to the listener.
297func (l *tcpListener) Accept() (net.Conn, error) {
298	s, ok := <-l.in
299	if !ok {
300		return nil, io.EOF
301	}
302	ch, incoming, err := s.newCh.Accept()
303	if err != nil {
304		return nil, err
305	}
306	go DiscardRequests(incoming)
307
308	return &chanConn{
309		Channel: ch,
310		laddr:   l.laddr,
311		raddr:   s.raddr,
312	}, nil
313}
314
315// Close closes the listener.
316func (l *tcpListener) Close() error {
317	m := channelForwardMsg{
318		l.laddr.IP.String(),
319		uint32(l.laddr.Port),
320	}
321
322	// this also closes the listener.
323	l.conn.forwards.remove(l.laddr)
324	ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
325	if err == nil && !ok {
326		err = errors.New("ssh: cancel-tcpip-forward failed")
327	}
328	return err
329}
330
331// Addr returns the listener's network address.
332func (l *tcpListener) Addr() net.Addr {
333	return l.laddr
334}
335
336// DialContext initiates a connection to the addr from the remote host.
337//
338// The provided Context must be non-nil. If the context expires before the
339// connection is complete, an error is returned. Once successfully connected,
340// any expiration of the context will not affect the connection.
341//
342// See func Dial for additional information.
343func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
344	if err := ctx.Err(); err != nil {
345		return nil, err
346	}
347	type connErr struct {
348		conn net.Conn
349		err  error
350	}
351	ch := make(chan connErr)
352	go func() {
353		conn, err := c.Dial(n, addr)
354		select {
355		case ch <- connErr{conn, err}:
356		case <-ctx.Done():
357			if conn != nil {
358				conn.Close()
359			}
360		}
361	}()
362	select {
363	case res := <-ch:
364		return res.conn, res.err
365	case <-ctx.Done():
366		return nil, ctx.Err()
367	}
368}
369
370// Dial initiates a connection to the addr from the remote host.
371// The resulting connection has a zero LocalAddr() and RemoteAddr().
372func (c *Client) Dial(n, addr string) (net.Conn, error) {
373	var ch Channel
374	switch n {
375	case "tcp", "tcp4", "tcp6":
376		// Parse the address into host and numeric port.
377		host, portString, err := net.SplitHostPort(addr)
378		if err != nil {
379			return nil, err
380		}
381		port, err := strconv.ParseUint(portString, 10, 16)
382		if err != nil {
383			return nil, err
384		}
385		ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
386		if err != nil {
387			return nil, err
388		}
389		// Use a zero address for local and remote address.
390		zeroAddr := &net.TCPAddr{
391			IP:   net.IPv4zero,
392			Port: 0,
393		}
394		return &chanConn{
395			Channel: ch,
396			laddr:   zeroAddr,
397			raddr:   zeroAddr,
398		}, nil
399	case "unix":
400		var err error
401		ch, err = c.dialStreamLocal(addr)
402		if err != nil {
403			return nil, err
404		}
405		return &chanConn{
406			Channel: ch,
407			laddr: &net.UnixAddr{
408				Name: "@",
409				Net:  "unix",
410			},
411			raddr: &net.UnixAddr{
412				Name: addr,
413				Net:  "unix",
414			},
415		}, nil
416	default:
417		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
418	}
419}
420
421// DialTCP connects to the remote address raddr on the network net,
422// which must be "tcp", "tcp4", or "tcp6".  If laddr is not nil, it is used
423// as the local address for the connection.
424func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
425	if laddr == nil {
426		laddr = &net.TCPAddr{
427			IP:   net.IPv4zero,
428			Port: 0,
429		}
430	}
431	ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
432	if err != nil {
433		return nil, err
434	}
435	return &chanConn{
436		Channel: ch,
437		laddr:   laddr,
438		raddr:   raddr,
439	}, nil
440}
441
442// RFC 4254 7.2
443type channelOpenDirectMsg struct {
444	raddr string
445	rport uint32
446	laddr string
447	lport uint32
448}
449
450func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
451	msg := channelOpenDirectMsg{
452		raddr: raddr,
453		rport: uint32(rport),
454		laddr: laddr,
455		lport: uint32(lport),
456	}
457	ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
458	if err != nil {
459		return nil, err
460	}
461	go DiscardRequests(in)
462	return ch, nil
463}
464
465type tcpChan struct {
466	Channel // the backing channel
467}
468
469// chanConn fulfills the net.Conn interface without
470// the tcpChan having to hold laddr or raddr directly.
471type chanConn struct {
472	Channel
473	laddr, raddr net.Addr
474}
475
476// LocalAddr returns the local network address.
477func (t *chanConn) LocalAddr() net.Addr {
478	return t.laddr
479}
480
481// RemoteAddr returns the remote network address.
482func (t *chanConn) RemoteAddr() net.Addr {
483	return t.raddr
484}
485
486// SetDeadline sets the read and write deadlines associated
487// with the connection.
488func (t *chanConn) SetDeadline(deadline time.Time) error {
489	if err := t.SetReadDeadline(deadline); err != nil {
490		return err
491	}
492	return t.SetWriteDeadline(deadline)
493}
494
495// SetReadDeadline sets the read deadline.
496// A zero value for t means Read will not time out.
497// After the deadline, the error from Read will implement net.Error
498// with Timeout() == true.
499func (t *chanConn) SetReadDeadline(deadline time.Time) error {
500	// for compatibility with previous version,
501	// the error message contains "tcpChan"
502	return errors.New("ssh: tcpChan: deadline not supported")
503}
504
505// SetWriteDeadline exists to satisfy the net.Conn interface
506// but is not implemented by this type.  It always returns an error.
507func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
508	return errors.New("ssh: tcpChan: deadline not supported")
509}