main
Raw Download raw file
  1//go:build windows
  2
  3package socket
  4
  5import (
  6	"errors"
  7	"fmt"
  8	"net"
  9	"sync"
 10	"syscall"
 11	"unsafe"
 12
 13	"github.com/Microsoft/go-winio/pkg/guid"
 14	"golang.org/x/sys/windows"
 15)
 16
 17//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go
 18
 19//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
 20//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
 21//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
 22
 23const socketError = uintptr(^uint32(0))
 24
 25var (
 26	// todo(helsaawy): create custom error types to store the desired vs actual size and addr family?
 27
 28	ErrBufferSize     = errors.New("buffer size")
 29	ErrAddrFamily     = errors.New("address family")
 30	ErrInvalidPointer = errors.New("invalid pointer")
 31	ErrSocketClosed   = fmt.Errorf("socket closed: %w", net.ErrClosed)
 32)
 33
 34// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)
 35
 36// GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
 37// If rsa is not large enough, the [windows.WSAEFAULT] is returned.
 38func GetSockName(s windows.Handle, rsa RawSockaddr) error {
 39	ptr, l, err := rsa.Sockaddr()
 40	if err != nil {
 41		return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
 42	}
 43
 44	// although getsockname returns WSAEFAULT if the buffer is too small, it does not set
 45	// &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
 46	return getsockname(s, ptr, &l)
 47}
 48
 49// GetPeerName returns the remote address the socket is connected to.
 50//
 51// See [GetSockName] for more information.
 52func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
 53	ptr, l, err := rsa.Sockaddr()
 54	if err != nil {
 55		return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
 56	}
 57
 58	return getpeername(s, ptr, &l)
 59}
 60
 61func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
 62	ptr, l, err := rsa.Sockaddr()
 63	if err != nil {
 64		return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
 65	}
 66
 67	return bind(s, ptr, l)
 68}
 69
 70// "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
 71// their sockaddr interface, so they cannot be used with HvsockAddr
 72// Replicate functionality here from
 73// https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go
 74
 75// The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
 76// runtime via a WSAIoctl call:
 77// https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks
 78
 79type runtimeFunc struct {
 80	id   guid.GUID
 81	once sync.Once
 82	addr uintptr
 83	err  error
 84}
 85
 86func (f *runtimeFunc) Load() error {
 87	f.once.Do(func() {
 88		var s windows.Handle
 89		s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
 90		if f.err != nil {
 91			return
 92		}
 93		defer windows.CloseHandle(s) //nolint:errcheck
 94
 95		var n uint32
 96		f.err = windows.WSAIoctl(s,
 97			windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
 98			(*byte)(unsafe.Pointer(&f.id)),
 99			uint32(unsafe.Sizeof(f.id)),
100			(*byte)(unsafe.Pointer(&f.addr)),
101			uint32(unsafe.Sizeof(f.addr)),
102			&n,
103			nil, // overlapped
104			0,   // completionRoutine
105		)
106	})
107	return f.err
108}
109
110var (
111	// todo: add `AcceptEx` and `GetAcceptExSockaddrs`
112	WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
113		Data1: 0x25a207b9,
114		Data2: 0xddf3,
115		Data3: 0x4660,
116		Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
117	}
118
119	connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
120)
121
122func ConnectEx(
123	fd windows.Handle,
124	rsa RawSockaddr,
125	sendBuf *byte,
126	sendDataLen uint32,
127	bytesSent *uint32,
128	overlapped *windows.Overlapped,
129) error {
130	if err := connectExFunc.Load(); err != nil {
131		return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
132	}
133	ptr, n, err := rsa.Sockaddr()
134	if err != nil {
135		return err
136	}
137	return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
138}
139
140// BOOL LpfnConnectex(
141//   [in]           SOCKET s,
142//   [in]           const sockaddr *name,
143//   [in]           int namelen,
144//   [in, optional] PVOID lpSendBuffer,
145//   [in]           DWORD dwSendDataLength,
146//   [out]          LPDWORD lpdwBytesSent,
147//   [in]           LPOVERLAPPED lpOverlapped
148// )
149
150func connectEx(
151	s windows.Handle,
152	name unsafe.Pointer,
153	namelen int32,
154	sendBuf *byte,
155	sendDataLen uint32,
156	bytesSent *uint32,
157	overlapped *windows.Overlapped,
158) (err error) {
159	r1, _, e1 := syscall.SyscallN(connectExFunc.addr,
160		uintptr(s),
161		uintptr(name),
162		uintptr(namelen),
163		uintptr(unsafe.Pointer(sendBuf)),
164		uintptr(sendDataLen),
165		uintptr(unsafe.Pointer(bytesSent)),
166		uintptr(unsafe.Pointer(overlapped)),
167	)
168
169	if r1 == 0 {
170		if e1 != 0 {
171			err = error(e1)
172		} else {
173			err = syscall.EINVAL
174		}
175	}
176	return err
177}