main
Raw Download raw file
  1//go:build windows
  2// +build windows
  3
  4package winio
  5
  6import (
  7	"bytes"
  8	"encoding/binary"
  9	"fmt"
 10	"runtime"
 11	"sync"
 12	"unicode/utf16"
 13
 14	"golang.org/x/sys/windows"
 15)
 16
 17//sys adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) [true] = advapi32.AdjustTokenPrivileges
 18//sys impersonateSelf(level uint32) (err error) = advapi32.ImpersonateSelf
 19//sys revertToSelf() (err error) = advapi32.RevertToSelf
 20//sys openThreadToken(thread windows.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) = advapi32.OpenThreadToken
 21//sys getCurrentThread() (h windows.Handle) = GetCurrentThread
 22//sys lookupPrivilegeValue(systemName string, name string, luid *uint64) (err error) = advapi32.LookupPrivilegeValueW
 23//sys lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) = advapi32.LookupPrivilegeNameW
 24//sys lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) = advapi32.LookupPrivilegeDisplayNameW
 25
 26const (
 27	//revive:disable-next-line:var-naming ALL_CAPS
 28	SE_PRIVILEGE_ENABLED = windows.SE_PRIVILEGE_ENABLED
 29
 30	//revive:disable-next-line:var-naming ALL_CAPS
 31	ERROR_NOT_ALL_ASSIGNED windows.Errno = windows.ERROR_NOT_ALL_ASSIGNED
 32
 33	SeBackupPrivilege   = "SeBackupPrivilege"
 34	SeRestorePrivilege  = "SeRestorePrivilege"
 35	SeSecurityPrivilege = "SeSecurityPrivilege"
 36)
 37
 38var (
 39	privNames     = make(map[string]uint64)
 40	privNameMutex sync.Mutex
 41)
 42
 43// PrivilegeError represents an error enabling privileges.
 44type PrivilegeError struct {
 45	privileges []uint64
 46}
 47
 48func (e *PrivilegeError) Error() string {
 49	s := "Could not enable privilege "
 50	if len(e.privileges) > 1 {
 51		s = "Could not enable privileges "
 52	}
 53	for i, p := range e.privileges {
 54		if i != 0 {
 55			s += ", "
 56		}
 57		s += `"`
 58		s += getPrivilegeName(p)
 59		s += `"`
 60	}
 61	return s
 62}
 63
 64// RunWithPrivilege enables a single privilege for a function call.
 65func RunWithPrivilege(name string, fn func() error) error {
 66	return RunWithPrivileges([]string{name}, fn)
 67}
 68
 69// RunWithPrivileges enables privileges for a function call.
 70func RunWithPrivileges(names []string, fn func() error) error {
 71	privileges, err := mapPrivileges(names)
 72	if err != nil {
 73		return err
 74	}
 75	runtime.LockOSThread()
 76	defer runtime.UnlockOSThread()
 77	token, err := newThreadToken()
 78	if err != nil {
 79		return err
 80	}
 81	defer releaseThreadToken(token)
 82	err = adjustPrivileges(token, privileges, SE_PRIVILEGE_ENABLED)
 83	if err != nil {
 84		return err
 85	}
 86	return fn()
 87}
 88
 89func mapPrivileges(names []string) ([]uint64, error) {
 90	privileges := make([]uint64, 0, len(names))
 91	privNameMutex.Lock()
 92	defer privNameMutex.Unlock()
 93	for _, name := range names {
 94		p, ok := privNames[name]
 95		if !ok {
 96			err := lookupPrivilegeValue("", name, &p)
 97			if err != nil {
 98				return nil, err
 99			}
100			privNames[name] = p
101		}
102		privileges = append(privileges, p)
103	}
104	return privileges, nil
105}
106
107// EnableProcessPrivileges enables privileges globally for the process.
108func EnableProcessPrivileges(names []string) error {
109	return enableDisableProcessPrivilege(names, SE_PRIVILEGE_ENABLED)
110}
111
112// DisableProcessPrivileges disables privileges globally for the process.
113func DisableProcessPrivileges(names []string) error {
114	return enableDisableProcessPrivilege(names, 0)
115}
116
117func enableDisableProcessPrivilege(names []string, action uint32) error {
118	privileges, err := mapPrivileges(names)
119	if err != nil {
120		return err
121	}
122
123	p := windows.CurrentProcess()
124	var token windows.Token
125	err = windows.OpenProcessToken(p, windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token)
126	if err != nil {
127		return err
128	}
129
130	defer token.Close()
131	return adjustPrivileges(token, privileges, action)
132}
133
134func adjustPrivileges(token windows.Token, privileges []uint64, action uint32) error {
135	var b bytes.Buffer
136	_ = binary.Write(&b, binary.LittleEndian, uint32(len(privileges)))
137	for _, p := range privileges {
138		_ = binary.Write(&b, binary.LittleEndian, p)
139		_ = binary.Write(&b, binary.LittleEndian, action)
140	}
141	prevState := make([]byte, b.Len())
142	reqSize := uint32(0)
143	success, err := adjustTokenPrivileges(token, false, &b.Bytes()[0], uint32(len(prevState)), &prevState[0], &reqSize)
144	if !success {
145		return err
146	}
147	if err == ERROR_NOT_ALL_ASSIGNED { //nolint:errorlint // err is Errno
148		return &PrivilegeError{privileges}
149	}
150	return nil
151}
152
153func getPrivilegeName(luid uint64) string {
154	var nameBuffer [256]uint16
155	bufSize := uint32(len(nameBuffer))
156	err := lookupPrivilegeName("", &luid, &nameBuffer[0], &bufSize)
157	if err != nil {
158		return fmt.Sprintf("<unknown privilege %d>", luid)
159	}
160
161	var displayNameBuffer [256]uint16
162	displayBufSize := uint32(len(displayNameBuffer))
163	var langID uint32
164	err = lookupPrivilegeDisplayName("", &nameBuffer[0], &displayNameBuffer[0], &displayBufSize, &langID)
165	if err != nil {
166		return fmt.Sprintf("<unknown privilege %s>", string(utf16.Decode(nameBuffer[:bufSize])))
167	}
168
169	return string(utf16.Decode(displayNameBuffer[:displayBufSize]))
170}
171
172func newThreadToken() (windows.Token, error) {
173	err := impersonateSelf(windows.SecurityImpersonation)
174	if err != nil {
175		return 0, err
176	}
177
178	var token windows.Token
179	err = openThreadToken(getCurrentThread(), windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, false, &token)
180	if err != nil {
181		rerr := revertToSelf()
182		if rerr != nil {
183			panic(rerr)
184		}
185		return 0, err
186	}
187	return token, nil
188}
189
190func releaseThreadToken(h windows.Token) {
191	err := revertToSelf()
192	if err != nil {
193		panic(err)
194	}
195	h.Close()
196}