main
1//go:build windows
2
3package socket
4
5import (
6 "errors"
7 "fmt"
8 "net"
9 "sync"
10 "syscall"
11 "unsafe"
12
13 "github.com/Microsoft/go-winio/pkg/guid"
14 "golang.org/x/sys/windows"
15)
16
17//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go
18
19//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
20//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
21//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
22
23const socketError = uintptr(^uint32(0))
24
25var (
26 // todo(helsaawy): create custom error types to store the desired vs actual size and addr family?
27
28 ErrBufferSize = errors.New("buffer size")
29 ErrAddrFamily = errors.New("address family")
30 ErrInvalidPointer = errors.New("invalid pointer")
31 ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)
32)
33
34// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)
35
36// GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
37// If rsa is not large enough, the [windows.WSAEFAULT] is returned.
38func GetSockName(s windows.Handle, rsa RawSockaddr) error {
39 ptr, l, err := rsa.Sockaddr()
40 if err != nil {
41 return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
42 }
43
44 // although getsockname returns WSAEFAULT if the buffer is too small, it does not set
45 // &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
46 return getsockname(s, ptr, &l)
47}
48
49// GetPeerName returns the remote address the socket is connected to.
50//
51// See [GetSockName] for more information.
52func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
53 ptr, l, err := rsa.Sockaddr()
54 if err != nil {
55 return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
56 }
57
58 return getpeername(s, ptr, &l)
59}
60
61func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
62 ptr, l, err := rsa.Sockaddr()
63 if err != nil {
64 return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
65 }
66
67 return bind(s, ptr, l)
68}
69
70// "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
71// their sockaddr interface, so they cannot be used with HvsockAddr
72// Replicate functionality here from
73// https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go
74
75// The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
76// runtime via a WSAIoctl call:
77// https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks
78
79type runtimeFunc struct {
80 id guid.GUID
81 once sync.Once
82 addr uintptr
83 err error
84}
85
86func (f *runtimeFunc) Load() error {
87 f.once.Do(func() {
88 var s windows.Handle
89 s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
90 if f.err != nil {
91 return
92 }
93 defer windows.CloseHandle(s) //nolint:errcheck
94
95 var n uint32
96 f.err = windows.WSAIoctl(s,
97 windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
98 (*byte)(unsafe.Pointer(&f.id)),
99 uint32(unsafe.Sizeof(f.id)),
100 (*byte)(unsafe.Pointer(&f.addr)),
101 uint32(unsafe.Sizeof(f.addr)),
102 &n,
103 nil, // overlapped
104 0, // completionRoutine
105 )
106 })
107 return f.err
108}
109
110var (
111 // todo: add `AcceptEx` and `GetAcceptExSockaddrs`
112 WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
113 Data1: 0x25a207b9,
114 Data2: 0xddf3,
115 Data3: 0x4660,
116 Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
117 }
118
119 connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
120)
121
122func ConnectEx(
123 fd windows.Handle,
124 rsa RawSockaddr,
125 sendBuf *byte,
126 sendDataLen uint32,
127 bytesSent *uint32,
128 overlapped *windows.Overlapped,
129) error {
130 if err := connectExFunc.Load(); err != nil {
131 return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
132 }
133 ptr, n, err := rsa.Sockaddr()
134 if err != nil {
135 return err
136 }
137 return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
138}
139
140// BOOL LpfnConnectex(
141// [in] SOCKET s,
142// [in] const sockaddr *name,
143// [in] int namelen,
144// [in, optional] PVOID lpSendBuffer,
145// [in] DWORD dwSendDataLength,
146// [out] LPDWORD lpdwBytesSent,
147// [in] LPOVERLAPPED lpOverlapped
148// )
149
150func connectEx(
151 s windows.Handle,
152 name unsafe.Pointer,
153 namelen int32,
154 sendBuf *byte,
155 sendDataLen uint32,
156 bytesSent *uint32,
157 overlapped *windows.Overlapped,
158) (err error) {
159 r1, _, e1 := syscall.SyscallN(connectExFunc.addr,
160 uintptr(s),
161 uintptr(name),
162 uintptr(namelen),
163 uintptr(unsafe.Pointer(sendBuf)),
164 uintptr(sendDataLen),
165 uintptr(unsafe.Pointer(bytesSent)),
166 uintptr(unsafe.Pointer(overlapped)),
167 )
168
169 if r1 == 0 {
170 if e1 != 0 {
171 err = error(e1)
172 } else {
173 err = syscall.EINVAL
174 }
175 }
176 return err
177}