//go:build windows package main import ( "fmt" "os" "strings" "syscall" "unsafe" "golang.org/x/sys/windows" ) // IsAdmin returns true when the current process token has elevation. // Wraps GetTokenInformation(TokenElevation). func IsAdmin() bool { var token windows.Token if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil { return false } defer token.Close() var elevation uint32 var sz uint32 err := windows.GetTokenInformation( token, windows.TokenElevation, (*byte)(unsafe.Pointer(&elevation)), uint32(unsafe.Sizeof(elevation)), &sz, ) if err != nil { return false } return elevation != 0 } // CmdNeedsAdmin reports whether the given CLI args land in a code path // that requires a WinDivert handle (and therefore admin). The default // (no args = GUI mode) needs admin; explicit subcommands like check, // version, update do not. func CmdNeedsAdmin(args []string) bool { if len(args) == 0 { return true // bare drover.exe → GUI/engine } switch args[0] { case "check", "version", "--version", "-v", "update", "--help", "-h", "help": return false default: return true } } // ReElevate re-launches the current executable with the given args via // ShellExecuteW("runas", ...). On success the caller should os.Exit(0) // immediately. Returns nil even when the user cancels UAC — the caller // can't distinguish; we just exit cleanly afterward. func ReElevate(args []string) error { exe, err := os.Executable() if err != nil { return err } verb, err := syscall.UTF16PtrFromString("runas") if err != nil { return fmt.Errorf("encode verb: %w", err) } exePtr, err := syscall.UTF16PtrFromString(exe) if err != nil { return fmt.Errorf("encode exe: %w", err) } var paramsPtr *uint16 if len(args) > 0 { // Quote each arg in case of spaces, and escape internal quotes. quoted := make([]string, len(args)) for i, a := range args { // Escape any internal quotes with backslash (MSVC argv convention). escaped := strings.ReplaceAll(a, "\"", "\\\"") quoted[i] = `"` + escaped + `"` } joined := "" for i, q := range quoted { if i > 0 { joined += " " } joined += q } paramsPtr, err = syscall.UTF16PtrFromString(joined) if err != nil { return fmt.Errorf("encode params: %w", err) } } cwd, err := os.Getwd() if err != nil { return fmt.Errorf("get cwd: %w", err) } cwdPtr, err := syscall.UTF16PtrFromString(cwd) if err != nil { return fmt.Errorf("encode cwd: %w", err) } // SW_NORMAL = 1 return windows.ShellExecute(0, verb, exePtr, paramsPtr, cwdPtr, 1) }