main
Raw Download raw file
  1// Copyright 2014 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package agent
  6
  7import (
  8	"errors"
  9	"io"
 10	"net"
 11	"sync"
 12
 13	"golang.org/x/crypto/ssh"
 14)
 15
 16// RequestAgentForwarding sets up agent forwarding for the session.
 17// ForwardToAgent or ForwardToRemote should be called to route
 18// the authentication requests.
 19func RequestAgentForwarding(session *ssh.Session) error {
 20	ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil)
 21	if err != nil {
 22		return err
 23	}
 24	if !ok {
 25		return errors.New("forwarding request denied")
 26	}
 27	return nil
 28}
 29
 30// ForwardToAgent routes authentication requests to the given keyring.
 31func ForwardToAgent(client *ssh.Client, keyring Agent) error {
 32	channels := client.HandleChannelOpen(channelType)
 33	if channels == nil {
 34		return errors.New("agent: already have handler for " + channelType)
 35	}
 36
 37	go func() {
 38		for ch := range channels {
 39			channel, reqs, err := ch.Accept()
 40			if err != nil {
 41				continue
 42			}
 43			go ssh.DiscardRequests(reqs)
 44			go func() {
 45				ServeAgent(keyring, channel)
 46				channel.Close()
 47			}()
 48		}
 49	}()
 50	return nil
 51}
 52
 53const channelType = "auth-agent@openssh.com"
 54
 55// ForwardToRemote routes authentication requests to the ssh-agent
 56// process serving on the given unix socket.
 57func ForwardToRemote(client *ssh.Client, addr string) error {
 58	channels := client.HandleChannelOpen(channelType)
 59	if channels == nil {
 60		return errors.New("agent: already have handler for " + channelType)
 61	}
 62	conn, err := net.Dial("unix", addr)
 63	if err != nil {
 64		return err
 65	}
 66	conn.Close()
 67
 68	go func() {
 69		for ch := range channels {
 70			channel, reqs, err := ch.Accept()
 71			if err != nil {
 72				continue
 73			}
 74			go ssh.DiscardRequests(reqs)
 75			go forwardUnixSocket(channel, addr)
 76		}
 77	}()
 78	return nil
 79}
 80
 81func forwardUnixSocket(channel ssh.Channel, addr string) {
 82	conn, err := net.Dial("unix", addr)
 83	if err != nil {
 84		return
 85	}
 86
 87	var wg sync.WaitGroup
 88	wg.Add(2)
 89	go func() {
 90		io.Copy(conn, channel)
 91		conn.(*net.UnixConn).CloseWrite()
 92		wg.Done()
 93	}()
 94	go func() {
 95		io.Copy(channel, conn)
 96		channel.CloseWrite()
 97		wg.Done()
 98	}()
 99
100	wg.Wait()
101	conn.Close()
102	channel.Close()
103}