main
1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package agent
6
7import (
8 "errors"
9 "io"
10 "net"
11 "sync"
12
13 "golang.org/x/crypto/ssh"
14)
15
16// RequestAgentForwarding sets up agent forwarding for the session.
17// ForwardToAgent or ForwardToRemote should be called to route
18// the authentication requests.
19func RequestAgentForwarding(session *ssh.Session) error {
20 ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil)
21 if err != nil {
22 return err
23 }
24 if !ok {
25 return errors.New("forwarding request denied")
26 }
27 return nil
28}
29
30// ForwardToAgent routes authentication requests to the given keyring.
31func ForwardToAgent(client *ssh.Client, keyring Agent) error {
32 channels := client.HandleChannelOpen(channelType)
33 if channels == nil {
34 return errors.New("agent: already have handler for " + channelType)
35 }
36
37 go func() {
38 for ch := range channels {
39 channel, reqs, err := ch.Accept()
40 if err != nil {
41 continue
42 }
43 go ssh.DiscardRequests(reqs)
44 go func() {
45 ServeAgent(keyring, channel)
46 channel.Close()
47 }()
48 }
49 }()
50 return nil
51}
52
53const channelType = "auth-agent@openssh.com"
54
55// ForwardToRemote routes authentication requests to the ssh-agent
56// process serving on the given unix socket.
57func ForwardToRemote(client *ssh.Client, addr string) error {
58 channels := client.HandleChannelOpen(channelType)
59 if channels == nil {
60 return errors.New("agent: already have handler for " + channelType)
61 }
62 conn, err := net.Dial("unix", addr)
63 if err != nil {
64 return err
65 }
66 conn.Close()
67
68 go func() {
69 for ch := range channels {
70 channel, reqs, err := ch.Accept()
71 if err != nil {
72 continue
73 }
74 go ssh.DiscardRequests(reqs)
75 go forwardUnixSocket(channel, addr)
76 }
77 }()
78 return nil
79}
80
81func forwardUnixSocket(channel ssh.Channel, addr string) {
82 conn, err := net.Dial("unix", addr)
83 if err != nil {
84 return
85 }
86
87 var wg sync.WaitGroup
88 wg.Add(2)
89 go func() {
90 io.Copy(conn, channel)
91 conn.(*net.UnixConn).CloseWrite()
92 wg.Done()
93 }()
94 go func() {
95 io.Copy(channel, conn)
96 channel.CloseWrite()
97 wg.Done()
98 }()
99
100 wg.Wait()
101 conn.Close()
102 channel.Close()
103}