diff --git a/windows/syscall_windows.go b/windows/syscall_windows.go index fb6cfd046..47dc57967 100644 --- a/windows/syscall_windows.go +++ b/windows/syscall_windows.go @@ -155,6 +155,8 @@ func NewCallbackCDecl(fn interface{}) uintptr { //sys GetModuleFileName(module Handle, filename *uint16, size uint32) (n uint32, err error) = kernel32.GetModuleFileNameW //sys GetModuleHandleEx(flags uint32, moduleName *uint16, module *Handle) (err error) = kernel32.GetModuleHandleExW //sys SetDefaultDllDirectories(directoryFlags uint32) (err error) +//sys AddDllDirectory(path *uint16) (cookie uintptr, err error) = kernel32.AddDllDirectory +//sys RemoveDllDirectory(cookie uintptr) (err error) = kernel32.RemoveDllDirectory //sys SetDllDirectory(path string) (err error) = kernel32.SetDllDirectoryW //sys GetVersion() (ver uint32, err error) //sys FormatMessage(flags uint32, msgsrc uintptr, msgid uint32, langid uint32, buf []uint16, args *byte) (n uint32, err error) = FormatMessageW diff --git a/windows/syscall_windows_test.go b/windows/syscall_windows_test.go index dcc706ded..665837907 100644 --- a/windows/syscall_windows_test.go +++ b/windows/syscall_windows_test.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "os" + "os/exec" "path/filepath" "runtime" "strconv" @@ -1222,3 +1223,55 @@ func TestGetStartupInfo(t *testing.T) { t.Fatalf("GetStartupInfo: got error %v, want nil", err) } } + +func TestAddRemoveDllDirectory(t *testing.T) { + if _, err := exec.LookPath("gcc"); err != nil { + t.Skip("skipping test: gcc is missing") + } + dllSrc := `#include +#include + +uintptr_t beep(void) { + return 5; +}` + tmpdir := t.TempDir() + srcname := "beep.c" + err := os.WriteFile(filepath.Join(tmpdir, srcname), []byte(dllSrc), 0) + if err != nil { + t.Fatal(err) + } + name := "beep.dll" + cmd := exec.Command("gcc", "-shared", "-s", "-Werror", "-o", name, srcname) + cmd.Dir = tmpdir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("failed to build dll: %v - %v", err, string(out)) + } + + if _, err := windows.LoadLibraryEx("beep.dll", 0, windows.LOAD_LIBRARY_SEARCH_USER_DIRS); err == nil { + t.Fatal("LoadLibraryEx unexpectedly found beep.dll") + } + + dllCookie, err := windows.AddDllDirectory(windows.StringToUTF16Ptr(tmpdir)) + if err != nil { + t.Fatalf("AddDllDirectory failed: %s", err) + } + + handle, err := windows.LoadLibraryEx("beep.dll", 0, windows.LOAD_LIBRARY_SEARCH_USER_DIRS) + if err != nil { + t.Fatalf("LoadLibraryEx failed: %s", err) + } + + if err := windows.FreeLibrary(handle); err != nil { + t.Fatalf("FreeLibrary failed: %s", err) + } + + if err := windows.RemoveDllDirectory(dllCookie); err != nil { + t.Fatalf("RemoveDllDirectory failed: %s", err) + } + + _, err = windows.LoadLibraryEx("beep.dll", 0, windows.LOAD_LIBRARY_SEARCH_USER_DIRS) + if err == nil { + t.Fatal("LoadLibraryEx unexpectedly found beep.dll") + } +} diff --git a/windows/zsyscall_windows.go b/windows/zsyscall_windows.go index db6282e00..146a1f019 100644 --- a/windows/zsyscall_windows.go +++ b/windows/zsyscall_windows.go @@ -184,6 +184,7 @@ var ( procGetAdaptersInfo = modiphlpapi.NewProc("GetAdaptersInfo") procGetBestInterfaceEx = modiphlpapi.NewProc("GetBestInterfaceEx") procGetIfEntry = modiphlpapi.NewProc("GetIfEntry") + procAddDllDirectory = modkernel32.NewProc("AddDllDirectory") procAssignProcessToJobObject = modkernel32.NewProc("AssignProcessToJobObject") procCancelIo = modkernel32.NewProc("CancelIo") procCancelIoEx = modkernel32.NewProc("CancelIoEx") @@ -330,6 +331,7 @@ var ( procReadProcessMemory = modkernel32.NewProc("ReadProcessMemory") procReleaseMutex = modkernel32.NewProc("ReleaseMutex") procRemoveDirectoryW = modkernel32.NewProc("RemoveDirectoryW") + procRemoveDllDirectory = modkernel32.NewProc("RemoveDllDirectory") procResetEvent = modkernel32.NewProc("ResetEvent") procResizePseudoConsole = modkernel32.NewProc("ResizePseudoConsole") procResumeThread = modkernel32.NewProc("ResumeThread") @@ -1605,6 +1607,15 @@ func GetIfEntry(pIfRow *MibIfRow) (errcode error) { return } +func AddDllDirectory(path *uint16) (cookie uintptr, err error) { + r0, _, e1 := syscall.Syscall(procAddDllDirectory.Addr(), 1, uintptr(unsafe.Pointer(path)), 0, 0) + cookie = uintptr(r0) + if cookie == 0 { + err = errnoErr(e1) + } + return +} + func AssignProcessToJobObject(job Handle, process Handle) (err error) { r1, _, e1 := syscall.Syscall(procAssignProcessToJobObject.Addr(), 2, uintptr(job), uintptr(process), 0) if r1 == 0 { @@ -2879,6 +2890,14 @@ func RemoveDirectory(path *uint16) (err error) { return } +func RemoveDllDirectory(cookie uintptr) (err error) { + r1, _, e1 := syscall.Syscall(procRemoveDllDirectory.Addr(), 1, uintptr(cookie), 0, 0) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + func ResetEvent(event Handle) (err error) { r1, _, e1 := syscall.Syscall(procResetEvent.Addr(), 1, uintptr(event), 0, 0) if r1 == 0 {