From b96fc173762b7ae09fc8b50ac9d10a18016fdfd4 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Fri, 5 Mar 2021 15:06:08 -0700 Subject: [PATCH 1/6] mod: bump x/sys Signed-off-by: Jason A. Donenfeld --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 0aa27d850..6e691307c 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,6 @@ go 1.16 require ( golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 - golang.org/x/net v0.0.0-20210224082022-3d97a244fca7 - golang.org/x/sys v0.0.0-20210225014209-683adc9d29d7 + golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 + golang.org/x/sys v0.0.0-20210305215415-5cdee2b1b5a0 ) diff --git a/go.sum b/go.sum index 1ccf774ca..733a8f977 100644 --- a/go.sum +++ b/go.sum @@ -2,13 +2,13 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g= golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20210224082022-3d97a244fca7 h1:OgUuv8lsRpBibGNbSizVwKWlysjaNzmC9gYMhPVfqFM= -golang.org/x/net v0.0.0-20210224082022-3d97a244fca7/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210225014209-683adc9d29d7 h1:pk3Y+QnSKjMLfO/HIqzn/Zvv3/IHjRPhwblrmUuodzw= -golang.org/x/sys v0.0.0-20210225014209-683adc9d29d7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210305215415-5cdee2b1b5a0 h1:MOJR6AyRlIYMexU2acorBot1aPks0cBDOyUA4hFlBhE= +golang.org/x/sys v0.0.0-20210305215415-5cdee2b1b5a0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= From 1bc997de1762accee54c19e2b9410be4961bb0c0 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sat, 6 Mar 2021 09:20:46 -0700 Subject: [PATCH 2/6] conn: linux: unexport mutex Signed-off-by: Jason A. Donenfeld --- conn/bind_linux.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/conn/bind_linux.go b/conn/bind_linux.go index 419980951..70ea609d6 100644 --- a/conn/bind_linux.go +++ b/conn/bind_linux.go @@ -27,7 +27,7 @@ type ipv6Source struct { } type LinuxSocketEndpoint struct { - sync.Mutex + mu sync.Mutex dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte src [unsafe.Sizeof(ipv6Source{})]byte isV6 bool @@ -450,9 +450,9 @@ func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error { }, } - end.Lock() + end.mu.Lock() _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.Unlock() + end.mu.Unlock() if err == nil { return nil @@ -463,9 +463,9 @@ func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error { if err == unix.EINVAL { end.ClearSrc() cmsg.pktinfo = unix.Inet4Pktinfo{} - end.Lock() + end.mu.Lock() _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.Unlock() + end.mu.Unlock() } return err @@ -494,9 +494,9 @@ func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error { cmsg.pktinfo.Ifindex = 0 } - end.Lock() + end.mu.Lock() _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.Unlock() + end.mu.Unlock() if err == nil { return nil @@ -507,9 +507,9 @@ func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error { if err == unix.EINVAL { end.ClearSrc() cmsg.pktinfo = unix.Inet6Pktinfo{} - end.Lock() + end.mu.Lock() _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.Unlock() + end.mu.Unlock() } return err From 4b7c180b78457dbb479d71973b9ec7b307bc8097 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Wed, 3 Mar 2021 14:38:26 +0100 Subject: [PATCH 3/6] memmod: do not use IsBadReadPtr It should be enough to check for the trailing zero name. Signed-off-by: Jason A. Donenfeld --- tun/wintun/memmod/memmod_windows.go | 2 +- tun/wintun/memmod/mksyscall.go | 8 ----- tun/wintun/memmod/syscall_windows.go | 2 -- tun/wintun/memmod/zsyscall_windows.go | 50 --------------------------- 4 files changed, 1 insertion(+), 61 deletions(-) delete mode 100644 tun/wintun/memmod/mksyscall.go delete mode 100644 tun/wintun/memmod/zsyscall_windows.go diff --git a/tun/wintun/memmod/memmod_windows.go b/tun/wintun/memmod/memmod_windows.go index a9514c400..c75de5ade 100644 --- a/tun/wintun/memmod/memmod_windows.go +++ b/tun/wintun/memmod/memmod_windows.go @@ -312,7 +312,7 @@ func (module *Module) buildImportTable() error { module.modules = make([]windows.Handle, 0, 16) importDesc := (*IMAGE_IMPORT_DESCRIPTOR)(a2p(module.codeBase + uintptr(directory.VirtualAddress))) - for !isBadReadPtr(uintptr(unsafe.Pointer(importDesc)), unsafe.Sizeof(*importDesc)) && importDesc.Name != 0 { + for importDesc.Name != 0 { handle, err := windows.LoadLibraryEx(windows.BytePtrToString((*byte)(a2p(module.codeBase+uintptr(importDesc.Name)))), 0, windows.LOAD_LIBRARY_SEARCH_SYSTEM32) if err != nil { return fmt.Errorf("Error loading module: %w", err) diff --git a/tun/wintun/memmod/mksyscall.go b/tun/wintun/memmod/mksyscall.go deleted file mode 100644 index a78f613e9..000000000 --- a/tun/wintun/memmod/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package memmod - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go diff --git a/tun/wintun/memmod/syscall_windows.go b/tun/wintun/memmod/syscall_windows.go index 11715c03b..31dd0b5d9 100644 --- a/tun/wintun/memmod/syscall_windows.go +++ b/tun/wintun/memmod/syscall_windows.go @@ -324,8 +324,6 @@ const ( DLL_PROCESS_DETACH = 0 ) -//sys isBadReadPtr(addr uintptr, ucb uintptr) (ret bool) = kernel32.IsBadReadPtr - type SYSTEM_INFO struct { ProcessorArchitecture uint16 Reserved uint16 diff --git a/tun/wintun/memmod/zsyscall_windows.go b/tun/wintun/memmod/zsyscall_windows.go deleted file mode 100644 index 6a5b76f59..000000000 --- a/tun/wintun/memmod/zsyscall_windows.go +++ /dev/null @@ -1,50 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package memmod - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) - errERROR_EINVAL error = syscall.EINVAL -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return errERROR_EINVAL - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - - procIsBadReadPtr = modkernel32.NewProc("IsBadReadPtr") -) - -func isBadReadPtr(addr uintptr, ucb uintptr) (ret bool) { - r0, _, _ := syscall.Syscall(procIsBadReadPtr.Addr(), 2, uintptr(addr), uintptr(ucb), 0) - ret = r0 != 0 - return -} From c1b1fd4b60b91f23675b784d0d59a87b2c685577 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Wed, 3 Mar 2021 15:05:19 +0100 Subject: [PATCH 4/6] memmod: use resource functions from x/sys Signed-off-by: Jason A. Donenfeld --- tun/wintun/dll_fromrsrc_windows.go | 5 +- tun/wintun/resource/mksyscall.go | 8 -- tun/wintun/resource/resource_windows.go | 143 ------------------------ tun/wintun/resource/zsyscall_windows.go | 112 ------------------- 4 files changed, 2 insertions(+), 266 deletions(-) delete mode 100644 tun/wintun/resource/mksyscall.go delete mode 100644 tun/wintun/resource/resource_windows.go delete mode 100644 tun/wintun/resource/zsyscall_windows.go diff --git a/tun/wintun/dll_fromrsrc_windows.go b/tun/wintun/dll_fromrsrc_windows.go index d107ba98b..dc70486fd 100644 --- a/tun/wintun/dll_fromrsrc_windows.go +++ b/tun/wintun/dll_fromrsrc_windows.go @@ -16,7 +16,6 @@ import ( "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/tun/wintun/memmod" - "golang.zx2c4.com/wireguard/tun/wintun/resource" ) type lazyDLL struct { @@ -37,11 +36,11 @@ func (d *lazyDLL) Load() error { } const ourModule windows.Handle = 0 - resInfo, err := resource.FindByName(ourModule, d.Name, resource.RT_RCDATA) + resInfo, err := windows.FindResource(ourModule, d.Name, windows.RT_RCDATA) if err != nil { return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err) } - data, err := resource.Load(ourModule, resInfo) + data, err := windows.LoadResourceData(ourModule, resInfo) if err != nil { return fmt.Errorf("Unable to load resource: %w", err) } diff --git a/tun/wintun/resource/mksyscall.go b/tun/wintun/resource/mksyscall.go deleted file mode 100644 index 2013f24f2..000000000 --- a/tun/wintun/resource/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package resource - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go resource_windows.go diff --git a/tun/wintun/resource/resource_windows.go b/tun/wintun/resource/resource_windows.go deleted file mode 100644 index f63751830..000000000 --- a/tun/wintun/resource/resource_windows.go +++ /dev/null @@ -1,143 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package resource - -import ( - "errors" - "fmt" - "unsafe" - - "golang.org/x/sys/windows" -) - -func MAKEINTRESOURCE(i uint16) *uint16 { - return (*uint16)(unsafe.Pointer(uintptr(i))) -} - -// Predefined Resource Types -var ( - VS_VERSION_INFO uint16 = 1 - - RT_CURSOR = MAKEINTRESOURCE(1) - RT_BITMAP = MAKEINTRESOURCE(2) - RT_ICON = MAKEINTRESOURCE(3) - RT_MENU = MAKEINTRESOURCE(4) - RT_DIALOG = MAKEINTRESOURCE(5) - RT_STRING = MAKEINTRESOURCE(6) - RT_FONTDIR = MAKEINTRESOURCE(7) - RT_FONT = MAKEINTRESOURCE(8) - RT_ACCELERATOR = MAKEINTRESOURCE(9) - RT_RCDATA = MAKEINTRESOURCE(10) - RT_MESSAGETABLE = MAKEINTRESOURCE(11) - RT_GROUP_CURSOR = MAKEINTRESOURCE(12) - RT_GROUP_ICON = MAKEINTRESOURCE(14) - RT_VERSION = MAKEINTRESOURCE(16) - RT_DLGINCLUDE = MAKEINTRESOURCE(17) - RT_PLUGPLAY = MAKEINTRESOURCE(19) - RT_VXD = MAKEINTRESOURCE(20) - RT_ANICURSOR = MAKEINTRESOURCE(21) - RT_ANIICON = MAKEINTRESOURCE(22) - RT_HTML = MAKEINTRESOURCE(23) - RT_MANIFEST = MAKEINTRESOURCE(24) - CREATEPROCESS_MANIFEST_RESOURCE_ID = MAKEINTRESOURCE(1) - ISOLATIONAWARE_MANIFEST_RESOURCE_ID = MAKEINTRESOURCE(2) - ISOLATIONAWARE_NOSTATICIMPORT_MANIFEST_RESOURCE_ID = MAKEINTRESOURCE(3) - ISOLATIONPOLICY_MANIFEST_RESOURCE_ID = MAKEINTRESOURCE(4) - ISOLATIONPOLICY_BROWSER_MANIFEST_RESOURCE_ID = MAKEINTRESOURCE(5) - MINIMUM_RESERVED_MANIFEST_RESOURCE_ID = MAKEINTRESOURCE(1 /*inclusive*/) - MAXIMUM_RESERVED_MANIFEST_RESOURCE_ID = MAKEINTRESOURCE(16 /*inclusive*/) -) - -//sys findResource(module windows.Handle, name *uint16, resType *uint16) (resInfo windows.Handle, err error) = kernel32.FindResourceW - -func FindByID(module windows.Handle, id uint16, resType *uint16) (resInfo windows.Handle, err error) { - return findResource(module, MAKEINTRESOURCE(id), resType) -} - -func FindByName(module windows.Handle, name string, resType *uint16) (resInfo windows.Handle, err error) { - var name16 *uint16 - name16, err = windows.UTF16PtrFromString(name) - if err != nil { - return - } - resInfo, err = findResource(module, name16, resType) - return -} - -//sys sizeofResource(module windows.Handle, resInfo windows.Handle) (size uint32, err error) = kernel32.SizeofResource -//sys loadResource(module windows.Handle, resInfo windows.Handle) (resData windows.Handle, err error) = kernel32.LoadResource -//sys lockResource(resData windows.Handle) (addr uintptr, err error) = kernel32.LockResource - -func Load(module, resInfo windows.Handle) (data []byte, err error) { - size, err := sizeofResource(module, resInfo) - if err != nil { - err = fmt.Errorf("Unable to size resource: %w", err) - return - } - resData, err := loadResource(module, resInfo) - if err != nil { - err = fmt.Errorf("Unable to load resource: %w", err) - return - } - ptr, err := lockResource(resData) - if err != nil { - err = fmt.Errorf("Unable to lock resource: %w", err) - return - } - unsafeSlice(unsafe.Pointer(&data), unsafe.Pointer(ptr), int(size)) - return -} - -type VS_FIXEDFILEINFO struct { - Signature uint32 - StrucVersion uint32 - FileVersionMS uint32 - FileVersionLS uint32 - ProductVersionMS uint32 - ProductVersionLS uint32 - FileFlagsMask uint32 - FileFlags uint32 - FileOS uint32 - FileType uint32 - FileSubtype uint32 - FileDateMS uint32 - FileDateLS uint32 -} - -//sys verQueryValue(block *byte, section *uint16, value **byte, size *uint32) (err error) = version.VerQueryValueW - -func VerQueryRootValue(block []byte) (ffi *VS_FIXEDFILEINFO, err error) { - var data *byte - var size uint32 - err = verQueryValue(&block[0], windows.StringToUTF16Ptr("\\"), &data, &size) - if err != nil { - return - } - if uintptr(size) < unsafe.Sizeof(VS_FIXEDFILEINFO{}) { - err = errors.New("Incomplete VS_FIXEDFILEINFO") - return - } - ffi = (*VS_FIXEDFILEINFO)(unsafe.Pointer(data)) - return -} - -// unsafeSlice updates the slice slicePtr to be a slice -// referencing the provided data with its length & capacity set to -// lenCap. -// -// TODO: when Go 1.16 or Go 1.17 is the minimum supported version, -// update callers to use unsafe.Slice instead of this. -func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) { - type sliceHeader struct { - Data unsafe.Pointer - Len int - Cap int - } - h := (*sliceHeader)(slicePtr) - h.Data = data - h.Len = lenCap - h.Cap = lenCap -} diff --git a/tun/wintun/resource/zsyscall_windows.go b/tun/wintun/resource/zsyscall_windows.go deleted file mode 100644 index e4c4bf1c3..000000000 --- a/tun/wintun/resource/zsyscall_windows.go +++ /dev/null @@ -1,112 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package resource - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - modversion = windows.NewLazySystemDLL("version.dll") - - procFindResourceW = modkernel32.NewProc("FindResourceW") - procSizeofResource = modkernel32.NewProc("SizeofResource") - procLoadResource = modkernel32.NewProc("LoadResource") - procLockResource = modkernel32.NewProc("LockResource") - procVerQueryValueW = modversion.NewProc("VerQueryValueW") -) - -func findResource(module windows.Handle, name *uint16, resType *uint16) (resInfo windows.Handle, err error) { - r0, _, e1 := syscall.Syscall(procFindResourceW.Addr(), 3, uintptr(module), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(resType))) - resInfo = windows.Handle(r0) - if resInfo == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func sizeofResource(module windows.Handle, resInfo windows.Handle) (size uint32, err error) { - r0, _, e1 := syscall.Syscall(procSizeofResource.Addr(), 2, uintptr(module), uintptr(resInfo), 0) - size = uint32(r0) - if size == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func loadResource(module windows.Handle, resInfo windows.Handle) (resData windows.Handle, err error) { - r0, _, e1 := syscall.Syscall(procLoadResource.Addr(), 2, uintptr(module), uintptr(resInfo), 0) - resData = windows.Handle(r0) - if resData == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func lockResource(resData windows.Handle) (addr uintptr, err error) { - r0, _, e1 := syscall.Syscall(procLockResource.Addr(), 1, uintptr(resData), 0, 0) - addr = uintptr(r0) - if addr == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func verQueryValue(block *byte, section *uint16, value **byte, size *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procVerQueryValueW.Addr(), 4, uintptr(unsafe.Pointer(block)), uintptr(unsafe.Pointer(section)), uintptr(unsafe.Pointer(value)), uintptr(unsafe.Pointer(size)), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} From 7e3b8371a1bf277224df9e17d4075f2ee51b1be6 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Wed, 3 Mar 2021 12:26:59 +0100 Subject: [PATCH 5/6] winpipe: move syscalls into x/sys Signed-off-by: Jason A. Donenfeld --- ipc/uapi_windows.go | 4 +- ipc/winpipe/file.go | 121 ++---- ipc/winpipe/mksyscall.go | 9 - ipc/winpipe/pipe.go | 509 ------------------------- ipc/winpipe/winpipe.go | 474 +++++++++++++++++++++++ ipc/winpipe/winpipe_test.go | 656 ++++++++++++++++++++++++++++++++ ipc/winpipe/zsyscall_windows.go | 238 ------------ 7 files changed, 1174 insertions(+), 837 deletions(-) delete mode 100644 ipc/winpipe/mksyscall.go delete mode 100644 ipc/winpipe/pipe.go create mode 100644 ipc/winpipe/winpipe.go create mode 100644 ipc/winpipe/winpipe_test.go delete mode 100644 ipc/winpipe/zsyscall_windows.go diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go index 164b7cb1b..3e2709cee 100644 --- a/ipc/uapi_windows.go +++ b/ipc/uapi_windows.go @@ -62,10 +62,10 @@ func init() { } func UAPIListen(name string) (net.Listener, error) { - config := winpipe.PipeConfig{ + config := winpipe.ListenConfig{ SecurityDescriptor: UAPISecurityDescriptor, } - listener, err := winpipe.ListenPipe(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config) + listener, err := winpipe.Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config) if err != nil { return nil, err } diff --git a/ipc/winpipe/file.go b/ipc/winpipe/file.go index f3b768f63..0c9abb140 100644 --- a/ipc/winpipe/file.go +++ b/ipc/winpipe/file.go @@ -5,54 +5,21 @@ * Copyright (C) 2005 Microsoft * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. */ + package winpipe import ( - "errors" "io" + "os" "runtime" "sync" "sync/atomic" "time" + "unsafe" "golang.org/x/sys/windows" ) -//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx -//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort -//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus -//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes -//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult - -type atomicBool int32 - -func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } -func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } -func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } -func (b *atomicBool) swap(new bool) bool { - var newInt int32 - if new { - newInt = 1 - } - return atomic.SwapInt32((*int32)(b), newInt) == 1 -} - -const ( - cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1 - cFILE_SKIP_SET_EVENT_ON_HANDLE = 2 -) - -var ( - ErrFileClosed = errors.New("file has already been closed") - ErrTimeout = &timeoutError{} -) - -type timeoutError struct{} - -func (e *timeoutError) Error() string { return "i/o timeout" } -func (e *timeoutError) Timeout() bool { return true } -func (e *timeoutError) Temporary() bool { return true } - type timeoutChan chan struct{} var ioInitOnce sync.Once @@ -71,7 +38,7 @@ type ioOperation struct { } func initIo() { - h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff) + h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) if err != nil { panic(err) } @@ -79,13 +46,13 @@ func initIo() { go ioCompletionProcessor(h) } -// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. +// file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. // It takes ownership of this handle and will close it if it is garbage collected. -type win32File struct { +type file struct { handle windows.Handle wg sync.WaitGroup wgLock sync.RWMutex - closing atomicBool + closing uint32 // used as atomic boolean socket bool readDeadline deadlineHandler writeDeadline deadlineHandler @@ -96,18 +63,18 @@ type deadlineHandler struct { channel timeoutChan channelLock sync.RWMutex timer *time.Timer - timedout atomicBool + timedout uint32 // used as atomic boolean } -// makeWin32File makes a new win32File from an existing file handle -func makeWin32File(h windows.Handle) (*win32File, error) { - f := &win32File{handle: h} +// makeFile makes a new file from an existing file handle +func makeFile(h windows.Handle) (*file, error) { + f := &file{handle: h} ioInitOnce.Do(initIo) - _, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff) + _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0) if err != nil { return nil, err } - err = setFileCompletionNotificationModes(h, cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS|cFILE_SKIP_SET_EVENT_ON_HANDLE) + err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE) if err != nil { return nil, err } @@ -116,18 +83,14 @@ func makeWin32File(h windows.Handle) (*win32File, error) { return f, nil } -func MakeOpenFile(h windows.Handle) (io.ReadWriteCloser, error) { - return makeWin32File(h) -} - // closeHandle closes the resources associated with a Win32 handle -func (f *win32File) closeHandle() { +func (f *file) closeHandle() { f.wgLock.Lock() // Atomically set that we are closing, releasing the resources only once. - if !f.closing.swap(true) { + if atomic.SwapUint32(&f.closing, 1) == 0 { f.wgLock.Unlock() // cancel all IO and wait for it to complete - cancelIoEx(f.handle, nil) + windows.CancelIoEx(f.handle, nil) f.wg.Wait() // at this point, no new IO can start windows.Close(f.handle) @@ -137,19 +100,19 @@ func (f *win32File) closeHandle() { } } -// Close closes a win32File. -func (f *win32File) Close() error { +// Close closes a file. +func (f *file) Close() error { f.closeHandle() return nil } // prepareIo prepares for a new IO operation. // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. -func (f *win32File) prepareIo() (*ioOperation, error) { +func (f *file) prepareIo() (*ioOperation, error) { f.wgLock.RLock() - if f.closing.isSet() { + if atomic.LoadUint32(&f.closing) == 1 { f.wgLock.RUnlock() - return nil, ErrFileClosed + return nil, os.ErrClosed } f.wg.Add(1) f.wgLock.RUnlock() @@ -164,7 +127,7 @@ func ioCompletionProcessor(h windows.Handle) { var bytes uint32 var key uintptr var op *ioOperation - err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE) + err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE) if op == nil { panic(err) } @@ -174,13 +137,13 @@ func ioCompletionProcessor(h windows.Handle) { // asyncIo processes the return value from ReadFile or WriteFile, blocking until // the operation has actually completed. -func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { +func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { if err != windows.ERROR_IO_PENDING { return int(bytes), err } - if f.closing.isSet() { - cancelIoEx(f.handle, &c.o) + if atomic.LoadUint32(&f.closing) == 1 { + windows.CancelIoEx(f.handle, &c.o) } var timeout timeoutChan @@ -195,20 +158,20 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er case r = <-c.ch: err = r.err if err == windows.ERROR_OPERATION_ABORTED { - if f.closing.isSet() { - err = ErrFileClosed + if atomic.LoadUint32(&f.closing) == 1 { + err = os.ErrClosed } } else if err != nil && f.socket { // err is from Win32. Query the overlapped structure to get the winsock error. var bytes, flags uint32 - err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) + err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) } case <-timeout: - cancelIoEx(f.handle, &c.o) + windows.CancelIoEx(f.handle, &c.o) r = <-c.ch err = r.err if err == windows.ERROR_OPERATION_ABORTED { - err = ErrTimeout + err = os.ErrDeadlineExceeded } } @@ -220,15 +183,15 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er } // Read reads from a file handle. -func (f *win32File) Read(b []byte) (int, error) { +func (f *file) Read(b []byte) (int, error) { c, err := f.prepareIo() if err != nil { return 0, err } defer f.wg.Done() - if f.readDeadline.timedout.isSet() { - return 0, ErrTimeout + if atomic.LoadUint32(&f.readDeadline.timedout) == 1 { + return 0, os.ErrDeadlineExceeded } var bytes uint32 @@ -247,15 +210,15 @@ func (f *win32File) Read(b []byte) (int, error) { } // Write writes to a file handle. -func (f *win32File) Write(b []byte) (int, error) { +func (f *file) Write(b []byte) (int, error) { c, err := f.prepareIo() if err != nil { return 0, err } defer f.wg.Done() - if f.writeDeadline.timedout.isSet() { - return 0, ErrTimeout + if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 { + return 0, os.ErrDeadlineExceeded } var bytes uint32 @@ -265,19 +228,19 @@ func (f *win32File) Write(b []byte) (int, error) { return n, err } -func (f *win32File) SetReadDeadline(deadline time.Time) error { +func (f *file) SetReadDeadline(deadline time.Time) error { return f.readDeadline.set(deadline) } -func (f *win32File) SetWriteDeadline(deadline time.Time) error { +func (f *file) SetWriteDeadline(deadline time.Time) error { return f.writeDeadline.set(deadline) } -func (f *win32File) Flush() error { +func (f *file) Flush() error { return windows.FlushFileBuffers(f.handle) } -func (f *win32File) Fd() uintptr { +func (f *file) Fd() uintptr { return uintptr(f.handle) } @@ -291,7 +254,7 @@ func (d *deadlineHandler) set(deadline time.Time) error { } d.timer = nil } - d.timedout.setFalse() + atomic.StoreUint32(&d.timedout, 0) select { case <-d.channel: @@ -306,7 +269,7 @@ func (d *deadlineHandler) set(deadline time.Time) error { } timeoutIO := func() { - d.timedout.setTrue() + atomic.StoreUint32(&d.timedout, 1) close(d.channel) } diff --git a/ipc/winpipe/mksyscall.go b/ipc/winpipe/mksyscall.go deleted file mode 100644 index a87e9298e..000000000 --- a/ipc/winpipe/mksyscall.go +++ /dev/null @@ -1,9 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package winpipe - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go pipe.go file.go diff --git a/ipc/winpipe/pipe.go b/ipc/winpipe/pipe.go deleted file mode 100644 index e609274f5..000000000 --- a/ipc/winpipe/pipe.go +++ /dev/null @@ -1,509 +0,0 @@ -// +build windows - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package winpipe - -import ( - "context" - "errors" - "fmt" - "io" - "net" - "os" - "runtime" - "time" - "unsafe" - - "golang.org/x/sys/windows" -) - -//sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe -//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW -//sys createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW -//sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo -//sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW -//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc -//sys ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile -//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb -//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U -//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl - -type ioStatusBlock struct { - Status, Information uintptr -} - -type objectAttributes struct { - Length uintptr - RootDirectory uintptr - ObjectName *unicodeString - Attributes uintptr - SecurityDescriptor *windows.SECURITY_DESCRIPTOR - SecurityQoS uintptr -} - -type unicodeString struct { - Length uint16 - MaximumLength uint16 - Buffer uintptr -} - -type ntstatus int32 - -func (status ntstatus) Err() error { - if status >= 0 { - return nil - } - return rtlNtStatusToDosError(status) -} - -const ( - cSECURITY_SQOS_PRESENT = 0x100000 - cSECURITY_ANONYMOUS = 0 - - cPIPE_TYPE_MESSAGE = 4 - - cPIPE_READMODE_MESSAGE = 2 - - cFILE_OPEN = 1 - cFILE_CREATE = 2 - - cFILE_PIPE_MESSAGE_TYPE = 1 - cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2 -) - -var ( - // ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed. - // This error should match net.errClosing since docker takes a dependency on its text. - ErrPipeListenerClosed = errors.New("use of closed network connection") - - errPipeWriteClosed = errors.New("pipe has been closed for write") -) - -type win32Pipe struct { - *win32File - path string -} - -type win32MessageBytePipe struct { - win32Pipe - writeClosed bool - readEOF bool -} - -type pipeAddress string - -func (f *win32Pipe) LocalAddr() net.Addr { - return pipeAddress(f.path) -} - -func (f *win32Pipe) RemoteAddr() net.Addr { - return pipeAddress(f.path) -} - -func (f *win32Pipe) SetDeadline(t time.Time) error { - f.SetReadDeadline(t) - f.SetWriteDeadline(t) - return nil -} - -// CloseWrite closes the write side of a message pipe in byte mode. -func (f *win32MessageBytePipe) CloseWrite() error { - if f.writeClosed { - return errPipeWriteClosed - } - err := f.win32File.Flush() - if err != nil { - return err - } - _, err = f.win32File.Write(nil) - if err != nil { - return err - } - f.writeClosed = true - return nil -} - -// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since -// they are used to implement CloseWrite(). -func (f *win32MessageBytePipe) Write(b []byte) (int, error) { - if f.writeClosed { - return 0, errPipeWriteClosed - } - if len(b) == 0 { - return 0, nil - } - return f.win32File.Write(b) -} - -// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message -// mode pipe will return io.EOF, as will all subsequent reads. -func (f *win32MessageBytePipe) Read(b []byte) (int, error) { - if f.readEOF { - return 0, io.EOF - } - n, err := f.win32File.Read(b) - if err == io.EOF { - // If this was the result of a zero-byte read, then - // it is possible that the read was due to a zero-size - // message. Since we are simulating CloseWrite with a - // zero-byte message, ensure that all future Read() calls - // also return EOF. - f.readEOF = true - } else if err == windows.ERROR_MORE_DATA { - // ERROR_MORE_DATA indicates that the pipe's read mode is message mode - // and the message still has more bytes. Treat this as a success, since - // this package presents all named pipes as byte streams. - err = nil - } - return n, err -} - -func (s pipeAddress) Network() string { - return "pipe" -} - -func (s pipeAddress) String() string { - return string(s) -} - -// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout. -func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) { - for { - select { - case <-ctx.Done(): - return windows.Handle(0), ctx.Err() - default: - h, err := createFile(*path, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0) - if err == nil { - return h, nil - } - if err != windows.ERROR_PIPE_BUSY { - return h, &os.PathError{Err: err, Op: "open", Path: *path} - } - // Wait 10 msec and try again. This is a rather simplistic - // view, as we always try each 10 milliseconds. - time.Sleep(time.Millisecond * 10) - } - } -} - -// DialPipe connects to a named pipe by path, timing out if the connection -// takes longer than the specified duration. If timeout is nil, then we use -// a default timeout of 2 seconds. (We do not use WaitNamedPipe.) -func DialPipe(path string, timeout *time.Duration, expectedOwner *windows.SID) (net.Conn, error) { - var absTimeout time.Time - if timeout != nil { - absTimeout = time.Now().Add(*timeout) - } else { - absTimeout = time.Now().Add(time.Second * 2) - } - ctx, _ := context.WithDeadline(context.Background(), absTimeout) - conn, err := DialPipeContext(ctx, path, expectedOwner) - if err == context.DeadlineExceeded { - return nil, ErrTimeout - } - return conn, err -} - -// DialPipeContext attempts to connect to a named pipe by `path` until `ctx` -// cancellation or timeout. -func DialPipeContext(ctx context.Context, path string, expectedOwner *windows.SID) (net.Conn, error) { - var err error - var h windows.Handle - h, err = tryDialPipe(ctx, &path) - if err != nil { - return nil, err - } - - if expectedOwner != nil { - sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) - if err != nil { - windows.Close(h) - return nil, err - } - realOwner, _, err := sd.Owner() - if err != nil { - windows.Close(h) - return nil, err - } - if !realOwner.Equals(expectedOwner) { - windows.Close(h) - return nil, windows.ERROR_ACCESS_DENIED - } - } - - var flags uint32 - err = getNamedPipeInfo(h, &flags, nil, nil, nil) - if err != nil { - windows.Close(h) - return nil, err - } - - f, err := makeWin32File(h) - if err != nil { - windows.Close(h) - return nil, err - } - - // If the pipe is in message mode, return a message byte pipe, which - // supports CloseWrite(). - if flags&cPIPE_TYPE_MESSAGE != 0 { - return &win32MessageBytePipe{ - win32Pipe: win32Pipe{win32File: f, path: path}, - }, nil - } - return &win32Pipe{win32File: f, path: path}, nil -} - -type acceptResponse struct { - f *win32File - err error -} - -type win32PipeListener struct { - firstHandle windows.Handle - path string - config PipeConfig - acceptCh chan (chan acceptResponse) - closeCh chan int - doneCh chan int -} - -func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *PipeConfig, first bool) (windows.Handle, error) { - path16, err := windows.UTF16FromString(path) - if err != nil { - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - - var oa objectAttributes - oa.Length = unsafe.Sizeof(oa) - - var ntPath unicodeString - if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil { - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - defer windows.LocalFree(windows.Handle(ntPath.Buffer)) - oa.ObjectName = &ntPath - - // The security descriptor is only needed for the first pipe. - if first { - if sd != nil { - oa.SecurityDescriptor = sd - } else { - // Construct the default named pipe security descriptor. - var dacl uintptr - if err := rtlDefaultNpAcl(&dacl).Err(); err != nil { - return 0, fmt.Errorf("getting default named pipe ACL: %s", err) - } - defer windows.LocalFree(windows.Handle(dacl)) - sd, err := windows.NewSecurityDescriptor() - if err != nil { - return 0, fmt.Errorf("creating new security descriptor: %s", err) - } - if err = sd.SetDACL((*windows.ACL)(unsafe.Pointer(dacl)), true, false); err != nil { - return 0, fmt.Errorf("assigning dacl: %s", err) - } - sd, err = sd.ToSelfRelative() - if err != nil { - return 0, fmt.Errorf("converting to self-relative: %s", err) - } - oa.SecurityDescriptor = sd - } - } - - typ := uint32(cFILE_PIPE_REJECT_REMOTE_CLIENTS) - if c.MessageMode { - typ |= cFILE_PIPE_MESSAGE_TYPE - } - - disposition := uint32(cFILE_OPEN) - access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) - if first { - disposition = cFILE_CREATE - // By not asking for read or write access, the named pipe file system - // will put this pipe into an initially disconnected state, blocking - // client connections until the next call with first == false. - access = windows.SYNCHRONIZE - } - - timeout := int64(-50 * 10000) // 50ms - - var ( - h windows.Handle - iosb ioStatusBlock - ) - err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err() - if err != nil { - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - - runtime.KeepAlive(ntPath) - return h, nil -} - -func (l *win32PipeListener) makeServerPipe() (*win32File, error) { - h, err := makeServerPipeHandle(l.path, nil, &l.config, false) - if err != nil { - return nil, err - } - f, err := makeWin32File(h) - if err != nil { - windows.Close(h) - return nil, err - } - return f, nil -} - -func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) { - p, err := l.makeServerPipe() - if err != nil { - return nil, err - } - - // Wait for the client to connect. - ch := make(chan error) - go func(p *win32File) { - ch <- connectPipe(p) - }(p) - - select { - case err = <-ch: - if err != nil { - p.Close() - p = nil - } - case <-l.closeCh: - // Abort the connect request by closing the handle. - p.Close() - p = nil - err = <-ch - if err == nil || err == ErrFileClosed { - err = ErrPipeListenerClosed - } - } - return p, err -} - -func (l *win32PipeListener) listenerRoutine() { - closed := false - for !closed { - select { - case <-l.closeCh: - closed = true - case responseCh := <-l.acceptCh: - var ( - p *win32File - err error - ) - for { - p, err = l.makeConnectedServerPipe() - // If the connection was immediately closed by the client, try - // again. - if err != windows.ERROR_NO_DATA { - break - } - } - responseCh <- acceptResponse{p, err} - closed = err == ErrPipeListenerClosed - } - } - windows.Close(l.firstHandle) - l.firstHandle = 0 - // Notify Close() and Accept() callers that the handle has been closed. - close(l.doneCh) -} - -// PipeConfig contain configuration for the pipe listener. -type PipeConfig struct { - // SecurityDescriptor contains a Windows security descriptor. - SecurityDescriptor *windows.SECURITY_DESCRIPTOR - - // MessageMode determines whether the pipe is in byte or message mode. In either - // case the pipe is read in byte mode by default. The only practical difference in - // this implementation is that CloseWrite() is only supported for message mode pipes; - // CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only - // transferred to the reader (and returned as io.EOF in this implementation) - // when the pipe is in message mode. - MessageMode bool - - // InputBufferSize specifies the size the input buffer, in bytes. - InputBufferSize int32 - - // OutputBufferSize specifies the size the input buffer, in bytes. - OutputBufferSize int32 -} - -// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe. -// The pipe must not already exist. -func ListenPipe(path string, c *PipeConfig) (net.Listener, error) { - if c == nil { - c = &PipeConfig{} - } - h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) - if err != nil { - return nil, err - } - l := &win32PipeListener{ - firstHandle: h, - path: path, - config: *c, - acceptCh: make(chan (chan acceptResponse)), - closeCh: make(chan int), - doneCh: make(chan int), - } - go l.listenerRoutine() - return l, nil -} - -func connectPipe(p *win32File) error { - c, err := p.prepareIo() - if err != nil { - return err - } - defer p.wg.Done() - - err = connectNamedPipe(p.handle, &c.o) - _, err = p.asyncIo(c, nil, 0, err) - if err != nil && err != windows.ERROR_PIPE_CONNECTED { - return err - } - return nil -} - -func (l *win32PipeListener) Accept() (net.Conn, error) { - ch := make(chan acceptResponse) - select { - case l.acceptCh <- ch: - response := <-ch - err := response.err - if err != nil { - return nil, err - } - if l.config.MessageMode { - return &win32MessageBytePipe{ - win32Pipe: win32Pipe{win32File: response.f, path: l.path}, - }, nil - } - return &win32Pipe{win32File: response.f, path: l.path}, nil - case <-l.doneCh: - return nil, ErrPipeListenerClosed - } -} - -func (l *win32PipeListener) Close() error { - select { - case l.closeCh <- 1: - <-l.doneCh - case <-l.doneCh: - } - return nil -} - -func (l *win32PipeListener) Addr() net.Addr { - return pipeAddress(l.path) -} diff --git a/ipc/winpipe/winpipe.go b/ipc/winpipe/winpipe.go new file mode 100644 index 000000000..f02f3d884 --- /dev/null +++ b/ipc/winpipe/winpipe.go @@ -0,0 +1,474 @@ +// +build windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2005 Microsoft + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +// Package winpipe implements a net.Conn and net.Listener around Windows named pipes. +package winpipe + +import ( + "context" + "io" + "net" + "os" + "runtime" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +type pipe struct { + *file + path string +} + +type messageBytePipe struct { + pipe + writeClosed bool + readEOF bool +} + +type pipeAddress string + +func (f *pipe) LocalAddr() net.Addr { + return pipeAddress(f.path) +} + +func (f *pipe) RemoteAddr() net.Addr { + return pipeAddress(f.path) +} + +func (f *pipe) SetDeadline(t time.Time) error { + f.SetReadDeadline(t) + f.SetWriteDeadline(t) + return nil +} + +// CloseWrite closes the write side of a message pipe in byte mode. +func (f *messageBytePipe) CloseWrite() error { + if f.writeClosed { + return io.ErrClosedPipe + } + err := f.file.Flush() + if err != nil { + return err + } + _, err = f.file.Write(nil) + if err != nil { + return err + } + f.writeClosed = true + return nil +} + +// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since +// they are used to implement CloseWrite. +func (f *messageBytePipe) Write(b []byte) (int, error) { + if f.writeClosed { + return 0, io.ErrClosedPipe + } + if len(b) == 0 { + return 0, nil + } + return f.file.Write(b) +} + +// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message +// mode pipe will return io.EOF, as will all subsequent reads. +func (f *messageBytePipe) Read(b []byte) (int, error) { + if f.readEOF { + return 0, io.EOF + } + n, err := f.file.Read(b) + if err == io.EOF { + // If this was the result of a zero-byte read, then + // it is possible that the read was due to a zero-size + // message. Since we are simulating CloseWrite with a + // zero-byte message, ensure that all future Read calls + // also return EOF. + f.readEOF = true + } else if err == windows.ERROR_MORE_DATA { + // ERROR_MORE_DATA indicates that the pipe's read mode is message mode + // and the message still has more bytes. Treat this as a success, since + // this package presents all named pipes as byte streams. + err = nil + } + return n, err +} + +func (f *pipe) Handle() windows.Handle { + return f.handle +} + +func (s pipeAddress) Network() string { + return "pipe" +} + +func (s pipeAddress) String() string { + return string(s) +} + +// tryDialPipe attempts to dial the specified pipe until cancellation or timeout. +func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) { + for { + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + path16, err := windows.UTF16PtrFromString(*path) + if err != nil { + return 0, err + } + h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) + if err == nil { + return h, nil + } + if err != windows.ERROR_PIPE_BUSY { + return h, &os.PathError{Err: err, Op: "open", Path: *path} + } + // Wait 10 msec and try again. This is a rather simplistic + // view, as we always try each 10 milliseconds. + time.Sleep(10 * time.Millisecond) + } + } +} + +// DialConfig exposes various options for use in Dial and DialContext. +type DialConfig struct { + ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID. +} + +// Dial connects to the specified named pipe by path, timing out if the connection +// takes longer than the specified duration. If timeout is nil, then we use +// a default timeout of 2 seconds. +func Dial(path string, timeout *time.Duration, config *DialConfig) (net.Conn, error) { + var absTimeout time.Time + if timeout != nil { + absTimeout = time.Now().Add(*timeout) + } else { + absTimeout = time.Now().Add(2 * time.Second) + } + ctx, _ := context.WithDeadline(context.Background(), absTimeout) + conn, err := DialContext(ctx, path, config) + if err == context.DeadlineExceeded { + return nil, os.ErrDeadlineExceeded + } + return conn, err +} + +// DialContext attempts to connect to the specified named pipe by path +// cancellation or timeout. +func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn, error) { + if config == nil { + config = &DialConfig{} + } + var err error + var h windows.Handle + h, err = tryDialPipe(ctx, &path) + if err != nil { + return nil, err + } + + if config.ExpectedOwner != nil { + sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) + if err != nil { + windows.Close(h) + return nil, err + } + realOwner, _, err := sd.Owner() + if err != nil { + windows.Close(h) + return nil, err + } + if !realOwner.Equals(config.ExpectedOwner) { + windows.Close(h) + return nil, windows.ERROR_ACCESS_DENIED + } + } + + var flags uint32 + err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil) + if err != nil { + windows.Close(h) + return nil, err + } + + f, err := makeFile(h) + if err != nil { + windows.Close(h) + return nil, err + } + + // If the pipe is in message mode, return a message byte pipe, which + // supports CloseWrite. + if flags&windows.PIPE_TYPE_MESSAGE != 0 { + return &messageBytePipe{ + pipe: pipe{file: f, path: path}, + }, nil + } + return &pipe{file: f, path: path}, nil +} + +type acceptResponse struct { + f *file + err error +} + +type pipeListener struct { + firstHandle windows.Handle + path string + config ListenConfig + acceptCh chan (chan acceptResponse) + closeCh chan int + doneCh chan int +} + +func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, first bool) (windows.Handle, error) { + path16, err := windows.UTF16PtrFromString(path) + if err != nil { + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + + var oa windows.OBJECT_ATTRIBUTES + oa.Length = uint32(unsafe.Sizeof(oa)) + + var ntPath windows.NTUnicodeString + if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + err = ntstatus.Errno() + } + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer))) + oa.ObjectName = &ntPath + + // The security descriptor is only needed for the first pipe. + if first { + if sd != nil { + oa.SecurityDescriptor = sd + } else { + // Construct the default named pipe security descriptor. + var acl *windows.ACL + if err := windows.RtlDefaultNpAcl(&acl); err != nil { + return 0, err + } + defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl))) + sd, err := windows.NewSecurityDescriptor() + if err != nil { + return 0, err + } + if err = sd.SetDACL(acl, true, false); err != nil { + return 0, err + } + oa.SecurityDescriptor = sd + } + } + + typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS) + if c.MessageMode { + typ |= windows.FILE_PIPE_MESSAGE_TYPE + } + + disposition := uint32(windows.FILE_OPEN) + access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) + if first { + disposition = windows.FILE_CREATE + // By not asking for read or write access, the named pipe file system + // will put this pipe into an initially disconnected state, blocking + // client connections until the next call with first == false. + access = windows.SYNCHRONIZE + } + + timeout := int64(-50 * 10000) // 50ms + + var ( + h windows.Handle + iosb windows.IO_STATUS_BLOCK + ) + err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout) + if err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + err = ntstatus.Errno() + } + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + + runtime.KeepAlive(ntPath) + return h, nil +} + +func (l *pipeListener) makeServerPipe() (*file, error) { + h, err := makeServerPipeHandle(l.path, nil, &l.config, false) + if err != nil { + return nil, err + } + f, err := makeFile(h) + if err != nil { + windows.Close(h) + return nil, err + } + return f, nil +} + +func (l *pipeListener) makeConnectedServerPipe() (*file, error) { + p, err := l.makeServerPipe() + if err != nil { + return nil, err + } + + // Wait for the client to connect. + ch := make(chan error) + go func(p *file) { + ch <- connectPipe(p) + }(p) + + select { + case err = <-ch: + if err != nil { + p.Close() + p = nil + } + case <-l.closeCh: + // Abort the connect request by closing the handle. + p.Close() + p = nil + err = <-ch + if err == nil || err == os.ErrClosed { + err = net.ErrClosed + } + } + return p, err +} + +func (l *pipeListener) listenerRoutine() { + closed := false + for !closed { + select { + case <-l.closeCh: + closed = true + case responseCh := <-l.acceptCh: + var ( + p *file + err error + ) + for { + p, err = l.makeConnectedServerPipe() + // If the connection was immediately closed by the client, try + // again. + if err != windows.ERROR_NO_DATA { + break + } + } + responseCh <- acceptResponse{p, err} + closed = err == net.ErrClosed + } + } + windows.Close(l.firstHandle) + l.firstHandle = 0 + // Notify Close and Accept callers that the handle has been closed. + close(l.doneCh) +} + +// ListenConfig contains configuration for the pipe listener. +type ListenConfig struct { + // SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used. + SecurityDescriptor *windows.SECURITY_DESCRIPTOR + + // MessageMode determines whether the pipe is in byte or message mode. In either + // case the pipe is read in byte mode by default. The only practical difference in + // this implementation is that CloseWrite is only supported for message mode pipes; + // CloseWrite is implemented as a zero-byte write, but zero-byte writes are only + // transferred to the reader (and returned as io.EOF in this implementation) + // when the pipe is in message mode. + MessageMode bool + + // InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed. + InputBufferSize int32 + + // OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed. + OutputBufferSize int32 +} + +// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe. +// The pipe must not already exist. +func Listen(path string, c *ListenConfig) (net.Listener, error) { + if c == nil { + c = &ListenConfig{} + } + h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) + if err != nil { + return nil, err + } + l := &pipeListener{ + firstHandle: h, + path: path, + config: *c, + acceptCh: make(chan (chan acceptResponse)), + closeCh: make(chan int), + doneCh: make(chan int), + } + // The first connection is swallowed on Windows 7 & 8, so synthesize it. + if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { + path16, err := windows.UTF16PtrFromString(path) + if err == nil { + h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) + if err == nil { + windows.CloseHandle(h) + } + } + } + go l.listenerRoutine() + return l, nil +} + +func connectPipe(p *file) error { + c, err := p.prepareIo() + if err != nil { + return err + } + defer p.wg.Done() + + err = windows.ConnectNamedPipe(p.handle, &c.o) + _, err = p.asyncIo(c, nil, 0, err) + if err != nil && err != windows.ERROR_PIPE_CONNECTED { + return err + } + return nil +} + +func (l *pipeListener) Accept() (net.Conn, error) { + ch := make(chan acceptResponse) + select { + case l.acceptCh <- ch: + response := <-ch + err := response.err + if err != nil { + return nil, err + } + if l.config.MessageMode { + return &messageBytePipe{ + pipe: pipe{file: response.f, path: l.path}, + }, nil + } + return &pipe{file: response.f, path: l.path}, nil + case <-l.doneCh: + return nil, net.ErrClosed + } +} + +func (l *pipeListener) Close() error { + select { + case l.closeCh <- 1: + <-l.doneCh + case <-l.doneCh: + } + return nil +} + +func (l *pipeListener) Addr() net.Addr { + return pipeAddress(l.path) +} diff --git a/ipc/winpipe/winpipe_test.go b/ipc/winpipe/winpipe_test.go new file mode 100644 index 000000000..34ebeb18e --- /dev/null +++ b/ipc/winpipe/winpipe_test.go @@ -0,0 +1,656 @@ +// +build windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2005 Microsoft + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package winpipe_test + +import ( + "bufio" + "bytes" + "context" + "io" + "net" + "os" + "sync" + "syscall" + "testing" + "time" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/ipc/winpipe" +) + +func randomPipePath() string { + guid, err := windows.GenerateGUID() + if err != nil { + panic(err) + } + return `\\.\pipe\go-winpipe-test-` + guid.String() +} + +func TestPingPong(t *testing.T) { + const ( + ping = 42 + pong = 24 + ) + pipePath := randomPipePath() + listener, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatalf("unable to listen on pipe: %v", err) + } + defer listener.Close() + go func() { + incoming, err := listener.Accept() + if err != nil { + t.Fatalf("unable to accept pipe connection: %v", err) + } + defer incoming.Close() + var data [1]byte + _, err = incoming.Read(data[:]) + if err != nil { + t.Fatalf("unable to read ping from pipe: %v", err) + } + if data[0] != ping { + t.Fatalf("expected ping, got %d", data[0]) + } + data[0] = pong + _, err = incoming.Write(data[:]) + if err != nil { + t.Fatalf("unable to write pong to pipe: %v", err) + } + }() + client, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatalf("unable to dial pipe: %v", err) + } + defer client.Close() + var data [1]byte + data[0] = ping + _, err = client.Write(data[:]) + if err != nil { + t.Fatalf("unable to write ping to pipe: %v", err) + } + _, err = client.Read(data[:]) + if err != nil { + t.Fatalf("unable to read pong from pipe: %v", err) + } + if data[0] != pong { + t.Fatalf("expected pong, got %d", data[0]) + } +} + +func TestDialUnknownFailsImmediately(t *testing.T) { + _, err := winpipe.Dial(randomPipePath(), nil, nil) + if err.(*os.PathError).Err != syscall.ENOENT { + t.Fatalf("expected ENOENT got %v", err) + } +} + +func TestDialListenerTimesOut(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + d := 10 * time.Millisecond + _, err = winpipe.Dial(pipePath, &d, nil) + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func TestDialContextListenerTimesOut(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + d := 10 * time.Millisecond + ctx, _ := context.WithTimeout(context.Background(), d) + _, err = winpipe.DialContext(ctx, pipePath, nil) + if err != context.DeadlineExceeded { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } +} + +func TestDialListenerGetsCancelled(t *testing.T) { + pipePath := randomPipePath() + ctx, cancel := context.WithCancel(context.Background()) + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + ch := make(chan error) + defer l.Close() + go func(ctx context.Context, ch chan error) { + _, err := winpipe.DialContext(ctx, pipePath, nil) + ch <- err + }(ctx, ch) + time.Sleep(time.Millisecond * 30) + cancel() + err = <-ch + if err != context.Canceled { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { + pipePath := randomPipePath() + sd, _ := windows.SecurityDescriptorFromString("D:P(A;;0x1200FF;;;WD)") + c := winpipe.ListenConfig{ + SecurityDescriptor: sd, + } + l, err := winpipe.Listen(pipePath, &c) + if err != nil { + t.Fatal(err) + } + defer l.Close() + _, err = winpipe.Dial(pipePath, nil, nil) + if err.(*os.PathError).Err != syscall.ERROR_ACCESS_DENIED { + t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) + } +} + +func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, err error) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, cfg) + if err != nil { + return + } + defer l.Close() + + type response struct { + c net.Conn + err error + } + ch := make(chan response) + go func() { + c, err := l.Accept() + ch <- response{c, err} + }() + + c, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + return + } + + r := <-ch + if err = r.err; err != nil { + c.Close() + return + } + + client = c + server = r.c + return +} + +func TestReadTimeout(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + + c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + + buf := make([]byte, 10) + _, err = c.Read(buf) + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func server(l net.Listener, ch chan int) { + c, err := l.Accept() + if err != nil { + panic(err) + } + rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) + s, err := rw.ReadString('\n') + if err != nil { + panic(err) + } + _, err = rw.WriteString("got " + s) + if err != nil { + panic(err) + } + err = rw.Flush() + if err != nil { + panic(err) + } + c.Close() + ch <- 1 +} + +func TestFullListenDialReadWrite(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + ch := make(chan int) + go server(l, ch) + + c, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) + _, err = rw.WriteString("hello world\n") + if err != nil { + t.Fatal(err) + } + err = rw.Flush() + if err != nil { + t.Fatal(err) + } + + s, err := rw.ReadString('\n') + if err != nil { + t.Fatal(err) + } + ms := "got hello world\n" + if s != ms { + t.Errorf("expected '%s', got '%s'", ms, s) + } + + <-ch +} + +func TestCloseAbortsListen(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + + ch := make(chan error) + go func() { + _, err := l.Accept() + ch <- err + }() + + time.Sleep(30 * time.Millisecond) + l.Close() + + err = <-ch + if err != net.ErrClosed { + t.Fatalf("expected net.ErrClosed, got %v", err) + } +} + +func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) { + b := make([]byte, 10) + w.Close() + n, err := r.Read(b) + if n > 0 { + t.Errorf("unexpected byte count %d", n) + } + if err != io.EOF { + t.Errorf("expected EOF: %v", err) + } +} + +func TestCloseClientEOFServer(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + ensureEOFOnClose(t, c, s) +} + +func TestCloseServerEOFClient(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + ensureEOFOnClose(t, s, c) +} + +func TestCloseWriteEOF(t *testing.T) { + cfg := &winpipe.ListenConfig{ + MessageMode: true, + } + c, s, err := getConnection(cfg) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + + type closeWriter interface { + CloseWrite() error + } + + err = c.(closeWriter).CloseWrite() + if err != nil { + t.Fatal(err) + } + + b := make([]byte, 10) + _, err = s.Read(b) + if err != io.EOF { + t.Fatal(err) + } +} + +func TestAcceptAfterCloseFails(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + l.Close() + _, err = l.Accept() + if err != net.ErrClosed { + t.Fatalf("expected net.ErrClosed, got %v", err) + } +} + +func TestDialTimesOutByDefault(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + _, err = winpipe.Dial(pipePath, nil, nil) + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func TestTimeoutPendingRead(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + serverDone := make(chan struct{}) + + go func() { + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + s.Close() + close(serverDone) + }() + + client, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + clientErr := make(chan error) + go func() { + buf := make([]byte, 10) + _, err = client.Read(buf) + clientErr <- err + }() + + time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline + client.SetReadDeadline(time.Unix(1, 0)) + + select { + case err = <-clientErr: + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timed out while waiting for read to cancel") + <-clientErr + } + <-serverDone +} + +func TestTimeoutPendingWrite(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + serverDone := make(chan struct{}) + + go func() { + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + s.Close() + close(serverDone) + }() + + client, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + clientErr := make(chan error) + go func() { + _, err = client.Write([]byte("this should timeout")) + clientErr <- err + }() + + time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline + client.SetWriteDeadline(time.Unix(1, 0)) + + select { + case err = <-clientErr: + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timed out while waiting for write to cancel") + <-clientErr + } + <-serverDone +} + +type CloseWriter interface { + CloseWrite() error +} + +func TestEchoWithMessaging(t *testing.T) { + c := winpipe.ListenConfig{ + MessageMode: true, // Use message mode so that CloseWrite() is supported + InputBufferSize: 65536, // Use 64KB buffers to improve performance + OutputBufferSize: 65536, + } + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, &c) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + listenerDone := make(chan bool) + clientDone := make(chan bool) + go func() { + // server echo + conn, e := l.Accept() + if e != nil { + t.Fatal(e) + } + defer conn.Close() + + time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent + io.Copy(conn, conn) + conn.(CloseWriter).CloseWrite() + close(listenerDone) + }() + timeout := 1 * time.Second + client, err := winpipe.Dial(pipePath, &timeout, nil) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + go func() { + // client read back + bytes := make([]byte, 2) + n, e := client.Read(bytes) + if e != nil { + t.Fatal(e) + } + if n != 2 { + t.Fatalf("expected 2 bytes, got %v", n) + } + close(clientDone) + }() + + payload := make([]byte, 2) + payload[0] = 0 + payload[1] = 1 + + n, err := client.Write(payload) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Fatalf("expected 2 bytes, got %v", n) + } + client.(CloseWriter).CloseWrite() + <-listenerDone + <-clientDone +} + +func TestConnectRace(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + go func() { + for { + s, err := l.Accept() + if err == net.ErrClosed { + return + } + + if err != nil { + t.Fatal(err) + } + s.Close() + } + }() + + for i := 0; i < 1000; i++ { + c, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + c.Close() + } +} + +func TestMessageReadMode(t *testing.T) { + if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { + t.Skipf("Skipping on Windows %d", maj) + } + var wg sync.WaitGroup + defer wg.Wait() + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, &winpipe.ListenConfig{MessageMode: true}) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + msg := ([]byte)("hello world") + + wg.Add(1) + go func() { + defer wg.Done() + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + _, err = s.Write(msg) + if err != nil { + t.Fatal(err) + } + s.Close() + }() + + c, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + mode := uint32(windows.PIPE_READMODE_MESSAGE) + err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil) + if err != nil { + t.Fatal(err) + } + + ch := make([]byte, 1) + var vmsg []byte + for { + n, err := c.Read(ch) + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatalf("expected 1, got %d", n) + } + vmsg = append(vmsg, ch[0]) + } + if !bytes.Equal(msg, vmsg) { + t.Fatalf("expected %s, got %s", msg, vmsg) + } +} + +func TestListenConnectRace(t *testing.T) { + if testing.Short() { + t.Skip("Skipping long race test") + } + pipePath := randomPipePath() + for i := 0; i < 50 && !t.Failed(); i++ { + var wg sync.WaitGroup + wg.Add(1) + go func() { + c, err := winpipe.Dial(pipePath, nil, nil) + if err == nil { + c.Close() + } + wg.Done() + }() + s, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Error(i, err) + } else { + s.Close() + } + wg.Wait() + } +} diff --git a/ipc/winpipe/zsyscall_windows.go b/ipc/winpipe/zsyscall_windows.go deleted file mode 100644 index 995432975..000000000 --- a/ipc/winpipe/zsyscall_windows.go +++ /dev/null @@ -1,238 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package winpipe - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - modntdll = windows.NewLazySystemDLL("ntdll.dll") - modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") - - procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe") - procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW") - procCreateFileW = modkernel32.NewProc("CreateFileW") - procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo") - procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW") - procLocalAlloc = modkernel32.NewProc("LocalAlloc") - procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile") - procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb") - procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U") - procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl") - procCancelIoEx = modkernel32.NewProc("CancelIoEx") - procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort") - procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus") - procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes") - procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult") -) - -func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) { - r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) { - var _p0 *uint16 - _p0, err = syscall.UTF16PtrFromString(name) - if err != nil { - return - } - return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa) -} - -func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0) - handle = windows.Handle(r0) - if handle == windows.InvalidHandle { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) { - var _p0 *uint16 - _p0, err = syscall.UTF16PtrFromString(name) - if err != nil { - return - } - return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile) -} - -func _createFile(name *uint16, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0) - handle = windows.Handle(r0) - if handle == windows.InvalidHandle { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) { - r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func localAlloc(uFlags uint32, length uint32) (ptr uintptr) { - r0, _, _ := syscall.Syscall(procLocalAlloc.Addr(), 2, uintptr(uFlags), uintptr(length), 0) - ptr = uintptr(r0) - return -} - -func ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) { - r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0) - status = ntstatus(r0) - return -} - -func rtlNtStatusToDosError(status ntstatus) (winerr error) { - r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0) - if r0 != 0 { - winerr = syscall.Errno(r0) - } - return -} - -func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) { - r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0) - status = ntstatus(r0) - return -} - -func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) { - r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0) - status = ntstatus(r0) - return -} - -func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) { - r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) { - r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0) - newport = windows.Handle(r0) - if newport == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) { - r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) { - var _p0 uint32 - if wait { - _p0 = 1 - } else { - _p0 = 0 - } - r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} From 4b36949f72661e56f5672318b4158c51904a5b6c Mon Sep 17 00:00:00 2001 From: Ben Burkert Date: Mon, 1 Mar 2021 12:45:20 -0500 Subject: [PATCH 6/6] tun: replace the DNS client in netstack with net.Resolver Use the net.Resolver DNS client to send domain lookups to the server specified as a CreateNetTUN parameter. The net package's DNS client handles DNS request and response parsing when PreferGo is true. Like the previous DNS client it replaces, the net.Resolver instances also sends DNS queries over the WireGuard connection. Tested on and with support from Fly.io. Signed-off-by: Ben Burkert --- tun/netstack/tun.go | 524 +++++--------------------------------------- 1 file changed, 58 insertions(+), 466 deletions(-) diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 4846e2f06..87a76072a 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -7,11 +7,8 @@ package netstack import ( "context" - "crypto/rand" - "encoding/binary" "errors" "fmt" - "io" "net" "os" "strconv" @@ -20,7 +17,6 @@ import ( "golang.zx2c4.com/wireguard/tun" - "golang.org/x/net/dns/dnsmessage" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -39,6 +35,7 @@ type netTun struct { incomingPacket chan buffer.VectorisedView mtu int dnsServers []net.IP + resolver *net.Resolver hasV4, hasV6 bool } type endpoint netTun @@ -129,6 +126,11 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) } + dev.resolver = &net.Resolver{ + PreferGo: true, + Dial: (*Net)(dev).dialDNS, + } + dev.events <- tun.EventUp return dev, (*Net)(dev), nil } @@ -247,458 +249,14 @@ func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { } var ( - errNoSuchHost = errors.New("no such host") - errLameReferral = errors.New("lame referral") - errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message") - errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message") - errServerMisbehaving = errors.New("server misbehaving") - errInvalidDNSResponse = errors.New("invalid DNS response") - errNoAnswerFromDNSServer = errors.New("no answer from DNS server") - errServerTemporarilyMisbehaving = errors.New("server misbehaving") - errCanceled = errors.New("operation was canceled") - errTimeout = errors.New("i/o timeout") - errNumericPort = errors.New("port must be numeric") - errNoSuitableAddress = errors.New("no suitable address found") - errMissingAddress = errors.New("missing address") + errCanceled = errors.New("operation was canceled") + errTimeout = errors.New("i/o timeout") + errNumericPort = errors.New("port must be numeric") + errNoSuitableAddress = errors.New("no suitable address found") + errMissingAddress = errors.New("missing address") ) -func (net *Net) LookupHost(host string) (addrs []string, err error) { - return net.LookupContextHost(context.Background(), host) -} - -func isDomainName(s string) bool { - l := len(s) - if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { - return false - } - last := byte('.') - nonNumeric := false - partlen := 0 - for i := 0; i < len(s); i++ { - c := s[i] - switch { - default: - return false - case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': - nonNumeric = true - partlen++ - case '0' <= c && c <= '9': - partlen++ - case c == '-': - if last == '.' { - return false - } - partlen++ - nonNumeric = true - case c == '.': - if last == '.' || last == '-' { - return false - } - if partlen > 63 || partlen == 0 { - return false - } - partlen = 0 - } - last = c - } - if last == '-' || partlen > 63 { - return false - } - return nonNumeric -} - -func randU16() uint16 { - var b [2]byte - _, err := rand.Read(b[:]) - if err != nil { - panic(err) - } - return binary.LittleEndian.Uint16(b[:]) -} - -func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { - id = randU16() - b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) - b.EnableCompression() - if err := b.StartQuestions(); err != nil { - return 0, nil, nil, err - } - if err := b.Question(q); err != nil { - return 0, nil, nil, err - } - tcpReq, err = b.Finish() - udpReq = tcpReq[2:] - l := len(tcpReq) - 2 - tcpReq[0] = byte(l >> 8) - tcpReq[1] = byte(l) - return id, udpReq, tcpReq, err -} - -func equalASCIIName(x, y dnsmessage.Name) bool { - if x.Length != y.Length { - return false - } - for i := 0; i < int(x.Length); i++ { - a := x.Data[i] - b := y.Data[i] - if 'A' <= a && a <= 'Z' { - a += 0x20 - } - if 'A' <= b && b <= 'Z' { - b += 0x20 - } - if a != b { - return false - } - } - return true -} - -func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool { - if !respHdr.Response { - return false - } - if reqID != respHdr.ID { - return false - } - if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) { - return false - } - return true -} - -func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { - if _, err := c.Write(b); err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - b = make([]byte, 512) - for { - n, err := c.Read(b) - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - var p dnsmessage.Parser - h, err := p.Start(b[:n]) - if err != nil { - continue - } - q, err := p.Question() - if err != nil || !checkResponse(id, query, h, q) { - continue - } - return p, h, nil - } -} - -func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { - if _, err := c.Write(b); err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - b = make([]byte, 1280) - if _, err := io.ReadFull(c, b[:2]); err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - l := int(b[0])<<8 | int(b[1]) - if l > len(b) { - b = make([]byte, l) - } - n, err := io.ReadFull(c, b[:l]) - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - var p dnsmessage.Parser - h, err := p.Start(b[:n]) - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage - } - q, err := p.Question() - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage - } - if !checkResponse(id, query, h, q) { - return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse - } - return p, h, nil -} - -func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { - q.Class = dnsmessage.ClassINET - id, udpReq, tcpReq, err := newRequest(q) - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage - } - - for _, useUDP := range []bool{true, false} { - ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) - defer cancel() - - var c net.Conn - var err error - if useUDP { - c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53}) - } else { - c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53}) - } - - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - if d, ok := ctx.Deadline(); ok && !d.IsZero() { - c.SetDeadline(d) - } - var p dnsmessage.Parser - var h dnsmessage.Header - if useUDP { - p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) - } else { - p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) - } - c.Close() - if err != nil { - if err == context.Canceled { - err = errCanceled - } else if err == context.DeadlineExceeded { - err = errTimeout - } - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone { - return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse - } - if h.Truncated { - continue - } - return p, h, nil - } - return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer -} - -func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { - if h.RCode == dnsmessage.RCodeNameError { - return errNoSuchHost - } - _, err := p.AnswerHeader() - if err != nil && err != dnsmessage.ErrSectionDone { - return errCannotUnmarshalDNSMessage - } - if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { - return errLameReferral - } - if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError { - if h.RCode == dnsmessage.RCodeServerFailure { - return errServerTemporarilyMisbehaving - } - return errServerMisbehaving - } - return nil -} - -func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { - for { - h, err := p.AnswerHeader() - if err == dnsmessage.ErrSectionDone { - return errNoSuchHost - } - if err != nil { - return errCannotUnmarshalDNSMessage - } - if h.Type == qtype { - return nil - } - if err := p.SkipAnswer(); err != nil { - return errCannotUnmarshalDNSMessage - } - } -} - -func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { - var lastErr error - - n, err := dnsmessage.NewName(name) - if err != nil { - return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage - } - q := dnsmessage.Question{ - Name: n, - Type: qtype, - Class: dnsmessage.ClassINET, - } - - for i := 0; i < 2; i++ { - for _, server := range tnet.dnsServers { - p, h, err := tnet.exchange(ctx, server, q, time.Second*5) - if err != nil { - dnsErr := &net.DNSError{ - Err: err.Error(), - Name: name, - Server: server.String(), - } - if nerr, ok := err.(net.Error); ok && nerr.Timeout() { - dnsErr.IsTimeout = true - } - if _, ok := err.(*net.OpError); ok { - dnsErr.IsTemporary = true - } - lastErr = dnsErr - continue - } - - if err := checkHeader(&p, h); err != nil { - dnsErr := &net.DNSError{ - Err: err.Error(), - Name: name, - Server: server.String(), - } - if err == errServerTemporarilyMisbehaving { - dnsErr.IsTemporary = true - } - if err == errNoSuchHost { - dnsErr.IsNotFound = true - return p, server.String(), dnsErr - } - lastErr = dnsErr - continue - } - - err = skipToAnswer(&p, qtype) - if err == nil { - return p, server.String(), nil - } - lastErr = &net.DNSError{ - Err: err.Error(), - Name: name, - Server: server.String(), - } - if err == errNoSuchHost { - lastErr.(*net.DNSError).IsNotFound = true - return p, server.String(), lastErr - } - } - } - return dnsmessage.Parser{}, "", lastErr -} - -func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) { - if host == "" || (!tnet.hasV6 && !tnet.hasV4) { - return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} - } - zlen := len(host) - if strings.IndexByte(host, ':') != -1 { - if zidx := strings.LastIndexByte(host, '%'); zidx != -1 { - zlen = zidx - } - } - if ip := net.ParseIP(host[:zlen]); ip != nil { - return []string{host[:zlen]}, nil - } - - if !isDomainName(host) { - return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} - } - type result struct { - p dnsmessage.Parser - server string - error - } - var addrsV4, addrsV6 []net.IP - lanes := 0 - if tnet.hasV4 { - lanes++ - } - if tnet.hasV6 { - lanes++ - } - lane := make(chan result, lanes) - var lastErr error - if tnet.hasV4 { - go func() { - p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA) - lane <- result{p, server, err} - }() - } - if tnet.hasV6 { - go func() { - p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA) - lane <- result{p, server, err} - }() - } - for l := 0; l < lanes; l++ { - result := <-lane - if result.error != nil { - if lastErr == nil { - lastErr = result.error - } - continue - } - - loop: - for { - h, err := result.p.AnswerHeader() - if err != nil && err != dnsmessage.ErrSectionDone { - lastErr = &net.DNSError{ - Err: errCannotMarshalDNSMessage.Error(), - Name: host, - Server: result.server, - } - } - if err != nil { - break - } - switch h.Type { - case dnsmessage.TypeA: - a, err := result.p.AResource() - if err != nil { - lastErr = &net.DNSError{ - Err: errCannotMarshalDNSMessage.Error(), - Name: host, - Server: result.server, - } - break loop - } - addrsV4 = append(addrsV4, net.IP(a.A[:])) - - case dnsmessage.TypeAAAA: - aaaa, err := result.p.AAAAResource() - if err != nil { - lastErr = &net.DNSError{ - Err: errCannotMarshalDNSMessage.Error(), - Name: host, - Server: result.server, - } - break loop - } - addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:])) - - default: - if err := result.p.SkipAnswer(); err != nil { - lastErr = &net.DNSError{ - Err: errCannotMarshalDNSMessage.Error(), - Name: host, - Server: result.server, - } - break loop - } - continue - } - } - } - // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled - var addrs []net.IP - if tnet.hasV6 { - addrs = append(addrsV6, addrsV4...) - } else { - addrs = append(addrsV4, addrsV6...) - } - - if len(addrs) == 0 && lastErr != nil { - return nil, lastErr - } - saddrs := make([]string, 0, len(addrs)) - for _, ip := range addrs { - saddrs = append(saddrs, ip.String()) - } - return saddrs, nil -} +func (net *Net) Resolver() *net.Resolver { return net.resolver } func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { if deadline.IsZero() { @@ -748,20 +306,25 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. if err != nil || port < 0 || port > 65535 { return nil, &net.OpError{Op: "dial", Err: errNumericPort} } - allAddr, err := tnet.LookupContextHost(ctx, host) - if err != nil { - return nil, &net.OpError{Op: "dial", Err: err} - } + var addrs []net.IP - for _, addr := range allAddr { - if strings.IndexByte(addr, ':') != -1 && acceptV6 { - addrs = append(addrs, net.ParseIP(addr)) - } else if strings.IndexByte(addr, '.') != -1 && acceptV4 { - addrs = append(addrs, net.ParseIP(addr)) + if addr := net.ParseIP(host); addr != nil { + addrs = []net.IP{addr} + } else { + allAddr, err := tnet.resolver.LookupHost(ctx, host) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: err} + } + for _, addr := range allAddr { + if strings.IndexByte(addr, ':') != -1 && acceptV6 { + addrs = append(addrs, net.ParseIP(addr)) + } else if strings.IndexByte(addr, '.') != -1 && acceptV4 { + addrs = append(addrs, net.ParseIP(addr)) + } + } + if len(addrs) == 0 && len(allAddr) != 0 { + return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} } - } - if len(addrs) == 0 && len(allAddr) != 0 { - return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} } var firstErr error @@ -816,3 +379,32 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. func (tnet *Net) Dial(network, address string) (net.Conn, error) { return tnet.DialContext(context.Background(), network, address) } + +func (tnet *Net) dialDNS(ctx context.Context, network, address string) (net.Conn, error) { + if len(tnet.dnsServers) == 0 { + return tnet.DialContext(ctx, network, address) + } + + dnsIPs := make(map[string]struct{}, len(tnet.dnsServers)) + for _, dnsServer := range tnet.dnsServers { + ipAddress := dnsServer.String() + if host, _, err := net.SplitHostPort(address); err != nil && host == ipAddress { + return tnet.DialContext(ctx, network, address) + } + dnsIPs[ipAddress] = struct{}{} + } + + var lastErr error + for ipAddress := range dnsIPs { + conn, err := tnet.DialContext(ctx, network, net.JoinHostPort(ipAddress, "53")) + if err != nil { + if nerr, ok := err.(*net.OpError); ok { + lastErr = nerr + continue + } + return nil, err + } + return conn, nil + } + return nil, lastErr +}