main
Raw Download raw file
  1// Copyright 2018 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package socks
  6
  7import (
  8	"context"
  9	"errors"
 10	"io"
 11	"net"
 12	"strconv"
 13	"time"
 14)
 15
 16var (
 17	noDeadline   = time.Time{}
 18	aLongTimeAgo = time.Unix(1, 0)
 19)
 20
 21func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
 22	host, port, err := splitHostPort(address)
 23	if err != nil {
 24		return nil, err
 25	}
 26	if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
 27		c.SetDeadline(deadline)
 28		defer c.SetDeadline(noDeadline)
 29	}
 30	if ctx != context.Background() {
 31		errCh := make(chan error, 1)
 32		done := make(chan struct{})
 33		defer func() {
 34			close(done)
 35			if ctxErr == nil {
 36				ctxErr = <-errCh
 37			}
 38		}()
 39		go func() {
 40			select {
 41			case <-ctx.Done():
 42				c.SetDeadline(aLongTimeAgo)
 43				errCh <- ctx.Err()
 44			case <-done:
 45				errCh <- nil
 46			}
 47		}()
 48	}
 49
 50	b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
 51	b = append(b, Version5)
 52	if len(d.AuthMethods) == 0 || d.Authenticate == nil {
 53		b = append(b, 1, byte(AuthMethodNotRequired))
 54	} else {
 55		ams := d.AuthMethods
 56		if len(ams) > 255 {
 57			return nil, errors.New("too many authentication methods")
 58		}
 59		b = append(b, byte(len(ams)))
 60		for _, am := range ams {
 61			b = append(b, byte(am))
 62		}
 63	}
 64	if _, ctxErr = c.Write(b); ctxErr != nil {
 65		return
 66	}
 67
 68	if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
 69		return
 70	}
 71	if b[0] != Version5 {
 72		return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
 73	}
 74	am := AuthMethod(b[1])
 75	if am == AuthMethodNoAcceptableMethods {
 76		return nil, errors.New("no acceptable authentication methods")
 77	}
 78	if d.Authenticate != nil {
 79		if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
 80			return
 81		}
 82	}
 83
 84	b = b[:0]
 85	b = append(b, Version5, byte(d.cmd), 0)
 86	if ip := net.ParseIP(host); ip != nil {
 87		if ip4 := ip.To4(); ip4 != nil {
 88			b = append(b, AddrTypeIPv4)
 89			b = append(b, ip4...)
 90		} else if ip6 := ip.To16(); ip6 != nil {
 91			b = append(b, AddrTypeIPv6)
 92			b = append(b, ip6...)
 93		} else {
 94			return nil, errors.New("unknown address type")
 95		}
 96	} else {
 97		if len(host) > 255 {
 98			return nil, errors.New("FQDN too long")
 99		}
100		b = append(b, AddrTypeFQDN)
101		b = append(b, byte(len(host)))
102		b = append(b, host...)
103	}
104	b = append(b, byte(port>>8), byte(port))
105	if _, ctxErr = c.Write(b); ctxErr != nil {
106		return
107	}
108
109	if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
110		return
111	}
112	if b[0] != Version5 {
113		return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
114	}
115	if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
116		return nil, errors.New("unknown error " + cmdErr.String())
117	}
118	if b[2] != 0 {
119		return nil, errors.New("non-zero reserved field")
120	}
121	l := 2
122	var a Addr
123	switch b[3] {
124	case AddrTypeIPv4:
125		l += net.IPv4len
126		a.IP = make(net.IP, net.IPv4len)
127	case AddrTypeIPv6:
128		l += net.IPv6len
129		a.IP = make(net.IP, net.IPv6len)
130	case AddrTypeFQDN:
131		if _, err := io.ReadFull(c, b[:1]); err != nil {
132			return nil, err
133		}
134		l += int(b[0])
135	default:
136		return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
137	}
138	if cap(b) < l {
139		b = make([]byte, l)
140	} else {
141		b = b[:l]
142	}
143	if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
144		return
145	}
146	if a.IP != nil {
147		copy(a.IP, b)
148	} else {
149		a.Name = string(b[:len(b)-2])
150	}
151	a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
152	return &a, nil
153}
154
155func splitHostPort(address string) (string, int, error) {
156	host, port, err := net.SplitHostPort(address)
157	if err != nil {
158		return "", 0, err
159	}
160	portnum, err := strconv.Atoi(port)
161	if err != nil {
162		return "", 0, err
163	}
164	if 1 > portnum || portnum > 0xffff {
165		return "", 0, errors.New("port number out of range " + port)
166	}
167	return host, portnum, nil
168}