main
Raw Download raw file
  1package main
  2
  3import (
  4	"encoding/binary"
  5	"fmt"
  6	"io"
  7	"net"
  8	"net/netip"
  9)
 10
 11const (
 12	VersionSocks5 = byte(0x05)
 13
 14	MethodNoAuthRequired     = byte(0x00)
 15	MethodNoAcceptableMethod = byte(0xFF)
 16
 17	CommandConnect = byte(0x01)
 18
 19	AddressTypeIPV4   = byte(0x01)
 20	AddressTypeDomain = byte(0x03)
 21	AddressTypeIPV6   = byte(0x04)
 22
 23	ReplySucceeded      = byte(0x00)
 24	ReplyGeneralFailure = byte(0x01)
 25
 26	Reserved = byte(0x00)
 27)
 28
 29var (
 30	// Useful byte arrays
 31	_zeroIPv4 = []byte(net.IPv4zero.To4())
 32	_zeroPort = []byte{0x00, 0x00}
 33)
 34
 35// AddrRequest represents a union type that holds either an netip.AddrPort
 36// or a string. Only one of the fields (AddrPort or Domain) should be non-nil
 37// at a time. This enables a socks5 parsed client request to return any of the
 38// three valid address type destaniation addr information in a struct that is
 39// useable without interface{} or reflection.
 40type AddrRequest struct {
 41	AddrPort *netip.AddrPort
 42	Domain   *string
 43}
 44
 45// reply creates a byte slice for a socks5 reply (RFC 1928 Section 6)
 46//
 47//	+----+-----+-------+------+----------+----------+
 48//	|VER | REP |  RSV  | ATYP | BND.ADDR | BND.PORT |
 49//	+----+-----+-------+------+----------+----------+
 50//	| 1  |  1  | X'00' |  1   | Variable |    2     |
 51//	+----+-----+-------+------+----------+----------+
 52func replyV4(rep byte) []byte {
 53	reply := make([]byte, 4, 10)
 54	reply[0] = VersionSocks5
 55	reply[1] = rep
 56	reply[2] = Reserved
 57	reply[3] = AddressTypeIPV4
 58	reply = append(reply, _zeroIPv4...)
 59	reply = append(reply, _zeroPort...)
 60	return reply
 61}
 62
 63// handleClientIdentifier
 64//
 65//	+----+----------+----------+
 66//	|VER | NMETHODS | METHODS  |
 67//	+----+----------+----------+
 68//	| 1  |    1     | 1 to 255 |
 69//	+----+----------+----------+
 70func handleClientIdentifier(conn io.ReadWriter) error {
 71
 72	// version and nmethods
 73	client := make([]byte, 2)
 74	n, err := conn.Read(client)
 75	if err != nil {
 76		return fmt.Errorf("failed to read client identifier: %w", err)
 77	}
 78	if n != 2 {
 79		return fmt.Errorf("failed to parse client identifier: bad length")
 80	}
 81
 82	// only socks5 supported
 83	ver := client[0]
 84	if ver != VersionSocks5 {
 85		_, err := conn.Write([]byte{VersionSocks5, MethodNoAcceptableMethod})
 86		if err != nil {
 87			return fmt.Errorf("failed to write to socket: %w", err)
 88		}
 89		return fmt.Errorf("version not supported version=%q", ver)
 90	}
 91
 92	// methods
 93	nmethods := int(client[1])
 94	methods := make([]byte, nmethods)
 95	n, err = conn.Read(methods)
 96	if err != nil {
 97		return fmt.Errorf("failed to read client methods: %w", err)
 98	}
 99	if n != nmethods {
100		return fmt.Errorf("failed to parse client methods: bad length")
101	}
102
103	// only no auth supported
104	match := false
105	for _, m := range methods {
106		if m == MethodNoAuthRequired {
107			match = true
108			break
109		}
110	}
111	if !match {
112		_, err := conn.Write([]byte{VersionSocks5, MethodNoAcceptableMethod})
113		if err != nil {
114			return fmt.Errorf("failed to write to socket: %w", err)
115		}
116		return fmt.Errorf("no method compatibility found")
117	}
118
119	_, err = conn.Write([]byte{VersionSocks5, MethodNoAuthRequired})
120	return nil
121}
122
123func ParseClientRequest(conn io.ReadWriter) (AddrRequest, error) {
124	var (
125		ar     AddrRequest
126		addr   netip.Addr
127		domain string
128	)
129
130	req := make([]byte, 4) // version and count
131	n, err := conn.Read(req)
132	if err != nil {
133		return ar, fmt.Errorf("failed to read client request: %w", err)
134	}
135	if n != 4 {
136		return ar, fmt.Errorf("failed to parse client request: bad length")
137	}
138
139	// socks5 and connect only
140	ver, cmd := req[0], req[1]
141	if ver != VersionSocks5 || cmd != CommandConnect {
142		_, err := conn.Write([]byte{VersionSocks5, ReplyGeneralFailure})
143		if err != nil {
144			return ar, fmt.Errorf("failed to write to socket: %w", err)
145		}
146		return ar, fmt.Errorf("unspported request ver=%q cmd=%q",
147			ver, cmd)
148	}
149
150	addrType := req[3]
151	switch addrType {
152
153	case AddressTypeIPV4:
154
155		ipRaw := [net.IPv4len]byte{}
156		n, err := conn.Read(ipRaw[:])
157		if err != nil {
158			return ar, fmt.Errorf("failed to read dst address: %w", err)
159		}
160		if n != net.IPv4len {
161			return ar, fmt.Errorf("failed to parse dst address: bad length")
162		}
163		addr = netip.AddrFrom4(ipRaw)
164
165	case AddressTypeIPV6:
166
167		ipRaw := [net.IPv6len]byte{}
168		n, err := conn.Read(ipRaw[:])
169		if err != nil {
170			return ar, fmt.Errorf("failed to read dst address: %w", err)
171		}
172		if n != net.IPv6len {
173			return ar, fmt.Errorf("failed to parse dst address: bad length")
174		}
175
176		addr = netip.AddrFrom16(ipRaw)
177
178	case AddressTypeDomain:
179
180		domainLen := make([]byte, 1)
181		n, err := conn.Read(domainLen)
182		if err != nil {
183			return ar, fmt.Errorf("failed to read dst domain length: %w", err)
184		}
185		if n != 1 {
186			return ar, fmt.Errorf("failed to parse dst domain length: bad length")
187		}
188		domainRaw := make([]byte, int(domainLen[0]))
189		n, err = conn.Read(domainRaw)
190		if err != nil {
191			return ar, fmt.Errorf("failed to read dst domain: %w", err)
192		}
193		if n != int(domainLen[0]) {
194			return ar, fmt.Errorf("failed to parse dst domain: bad length")
195		}
196		domain = string(domainRaw)
197
198	default:
199
200		return ar, fmt.Errorf("unknown address type")
201	}
202
203	portRaw := make([]byte, 2)
204	n, err = conn.Read(portRaw)
205	if err != nil {
206		return ar, fmt.Errorf("failed to read dst port: %w", err)
207	}
208	if n != 2 {
209		return ar, fmt.Errorf("failed to read dst port: bad length")
210	}
211	// destination port in network octet order
212	port := binary.BigEndian.Uint16(portRaw)
213
214	switch addrType {
215	case AddressTypeIPV4, AddressTypeIPV6:
216		addrPort := netip.AddrPortFrom(addr, port)
217		ar.AddrPort = &addrPort
218		return ar, nil
219	case AddressTypeDomain:
220		domainPort := fmt.Sprintf("%s:%d", domain, port)
221		ar.Domain = &domainPort
222		return ar, nil
223	default:
224		return ar, fmt.Errorf("unknown address type")
225	}
226}