main
Raw Download raw file
  1//go:build windows
  2// +build windows
  3
  4package winio
  5
  6import (
  7	"context"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"net"
 12	"os"
 13	"time"
 14	"unsafe"
 15
 16	"golang.org/x/sys/windows"
 17
 18	"github.com/Microsoft/go-winio/internal/socket"
 19	"github.com/Microsoft/go-winio/pkg/guid"
 20)
 21
 22const afHVSock = 34 // AF_HYPERV
 23
 24// Well known Service and VM IDs
 25// https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards
 26
 27// HvsockGUIDWildcard is the wildcard VmId for accepting connections from all partitions.
 28func HvsockGUIDWildcard() guid.GUID { // 00000000-0000-0000-0000-000000000000
 29	return guid.GUID{}
 30}
 31
 32// HvsockGUIDBroadcast is the wildcard VmId for broadcasting sends to all partitions.
 33func HvsockGUIDBroadcast() guid.GUID { // ffffffff-ffff-ffff-ffff-ffffffffffff
 34	return guid.GUID{
 35		Data1: 0xffffffff,
 36		Data2: 0xffff,
 37		Data3: 0xffff,
 38		Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
 39	}
 40}
 41
 42// HvsockGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector.
 43func HvsockGUIDLoopback() guid.GUID { // e0e16197-dd56-4a10-9195-5ee7a155a838
 44	return guid.GUID{
 45		Data1: 0xe0e16197,
 46		Data2: 0xdd56,
 47		Data3: 0x4a10,
 48		Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38},
 49	}
 50}
 51
 52// HvsockGUIDSiloHost is the address of a silo's host partition:
 53//   - The silo host of a hosted silo is the utility VM.
 54//   - The silo host of a silo on a physical host is the physical host.
 55func HvsockGUIDSiloHost() guid.GUID { // 36bd0c5c-7276-4223-88ba-7d03b654c568
 56	return guid.GUID{
 57		Data1: 0x36bd0c5c,
 58		Data2: 0x7276,
 59		Data3: 0x4223,
 60		Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68},
 61	}
 62}
 63
 64// HvsockGUIDChildren is the wildcard VmId for accepting connections from the connector's child partitions.
 65func HvsockGUIDChildren() guid.GUID { // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd
 66	return guid.GUID{
 67		Data1: 0x90db8b89,
 68		Data2: 0xd35,
 69		Data3: 0x4f79,
 70		Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd},
 71	}
 72}
 73
 74// HvsockGUIDParent is the wildcard VmId for accepting connections from the connector's parent partition.
 75// Listening on this VmId accepts connection from:
 76//   - Inside silos: silo host partition.
 77//   - Inside hosted silo: host of the VM.
 78//   - Inside VM: VM host.
 79//   - Physical host: Not supported.
 80func HvsockGUIDParent() guid.GUID { // a42e7cda-d03f-480c-9cc2-a4de20abb878
 81	return guid.GUID{
 82		Data1: 0xa42e7cda,
 83		Data2: 0xd03f,
 84		Data3: 0x480c,
 85		Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78},
 86	}
 87}
 88
 89// hvsockVsockServiceTemplate is the Service GUID used for the VSOCK protocol.
 90func hvsockVsockServiceTemplate() guid.GUID { // 00000000-facb-11e6-bd58-64006a7986d3
 91	return guid.GUID{
 92		Data2: 0xfacb,
 93		Data3: 0x11e6,
 94		Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3},
 95	}
 96}
 97
 98// An HvsockAddr is an address for a AF_HYPERV socket.
 99type HvsockAddr struct {
100	VMID      guid.GUID
101	ServiceID guid.GUID
102}
103
104type rawHvsockAddr struct {
105	Family    uint16
106	_         uint16
107	VMID      guid.GUID
108	ServiceID guid.GUID
109}
110
111var _ socket.RawSockaddr = &rawHvsockAddr{}
112
113// Network returns the address's network name, "hvsock".
114func (*HvsockAddr) Network() string {
115	return "hvsock"
116}
117
118func (addr *HvsockAddr) String() string {
119	return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID)
120}
121
122// VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
123func VsockServiceID(port uint32) guid.GUID {
124	g := hvsockVsockServiceTemplate() // make a copy
125	g.Data1 = port
126	return g
127}
128
129func (addr *HvsockAddr) raw() rawHvsockAddr {
130	return rawHvsockAddr{
131		Family:    afHVSock,
132		VMID:      addr.VMID,
133		ServiceID: addr.ServiceID,
134	}
135}
136
137func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
138	addr.VMID = raw.VMID
139	addr.ServiceID = raw.ServiceID
140}
141
142// Sockaddr returns a pointer to and the size of this struct.
143//
144// Implements the [socket.RawSockaddr] interface, and allows use in
145// [socket.Bind] and [socket.ConnectEx].
146func (r *rawHvsockAddr) Sockaddr() (unsafe.Pointer, int32, error) {
147	return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil
148}
149
150// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()`.
151func (r *rawHvsockAddr) FromBytes(b []byte) error {
152	n := int(unsafe.Sizeof(rawHvsockAddr{}))
153
154	if len(b) < n {
155		return fmt.Errorf("got %d, want %d: %w", len(b), n, socket.ErrBufferSize)
156	}
157
158	copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n])
159	if r.Family != afHVSock {
160		return fmt.Errorf("got %d, want %d: %w", r.Family, afHVSock, socket.ErrAddrFamily)
161	}
162
163	return nil
164}
165
166// HvsockListener is a socket listener for the AF_HYPERV address family.
167type HvsockListener struct {
168	sock *win32File
169	addr HvsockAddr
170}
171
172var _ net.Listener = &HvsockListener{}
173
174// HvsockConn is a connected socket of the AF_HYPERV address family.
175type HvsockConn struct {
176	sock          *win32File
177	local, remote HvsockAddr
178}
179
180var _ net.Conn = &HvsockConn{}
181
182func newHVSocket() (*win32File, error) {
183	fd, err := windows.Socket(afHVSock, windows.SOCK_STREAM, 1)
184	if err != nil {
185		return nil, os.NewSyscallError("socket", err)
186	}
187	f, err := makeWin32File(fd)
188	if err != nil {
189		windows.Close(fd)
190		return nil, err
191	}
192	f.socket = true
193	return f, nil
194}
195
196// ListenHvsock listens for connections on the specified hvsock address.
197func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
198	l := &HvsockListener{addr: *addr}
199
200	var sock *win32File
201	sock, err = newHVSocket()
202	if err != nil {
203		return nil, l.opErr("listen", err)
204	}
205	defer func() {
206		if err != nil {
207			_ = sock.Close()
208		}
209	}()
210
211	sa := addr.raw()
212	err = socket.Bind(sock.handle, &sa)
213	if err != nil {
214		return nil, l.opErr("listen", os.NewSyscallError("socket", err))
215	}
216	err = windows.Listen(sock.handle, 16)
217	if err != nil {
218		return nil, l.opErr("listen", os.NewSyscallError("listen", err))
219	}
220	return &HvsockListener{sock: sock, addr: *addr}, nil
221}
222
223func (l *HvsockListener) opErr(op string, err error) error {
224	return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err}
225}
226
227// Addr returns the listener's network address.
228func (l *HvsockListener) Addr() net.Addr {
229	return &l.addr
230}
231
232// Accept waits for the next connection and returns it.
233func (l *HvsockListener) Accept() (_ net.Conn, err error) {
234	sock, err := newHVSocket()
235	if err != nil {
236		return nil, l.opErr("accept", err)
237	}
238	defer func() {
239		if sock != nil {
240			sock.Close()
241		}
242	}()
243	c, err := l.sock.prepareIO()
244	if err != nil {
245		return nil, l.opErr("accept", err)
246	}
247	defer l.sock.wg.Done()
248
249	// AcceptEx, per documentation, requires an extra 16 bytes per address.
250	//
251	// https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex
252	const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
253	var addrbuf [addrlen * 2]byte
254
255	var bytes uint32
256	err = windows.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /* rxdatalen */, addrlen, addrlen, &bytes, &c.o)
257	if _, err = l.sock.asyncIO(c, nil, bytes, err); err != nil {
258		return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
259	}
260
261	conn := &HvsockConn{
262		sock: sock,
263	}
264	// The local address returned in the AcceptEx buffer is the same as the Listener socket's
265	// address. However, the service GUID reported by GetSockName is different from the Listeners
266	// socket, and is sometimes the same as the local address of the socket that dialed the
267	// address, with the service GUID.Data1 incremented, but othertimes is different.
268	// todo: does the local address matter? is the listener's address or the actual address appropriate?
269	conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
270	conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
271
272	// initialize the accepted socket and update its properties with those of the listening socket
273	if err = windows.Setsockopt(sock.handle,
274		windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT,
275		(*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil {
276		return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err))
277	}
278
279	sock = nil
280	return conn, nil
281}
282
283// Close closes the listener, causing any pending Accept calls to fail.
284func (l *HvsockListener) Close() error {
285	return l.sock.Close()
286}
287
288// HvsockDialer configures and dials a Hyper-V Socket (ie, [HvsockConn]).
289type HvsockDialer struct {
290	// Deadline is the time the Dial operation must connect before erroring.
291	Deadline time.Time
292
293	// Retries is the number of additional connects to try if the connection times out, is refused,
294	// or the host is unreachable
295	Retries uint
296
297	// RetryWait is the time to wait after a connection error to retry
298	RetryWait time.Duration
299
300	rt *time.Timer // redial wait timer
301}
302
303// Dial the Hyper-V socket at addr.
304//
305// See [HvsockDialer.Dial] for more information.
306func Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
307	return (&HvsockDialer{}).Dial(ctx, addr)
308}
309
310// Dial attempts to connect to the Hyper-V socket at addr, and returns a connection if successful.
311// Will attempt (HvsockDialer).Retries if dialing fails, waiting (HvsockDialer).RetryWait between
312// retries.
313//
314// Dialing can be cancelled either by providing (HvsockDialer).Deadline, or cancelling ctx.
315func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
316	op := "dial"
317	// create the conn early to use opErr()
318	conn = &HvsockConn{
319		remote: *addr,
320	}
321
322	if !d.Deadline.IsZero() {
323		var cancel context.CancelFunc
324		ctx, cancel = context.WithDeadline(ctx, d.Deadline)
325		defer cancel()
326	}
327
328	// preemptive timeout/cancellation check
329	if err = ctx.Err(); err != nil {
330		return nil, conn.opErr(op, err)
331	}
332
333	sock, err := newHVSocket()
334	if err != nil {
335		return nil, conn.opErr(op, err)
336	}
337	defer func() {
338		if sock != nil {
339			sock.Close()
340		}
341	}()
342
343	sa := addr.raw()
344	err = socket.Bind(sock.handle, &sa)
345	if err != nil {
346		return nil, conn.opErr(op, os.NewSyscallError("bind", err))
347	}
348
349	c, err := sock.prepareIO()
350	if err != nil {
351		return nil, conn.opErr(op, err)
352	}
353	defer sock.wg.Done()
354	var bytes uint32
355	for i := uint(0); i <= d.Retries; i++ {
356		err = socket.ConnectEx(
357			sock.handle,
358			&sa,
359			nil, // sendBuf
360			0,   // sendDataLen
361			&bytes,
362			(*windows.Overlapped)(unsafe.Pointer(&c.o)))
363		_, err = sock.asyncIO(c, nil, bytes, err)
364		if i < d.Retries && canRedial(err) {
365			if err = d.redialWait(ctx); err == nil {
366				continue
367			}
368		}
369		break
370	}
371	if err != nil {
372		return nil, conn.opErr(op, os.NewSyscallError("connectex", err))
373	}
374
375	// update the connection properties, so shutdown can be used
376	if err = windows.Setsockopt(
377		sock.handle,
378		windows.SOL_SOCKET,
379		windows.SO_UPDATE_CONNECT_CONTEXT,
380		nil, // optvalue
381		0,   // optlen
382	); err != nil {
383		return nil, conn.opErr(op, os.NewSyscallError("setsockopt", err))
384	}
385
386	// get the local name
387	var sal rawHvsockAddr
388	err = socket.GetSockName(sock.handle, &sal)
389	if err != nil {
390		return nil, conn.opErr(op, os.NewSyscallError("getsockname", err))
391	}
392	conn.local.fromRaw(&sal)
393
394	// one last check for timeout, since asyncIO doesn't check the context
395	if err = ctx.Err(); err != nil {
396		return nil, conn.opErr(op, err)
397	}
398
399	conn.sock = sock
400	sock = nil
401
402	return conn, nil
403}
404
405// redialWait waits before attempting to redial, resetting the timer as appropriate.
406func (d *HvsockDialer) redialWait(ctx context.Context) (err error) {
407	if d.RetryWait == 0 {
408		return nil
409	}
410
411	if d.rt == nil {
412		d.rt = time.NewTimer(d.RetryWait)
413	} else {
414		// should already be stopped and drained
415		d.rt.Reset(d.RetryWait)
416	}
417
418	select {
419	case <-ctx.Done():
420	case <-d.rt.C:
421		return nil
422	}
423
424	// stop and drain the timer
425	if !d.rt.Stop() {
426		<-d.rt.C
427	}
428	return ctx.Err()
429}
430
431// assumes error is a plain, unwrapped windows.Errno provided by direct syscall.
432func canRedial(err error) bool {
433	//nolint:errorlint // guaranteed to be an Errno
434	switch err {
435	case windows.WSAECONNREFUSED, windows.WSAENETUNREACH, windows.WSAETIMEDOUT,
436		windows.ERROR_CONNECTION_REFUSED, windows.ERROR_CONNECTION_UNAVAIL:
437		return true
438	default:
439		return false
440	}
441}
442
443func (conn *HvsockConn) opErr(op string, err error) error {
444	// translate from "file closed" to "socket closed"
445	if errors.Is(err, ErrFileClosed) {
446		err = socket.ErrSocketClosed
447	}
448	return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
449}
450
451func (conn *HvsockConn) Read(b []byte) (int, error) {
452	c, err := conn.sock.prepareIO()
453	if err != nil {
454		return 0, conn.opErr("read", err)
455	}
456	defer conn.sock.wg.Done()
457	buf := windows.WSABuf{Buf: &b[0], Len: uint32(len(b))}
458	var flags, bytes uint32
459	err = windows.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
460	n, err := conn.sock.asyncIO(c, &conn.sock.readDeadline, bytes, err)
461	if err != nil {
462		var eno windows.Errno
463		if errors.As(err, &eno) {
464			err = os.NewSyscallError("wsarecv", eno)
465		}
466		return 0, conn.opErr("read", err)
467	} else if n == 0 {
468		err = io.EOF
469	}
470	return n, err
471}
472
473func (conn *HvsockConn) Write(b []byte) (int, error) {
474	t := 0
475	for len(b) != 0 {
476		n, err := conn.write(b)
477		if err != nil {
478			return t + n, err
479		}
480		t += n
481		b = b[n:]
482	}
483	return t, nil
484}
485
486func (conn *HvsockConn) write(b []byte) (int, error) {
487	c, err := conn.sock.prepareIO()
488	if err != nil {
489		return 0, conn.opErr("write", err)
490	}
491	defer conn.sock.wg.Done()
492	buf := windows.WSABuf{Buf: &b[0], Len: uint32(len(b))}
493	var bytes uint32
494	err = windows.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
495	n, err := conn.sock.asyncIO(c, &conn.sock.writeDeadline, bytes, err)
496	if err != nil {
497		var eno windows.Errno
498		if errors.As(err, &eno) {
499			err = os.NewSyscallError("wsasend", eno)
500		}
501		return 0, conn.opErr("write", err)
502	}
503	return n, err
504}
505
506// Close closes the socket connection, failing any pending read or write calls.
507func (conn *HvsockConn) Close() error {
508	return conn.sock.Close()
509}
510
511func (conn *HvsockConn) IsClosed() bool {
512	return conn.sock.IsClosed()
513}
514
515// shutdown disables sending or receiving on a socket.
516func (conn *HvsockConn) shutdown(how int) error {
517	if conn.IsClosed() {
518		return socket.ErrSocketClosed
519	}
520
521	err := windows.Shutdown(conn.sock.handle, how)
522	if err != nil {
523		// If the connection was closed, shutdowns fail with "not connected"
524		if errors.Is(err, windows.WSAENOTCONN) ||
525			errors.Is(err, windows.WSAESHUTDOWN) {
526			err = socket.ErrSocketClosed
527		}
528		return os.NewSyscallError("shutdown", err)
529	}
530	return nil
531}
532
533// CloseRead shuts down the read end of the socket, preventing future read operations.
534func (conn *HvsockConn) CloseRead() error {
535	err := conn.shutdown(windows.SHUT_RD)
536	if err != nil {
537		return conn.opErr("closeread", err)
538	}
539	return nil
540}
541
542// CloseWrite shuts down the write end of the socket, preventing future write operations and
543// notifying the other endpoint that no more data will be written.
544func (conn *HvsockConn) CloseWrite() error {
545	err := conn.shutdown(windows.SHUT_WR)
546	if err != nil {
547		return conn.opErr("closewrite", err)
548	}
549	return nil
550}
551
552// LocalAddr returns the local address of the connection.
553func (conn *HvsockConn) LocalAddr() net.Addr {
554	return &conn.local
555}
556
557// RemoteAddr returns the remote address of the connection.
558func (conn *HvsockConn) RemoteAddr() net.Addr {
559	return &conn.remote
560}
561
562// SetDeadline implements the net.Conn SetDeadline method.
563func (conn *HvsockConn) SetDeadline(t time.Time) error {
564	// todo: implement `SetDeadline` for `win32File`
565	if err := conn.SetReadDeadline(t); err != nil {
566		return fmt.Errorf("set read deadline: %w", err)
567	}
568	if err := conn.SetWriteDeadline(t); err != nil {
569		return fmt.Errorf("set write deadline: %w", err)
570	}
571	return nil
572}
573
574// SetReadDeadline implements the net.Conn SetReadDeadline method.
575func (conn *HvsockConn) SetReadDeadline(t time.Time) error {
576	return conn.sock.SetReadDeadline(t)
577}
578
579// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
580func (conn *HvsockConn) SetWriteDeadline(t time.Time) error {
581	return conn.sock.SetWriteDeadline(t)
582}