diff --git a/conn/bind_linux.go b/conn/bind_linux.go index 419980951..70ea609d6 100644 --- a/conn/bind_linux.go +++ b/conn/bind_linux.go @@ -27,7 +27,7 @@ type ipv6Source struct { } type LinuxSocketEndpoint struct { - sync.Mutex + mu sync.Mutex dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte src [unsafe.Sizeof(ipv6Source{})]byte isV6 bool @@ -450,9 +450,9 @@ func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error { }, } - end.Lock() + end.mu.Lock() _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.Unlock() + end.mu.Unlock() if err == nil { return nil @@ -463,9 +463,9 @@ func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error { if err == unix.EINVAL { end.ClearSrc() cmsg.pktinfo = unix.Inet4Pktinfo{} - end.Lock() + end.mu.Lock() _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.Unlock() + end.mu.Unlock() } return err @@ -494,9 +494,9 @@ func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error { cmsg.pktinfo.Ifindex = 0 } - end.Lock() + end.mu.Lock() _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.Unlock() + end.mu.Unlock() if err == nil { return nil @@ -507,9 +507,9 @@ func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error { if err == unix.EINVAL { end.ClearSrc() cmsg.pktinfo = unix.Inet6Pktinfo{} - end.Lock() + end.mu.Lock() _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.Unlock() + end.mu.Unlock() } return err diff --git a/go.mod b/go.mod index 0aa27d850..6e691307c 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,6 @@ go 1.16 require ( golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 - golang.org/x/net v0.0.0-20210224082022-3d97a244fca7 - golang.org/x/sys v0.0.0-20210225014209-683adc9d29d7 + golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 + golang.org/x/sys v0.0.0-20210305215415-5cdee2b1b5a0 ) diff --git a/go.sum b/go.sum index 1ccf774ca..733a8f977 100644 --- a/go.sum +++ b/go.sum @@ -2,13 +2,13 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g= golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20210224082022-3d97a244fca7 h1:OgUuv8lsRpBibGNbSizVwKWlysjaNzmC9gYMhPVfqFM= -golang.org/x/net v0.0.0-20210224082022-3d97a244fca7/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210225014209-683adc9d29d7 h1:pk3Y+QnSKjMLfO/HIqzn/Zvv3/IHjRPhwblrmUuodzw= -golang.org/x/sys v0.0.0-20210225014209-683adc9d29d7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210305215415-5cdee2b1b5a0 h1:MOJR6AyRlIYMexU2acorBot1aPks0cBDOyUA4hFlBhE= +golang.org/x/sys v0.0.0-20210305215415-5cdee2b1b5a0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go index 164b7cb1b..3e2709cee 100644 --- a/ipc/uapi_windows.go +++ b/ipc/uapi_windows.go @@ -62,10 +62,10 @@ func init() { } func UAPIListen(name string) (net.Listener, error) { - config := winpipe.PipeConfig{ + config := winpipe.ListenConfig{ SecurityDescriptor: UAPISecurityDescriptor, } - listener, err := winpipe.ListenPipe(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config) + listener, err := winpipe.Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config) if err != nil { return nil, err } diff --git a/ipc/winpipe/file.go b/ipc/winpipe/file.go index f3b768f63..0c9abb140 100644 --- a/ipc/winpipe/file.go +++ b/ipc/winpipe/file.go @@ -5,54 +5,21 @@ * Copyright (C) 2005 Microsoft * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. */ + package winpipe import ( - "errors" "io" + "os" "runtime" "sync" "sync/atomic" "time" + "unsafe" "golang.org/x/sys/windows" ) -//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx -//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort -//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus -//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes -//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult - -type atomicBool int32 - -func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } -func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } -func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } -func (b *atomicBool) swap(new bool) bool { - var newInt int32 - if new { - newInt = 1 - } - return atomic.SwapInt32((*int32)(b), newInt) == 1 -} - -const ( - cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1 - cFILE_SKIP_SET_EVENT_ON_HANDLE = 2 -) - -var ( - ErrFileClosed = errors.New("file has already been closed") - ErrTimeout = &timeoutError{} -) - -type timeoutError struct{} - -func (e *timeoutError) Error() string { return "i/o timeout" } -func (e *timeoutError) Timeout() bool { return true } -func (e *timeoutError) Temporary() bool { return true } - type timeoutChan chan struct{} var ioInitOnce sync.Once @@ -71,7 +38,7 @@ type ioOperation struct { } func initIo() { - h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff) + h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) if err != nil { panic(err) } @@ -79,13 +46,13 @@ func initIo() { go ioCompletionProcessor(h) } -// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. +// file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. // It takes ownership of this handle and will close it if it is garbage collected. -type win32File struct { +type file struct { handle windows.Handle wg sync.WaitGroup wgLock sync.RWMutex - closing atomicBool + closing uint32 // used as atomic boolean socket bool readDeadline deadlineHandler writeDeadline deadlineHandler @@ -96,18 +63,18 @@ type deadlineHandler struct { channel timeoutChan channelLock sync.RWMutex timer *time.Timer - timedout atomicBool + timedout uint32 // used as atomic boolean } -// makeWin32File makes a new win32File from an existing file handle -func makeWin32File(h windows.Handle) (*win32File, error) { - f := &win32File{handle: h} +// makeFile makes a new file from an existing file handle +func makeFile(h windows.Handle) (*file, error) { + f := &file{handle: h} ioInitOnce.Do(initIo) - _, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff) + _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0) if err != nil { return nil, err } - err = setFileCompletionNotificationModes(h, cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS|cFILE_SKIP_SET_EVENT_ON_HANDLE) + err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE) if err != nil { return nil, err } @@ -116,18 +83,14 @@ func makeWin32File(h windows.Handle) (*win32File, error) { return f, nil } -func MakeOpenFile(h windows.Handle) (io.ReadWriteCloser, error) { - return makeWin32File(h) -} - // closeHandle closes the resources associated with a Win32 handle -func (f *win32File) closeHandle() { +func (f *file) closeHandle() { f.wgLock.Lock() // Atomically set that we are closing, releasing the resources only once. - if !f.closing.swap(true) { + if atomic.SwapUint32(&f.closing, 1) == 0 { f.wgLock.Unlock() // cancel all IO and wait for it to complete - cancelIoEx(f.handle, nil) + windows.CancelIoEx(f.handle, nil) f.wg.Wait() // at this point, no new IO can start windows.Close(f.handle) @@ -137,19 +100,19 @@ func (f *win32File) closeHandle() { } } -// Close closes a win32File. -func (f *win32File) Close() error { +// Close closes a file. +func (f *file) Close() error { f.closeHandle() return nil } // prepareIo prepares for a new IO operation. // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. -func (f *win32File) prepareIo() (*ioOperation, error) { +func (f *file) prepareIo() (*ioOperation, error) { f.wgLock.RLock() - if f.closing.isSet() { + if atomic.LoadUint32(&f.closing) == 1 { f.wgLock.RUnlock() - return nil, ErrFileClosed + return nil, os.ErrClosed } f.wg.Add(1) f.wgLock.RUnlock() @@ -164,7 +127,7 @@ func ioCompletionProcessor(h windows.Handle) { var bytes uint32 var key uintptr var op *ioOperation - err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE) + err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE) if op == nil { panic(err) } @@ -174,13 +137,13 @@ func ioCompletionProcessor(h windows.Handle) { // asyncIo processes the return value from ReadFile or WriteFile, blocking until // the operation has actually completed. -func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { +func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { if err != windows.ERROR_IO_PENDING { return int(bytes), err } - if f.closing.isSet() { - cancelIoEx(f.handle, &c.o) + if atomic.LoadUint32(&f.closing) == 1 { + windows.CancelIoEx(f.handle, &c.o) } var timeout timeoutChan @@ -195,20 +158,20 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er case r = <-c.ch: err = r.err if err == windows.ERROR_OPERATION_ABORTED { - if f.closing.isSet() { - err = ErrFileClosed + if atomic.LoadUint32(&f.closing) == 1 { + err = os.ErrClosed } } else if err != nil && f.socket { // err is from Win32. Query the overlapped structure to get the winsock error. var bytes, flags uint32 - err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) + err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) } case <-timeout: - cancelIoEx(f.handle, &c.o) + windows.CancelIoEx(f.handle, &c.o) r = <-c.ch err = r.err if err == windows.ERROR_OPERATION_ABORTED { - err = ErrTimeout + err = os.ErrDeadlineExceeded } } @@ -220,15 +183,15 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er } // Read reads from a file handle. -func (f *win32File) Read(b []byte) (int, error) { +func (f *file) Read(b []byte) (int, error) { c, err := f.prepareIo() if err != nil { return 0, err } defer f.wg.Done() - if f.readDeadline.timedout.isSet() { - return 0, ErrTimeout + if atomic.LoadUint32(&f.readDeadline.timedout) == 1 { + return 0, os.ErrDeadlineExceeded } var bytes uint32 @@ -247,15 +210,15 @@ func (f *win32File) Read(b []byte) (int, error) { } // Write writes to a file handle. -func (f *win32File) Write(b []byte) (int, error) { +func (f *file) Write(b []byte) (int, error) { c, err := f.prepareIo() if err != nil { return 0, err } defer f.wg.Done() - if f.writeDeadline.timedout.isSet() { - return 0, ErrTimeout + if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 { + return 0, os.ErrDeadlineExceeded } var bytes uint32 @@ -265,19 +228,19 @@ func (f *win32File) Write(b []byte) (int, error) { return n, err } -func (f *win32File) SetReadDeadline(deadline time.Time) error { +func (f *file) SetReadDeadline(deadline time.Time) error { return f.readDeadline.set(deadline) } -func (f *win32File) SetWriteDeadline(deadline time.Time) error { +func (f *file) SetWriteDeadline(deadline time.Time) error { return f.writeDeadline.set(deadline) } -func (f *win32File) Flush() error { +func (f *file) Flush() error { return windows.FlushFileBuffers(f.handle) } -func (f *win32File) Fd() uintptr { +func (f *file) Fd() uintptr { return uintptr(f.handle) } @@ -291,7 +254,7 @@ func (d *deadlineHandler) set(deadline time.Time) error { } d.timer = nil } - d.timedout.setFalse() + atomic.StoreUint32(&d.timedout, 0) select { case <-d.channel: @@ -306,7 +269,7 @@ func (d *deadlineHandler) set(deadline time.Time) error { } timeoutIO := func() { - d.timedout.setTrue() + atomic.StoreUint32(&d.timedout, 1) close(d.channel) } diff --git a/ipc/winpipe/mksyscall.go b/ipc/winpipe/mksyscall.go deleted file mode 100644 index a87e9298e..000000000 --- a/ipc/winpipe/mksyscall.go +++ /dev/null @@ -1,9 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package winpipe - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go pipe.go file.go diff --git a/ipc/winpipe/pipe.go b/ipc/winpipe/pipe.go deleted file mode 100644 index e609274f5..000000000 --- a/ipc/winpipe/pipe.go +++ /dev/null @@ -1,509 +0,0 @@ -// +build windows - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package winpipe - -import ( - "context" - "errors" - "fmt" - "io" - "net" - "os" - "runtime" - "time" - "unsafe" - - "golang.org/x/sys/windows" -) - -//sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe -//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW -//sys createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW -//sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo -//sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW -//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc -//sys ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile -//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb -//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U -//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl - -type ioStatusBlock struct { - Status, Information uintptr -} - -type objectAttributes struct { - Length uintptr - RootDirectory uintptr - ObjectName *unicodeString - Attributes uintptr - SecurityDescriptor *windows.SECURITY_DESCRIPTOR - SecurityQoS uintptr -} - -type unicodeString struct { - Length uint16 - MaximumLength uint16 - Buffer uintptr -} - -type ntstatus int32 - -func (status ntstatus) Err() error { - if status >= 0 { - return nil - } - return rtlNtStatusToDosError(status) -} - -const ( - cSECURITY_SQOS_PRESENT = 0x100000 - cSECURITY_ANONYMOUS = 0 - - cPIPE_TYPE_MESSAGE = 4 - - cPIPE_READMODE_MESSAGE = 2 - - cFILE_OPEN = 1 - cFILE_CREATE = 2 - - cFILE_PIPE_MESSAGE_TYPE = 1 - cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2 -) - -var ( - // ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed. - // This error should match net.errClosing since docker takes a dependency on its text. - ErrPipeListenerClosed = errors.New("use of closed network connection") - - errPipeWriteClosed = errors.New("pipe has been closed for write") -) - -type win32Pipe struct { - *win32File - path string -} - -type win32MessageBytePipe struct { - win32Pipe - writeClosed bool - readEOF bool -} - -type pipeAddress string - -func (f *win32Pipe) LocalAddr() net.Addr { - return pipeAddress(f.path) -} - -func (f *win32Pipe) RemoteAddr() net.Addr { - return pipeAddress(f.path) -} - -func (f *win32Pipe) SetDeadline(t time.Time) error { - f.SetReadDeadline(t) - f.SetWriteDeadline(t) - return nil -} - -// CloseWrite closes the write side of a message pipe in byte mode. -func (f *win32MessageBytePipe) CloseWrite() error { - if f.writeClosed { - return errPipeWriteClosed - } - err := f.win32File.Flush() - if err != nil { - return err - } - _, err = f.win32File.Write(nil) - if err != nil { - return err - } - f.writeClosed = true - return nil -} - -// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since -// they are used to implement CloseWrite(). -func (f *win32MessageBytePipe) Write(b []byte) (int, error) { - if f.writeClosed { - return 0, errPipeWriteClosed - } - if len(b) == 0 { - return 0, nil - } - return f.win32File.Write(b) -} - -// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message -// mode pipe will return io.EOF, as will all subsequent reads. -func (f *win32MessageBytePipe) Read(b []byte) (int, error) { - if f.readEOF { - return 0, io.EOF - } - n, err := f.win32File.Read(b) - if err == io.EOF { - // If this was the result of a zero-byte read, then - // it is possible that the read was due to a zero-size - // message. Since we are simulating CloseWrite with a - // zero-byte message, ensure that all future Read() calls - // also return EOF. - f.readEOF = true - } else if err == windows.ERROR_MORE_DATA { - // ERROR_MORE_DATA indicates that the pipe's read mode is message mode - // and the message still has more bytes. Treat this as a success, since - // this package presents all named pipes as byte streams. - err = nil - } - return n, err -} - -func (s pipeAddress) Network() string { - return "pipe" -} - -func (s pipeAddress) String() string { - return string(s) -} - -// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout. -func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) { - for { - select { - case <-ctx.Done(): - return windows.Handle(0), ctx.Err() - default: - h, err := createFile(*path, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0) - if err == nil { - return h, nil - } - if err != windows.ERROR_PIPE_BUSY { - return h, &os.PathError{Err: err, Op: "open", Path: *path} - } - // Wait 10 msec and try again. This is a rather simplistic - // view, as we always try each 10 milliseconds. - time.Sleep(time.Millisecond * 10) - } - } -} - -// DialPipe connects to a named pipe by path, timing out if the connection -// takes longer than the specified duration. If timeout is nil, then we use -// a default timeout of 2 seconds. (We do not use WaitNamedPipe.) -func DialPipe(path string, timeout *time.Duration, expectedOwner *windows.SID) (net.Conn, error) { - var absTimeout time.Time - if timeout != nil { - absTimeout = time.Now().Add(*timeout) - } else { - absTimeout = time.Now().Add(time.Second * 2) - } - ctx, _ := context.WithDeadline(context.Background(), absTimeout) - conn, err := DialPipeContext(ctx, path, expectedOwner) - if err == context.DeadlineExceeded { - return nil, ErrTimeout - } - return conn, err -} - -// DialPipeContext attempts to connect to a named pipe by `path` until `ctx` -// cancellation or timeout. -func DialPipeContext(ctx context.Context, path string, expectedOwner *windows.SID) (net.Conn, error) { - var err error - var h windows.Handle - h, err = tryDialPipe(ctx, &path) - if err != nil { - return nil, err - } - - if expectedOwner != nil { - sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) - if err != nil { - windows.Close(h) - return nil, err - } - realOwner, _, err := sd.Owner() - if err != nil { - windows.Close(h) - return nil, err - } - if !realOwner.Equals(expectedOwner) { - windows.Close(h) - return nil, windows.ERROR_ACCESS_DENIED - } - } - - var flags uint32 - err = getNamedPipeInfo(h, &flags, nil, nil, nil) - if err != nil { - windows.Close(h) - return nil, err - } - - f, err := makeWin32File(h) - if err != nil { - windows.Close(h) - return nil, err - } - - // If the pipe is in message mode, return a message byte pipe, which - // supports CloseWrite(). - if flags&cPIPE_TYPE_MESSAGE != 0 { - return &win32MessageBytePipe{ - win32Pipe: win32Pipe{win32File: f, path: path}, - }, nil - } - return &win32Pipe{win32File: f, path: path}, nil -} - -type acceptResponse struct { - f *win32File - err error -} - -type win32PipeListener struct { - firstHandle windows.Handle - path string - config PipeConfig - acceptCh chan (chan acceptResponse) - closeCh chan int - doneCh chan int -} - -func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *PipeConfig, first bool) (windows.Handle, error) { - path16, err := windows.UTF16FromString(path) - if err != nil { - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - - var oa objectAttributes - oa.Length = unsafe.Sizeof(oa) - - var ntPath unicodeString - if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil { - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - defer windows.LocalFree(windows.Handle(ntPath.Buffer)) - oa.ObjectName = &ntPath - - // The security descriptor is only needed for the first pipe. - if first { - if sd != nil { - oa.SecurityDescriptor = sd - } else { - // Construct the default named pipe security descriptor. - var dacl uintptr - if err := rtlDefaultNpAcl(&dacl).Err(); err != nil { - return 0, fmt.Errorf("getting default named pipe ACL: %s", err) - } - defer windows.LocalFree(windows.Handle(dacl)) - sd, err := windows.NewSecurityDescriptor() - if err != nil { - return 0, fmt.Errorf("creating new security descriptor: %s", err) - } - if err = sd.SetDACL((*windows.ACL)(unsafe.Pointer(dacl)), true, false); err != nil { - return 0, fmt.Errorf("assigning dacl: %s", err) - } - sd, err = sd.ToSelfRelative() - if err != nil { - return 0, fmt.Errorf("converting to self-relative: %s", err) - } - oa.SecurityDescriptor = sd - } - } - - typ := uint32(cFILE_PIPE_REJECT_REMOTE_CLIENTS) - if c.MessageMode { - typ |= cFILE_PIPE_MESSAGE_TYPE - } - - disposition := uint32(cFILE_OPEN) - access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) - if first { - disposition = cFILE_CREATE - // By not asking for read or write access, the named pipe file system - // will put this pipe into an initially disconnected state, blocking - // client connections until the next call with first == false. - access = windows.SYNCHRONIZE - } - - timeout := int64(-50 * 10000) // 50ms - - var ( - h windows.Handle - iosb ioStatusBlock - ) - err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err() - if err != nil { - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - - runtime.KeepAlive(ntPath) - return h, nil -} - -func (l *win32PipeListener) makeServerPipe() (*win32File, error) { - h, err := makeServerPipeHandle(l.path, nil, &l.config, false) - if err != nil { - return nil, err - } - f, err := makeWin32File(h) - if err != nil { - windows.Close(h) - return nil, err - } - return f, nil -} - -func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) { - p, err := l.makeServerPipe() - if err != nil { - return nil, err - } - - // Wait for the client to connect. - ch := make(chan error) - go func(p *win32File) { - ch <- connectPipe(p) - }(p) - - select { - case err = <-ch: - if err != nil { - p.Close() - p = nil - } - case <-l.closeCh: - // Abort the connect request by closing the handle. - p.Close() - p = nil - err = <-ch - if err == nil || err == ErrFileClosed { - err = ErrPipeListenerClosed - } - } - return p, err -} - -func (l *win32PipeListener) listenerRoutine() { - closed := false - for !closed { - select { - case <-l.closeCh: - closed = true - case responseCh := <-l.acceptCh: - var ( - p *win32File - err error - ) - for { - p, err = l.makeConnectedServerPipe() - // If the connection was immediately closed by the client, try - // again. - if err != windows.ERROR_NO_DATA { - break - } - } - responseCh <- acceptResponse{p, err} - closed = err == ErrPipeListenerClosed - } - } - windows.Close(l.firstHandle) - l.firstHandle = 0 - // Notify Close() and Accept() callers that the handle has been closed. - close(l.doneCh) -} - -// PipeConfig contain configuration for the pipe listener. -type PipeConfig struct { - // SecurityDescriptor contains a Windows security descriptor. - SecurityDescriptor *windows.SECURITY_DESCRIPTOR - - // MessageMode determines whether the pipe is in byte or message mode. In either - // case the pipe is read in byte mode by default. The only practical difference in - // this implementation is that CloseWrite() is only supported for message mode pipes; - // CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only - // transferred to the reader (and returned as io.EOF in this implementation) - // when the pipe is in message mode. - MessageMode bool - - // InputBufferSize specifies the size the input buffer, in bytes. - InputBufferSize int32 - - // OutputBufferSize specifies the size the input buffer, in bytes. - OutputBufferSize int32 -} - -// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe. -// The pipe must not already exist. -func ListenPipe(path string, c *PipeConfig) (net.Listener, error) { - if c == nil { - c = &PipeConfig{} - } - h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) - if err != nil { - return nil, err - } - l := &win32PipeListener{ - firstHandle: h, - path: path, - config: *c, - acceptCh: make(chan (chan acceptResponse)), - closeCh: make(chan int), - doneCh: make(chan int), - } - go l.listenerRoutine() - return l, nil -} - -func connectPipe(p *win32File) error { - c, err := p.prepareIo() - if err != nil { - return err - } - defer p.wg.Done() - - err = connectNamedPipe(p.handle, &c.o) - _, err = p.asyncIo(c, nil, 0, err) - if err != nil && err != windows.ERROR_PIPE_CONNECTED { - return err - } - return nil -} - -func (l *win32PipeListener) Accept() (net.Conn, error) { - ch := make(chan acceptResponse) - select { - case l.acceptCh <- ch: - response := <-ch - err := response.err - if err != nil { - return nil, err - } - if l.config.MessageMode { - return &win32MessageBytePipe{ - win32Pipe: win32Pipe{win32File: response.f, path: l.path}, - }, nil - } - return &win32Pipe{win32File: response.f, path: l.path}, nil - case <-l.doneCh: - return nil, ErrPipeListenerClosed - } -} - -func (l *win32PipeListener) Close() error { - select { - case l.closeCh <- 1: - <-l.doneCh - case <-l.doneCh: - } - return nil -} - -func (l *win32PipeListener) Addr() net.Addr { - return pipeAddress(l.path) -} diff --git a/ipc/winpipe/winpipe.go b/ipc/winpipe/winpipe.go new file mode 100644 index 000000000..f02f3d884 --- /dev/null +++ b/ipc/winpipe/winpipe.go @@ -0,0 +1,474 @@ +// +build windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2005 Microsoft + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +// Package winpipe implements a net.Conn and net.Listener around Windows named pipes. +package winpipe + +import ( + "context" + "io" + "net" + "os" + "runtime" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +type pipe struct { + *file + path string +} + +type messageBytePipe struct { + pipe + writeClosed bool + readEOF bool +} + +type pipeAddress string + +func (f *pipe) LocalAddr() net.Addr { + return pipeAddress(f.path) +} + +func (f *pipe) RemoteAddr() net.Addr { + return pipeAddress(f.path) +} + +func (f *pipe) SetDeadline(t time.Time) error { + f.SetReadDeadline(t) + f.SetWriteDeadline(t) + return nil +} + +// CloseWrite closes the write side of a message pipe in byte mode. +func (f *messageBytePipe) CloseWrite() error { + if f.writeClosed { + return io.ErrClosedPipe + } + err := f.file.Flush() + if err != nil { + return err + } + _, err = f.file.Write(nil) + if err != nil { + return err + } + f.writeClosed = true + return nil +} + +// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since +// they are used to implement CloseWrite. +func (f *messageBytePipe) Write(b []byte) (int, error) { + if f.writeClosed { + return 0, io.ErrClosedPipe + } + if len(b) == 0 { + return 0, nil + } + return f.file.Write(b) +} + +// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message +// mode pipe will return io.EOF, as will all subsequent reads. +func (f *messageBytePipe) Read(b []byte) (int, error) { + if f.readEOF { + return 0, io.EOF + } + n, err := f.file.Read(b) + if err == io.EOF { + // If this was the result of a zero-byte read, then + // it is possible that the read was due to a zero-size + // message. Since we are simulating CloseWrite with a + // zero-byte message, ensure that all future Read calls + // also return EOF. + f.readEOF = true + } else if err == windows.ERROR_MORE_DATA { + // ERROR_MORE_DATA indicates that the pipe's read mode is message mode + // and the message still has more bytes. Treat this as a success, since + // this package presents all named pipes as byte streams. + err = nil + } + return n, err +} + +func (f *pipe) Handle() windows.Handle { + return f.handle +} + +func (s pipeAddress) Network() string { + return "pipe" +} + +func (s pipeAddress) String() string { + return string(s) +} + +// tryDialPipe attempts to dial the specified pipe until cancellation or timeout. +func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) { + for { + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + path16, err := windows.UTF16PtrFromString(*path) + if err != nil { + return 0, err + } + h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) + if err == nil { + return h, nil + } + if err != windows.ERROR_PIPE_BUSY { + return h, &os.PathError{Err: err, Op: "open", Path: *path} + } + // Wait 10 msec and try again. This is a rather simplistic + // view, as we always try each 10 milliseconds. + time.Sleep(10 * time.Millisecond) + } + } +} + +// DialConfig exposes various options for use in Dial and DialContext. +type DialConfig struct { + ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID. +} + +// Dial connects to the specified named pipe by path, timing out if the connection +// takes longer than the specified duration. If timeout is nil, then we use +// a default timeout of 2 seconds. +func Dial(path string, timeout *time.Duration, config *DialConfig) (net.Conn, error) { + var absTimeout time.Time + if timeout != nil { + absTimeout = time.Now().Add(*timeout) + } else { + absTimeout = time.Now().Add(2 * time.Second) + } + ctx, _ := context.WithDeadline(context.Background(), absTimeout) + conn, err := DialContext(ctx, path, config) + if err == context.DeadlineExceeded { + return nil, os.ErrDeadlineExceeded + } + return conn, err +} + +// DialContext attempts to connect to the specified named pipe by path +// cancellation or timeout. +func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn, error) { + if config == nil { + config = &DialConfig{} + } + var err error + var h windows.Handle + h, err = tryDialPipe(ctx, &path) + if err != nil { + return nil, err + } + + if config.ExpectedOwner != nil { + sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) + if err != nil { + windows.Close(h) + return nil, err + } + realOwner, _, err := sd.Owner() + if err != nil { + windows.Close(h) + return nil, err + } + if !realOwner.Equals(config.ExpectedOwner) { + windows.Close(h) + return nil, windows.ERROR_ACCESS_DENIED + } + } + + var flags uint32 + err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil) + if err != nil { + windows.Close(h) + return nil, err + } + + f, err := makeFile(h) + if err != nil { + windows.Close(h) + return nil, err + } + + // If the pipe is in message mode, return a message byte pipe, which + // supports CloseWrite. + if flags&windows.PIPE_TYPE_MESSAGE != 0 { + return &messageBytePipe{ + pipe: pipe{file: f, path: path}, + }, nil + } + return &pipe{file: f, path: path}, nil +} + +type acceptResponse struct { + f *file + err error +} + +type pipeListener struct { + firstHandle windows.Handle + path string + config ListenConfig + acceptCh chan (chan acceptResponse) + closeCh chan int + doneCh chan int +} + +func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, first bool) (windows.Handle, error) { + path16, err := windows.UTF16PtrFromString(path) + if err != nil { + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + + var oa windows.OBJECT_ATTRIBUTES + oa.Length = uint32(unsafe.Sizeof(oa)) + + var ntPath windows.NTUnicodeString + if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + err = ntstatus.Errno() + } + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer))) + oa.ObjectName = &ntPath + + // The security descriptor is only needed for the first pipe. + if first { + if sd != nil { + oa.SecurityDescriptor = sd + } else { + // Construct the default named pipe security descriptor. + var acl *windows.ACL + if err := windows.RtlDefaultNpAcl(&acl); err != nil { + return 0, err + } + defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl))) + sd, err := windows.NewSecurityDescriptor() + if err != nil { + return 0, err + } + if err = sd.SetDACL(acl, true, false); err != nil { + return 0, err + } + oa.SecurityDescriptor = sd + } + } + + typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS) + if c.MessageMode { + typ |= windows.FILE_PIPE_MESSAGE_TYPE + } + + disposition := uint32(windows.FILE_OPEN) + access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) + if first { + disposition = windows.FILE_CREATE + // By not asking for read or write access, the named pipe file system + // will put this pipe into an initially disconnected state, blocking + // client connections until the next call with first == false. + access = windows.SYNCHRONIZE + } + + timeout := int64(-50 * 10000) // 50ms + + var ( + h windows.Handle + iosb windows.IO_STATUS_BLOCK + ) + err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout) + if err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + err = ntstatus.Errno() + } + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + + runtime.KeepAlive(ntPath) + return h, nil +} + +func (l *pipeListener) makeServerPipe() (*file, error) { + h, err := makeServerPipeHandle(l.path, nil, &l.config, false) + if err != nil { + return nil, err + } + f, err := makeFile(h) + if err != nil { + windows.Close(h) + return nil, err + } + return f, nil +} + +func (l *pipeListener) makeConnectedServerPipe() (*file, error) { + p, err := l.makeServerPipe() + if err != nil { + return nil, err + } + + // Wait for the client to connect. + ch := make(chan error) + go func(p *file) { + ch <- connectPipe(p) + }(p) + + select { + case err = <-ch: + if err != nil { + p.Close() + p = nil + } + case <-l.closeCh: + // Abort the connect request by closing the handle. + p.Close() + p = nil + err = <-ch + if err == nil || err == os.ErrClosed { + err = net.ErrClosed + } + } + return p, err +} + +func (l *pipeListener) listenerRoutine() { + closed := false + for !closed { + select { + case <-l.closeCh: + closed = true + case responseCh := <-l.acceptCh: + var ( + p *file + err error + ) + for { + p, err = l.makeConnectedServerPipe() + // If the connection was immediately closed by the client, try + // again. + if err != windows.ERROR_NO_DATA { + break + } + } + responseCh <- acceptResponse{p, err} + closed = err == net.ErrClosed + } + } + windows.Close(l.firstHandle) + l.firstHandle = 0 + // Notify Close and Accept callers that the handle has been closed. + close(l.doneCh) +} + +// ListenConfig contains configuration for the pipe listener. +type ListenConfig struct { + // SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used. + SecurityDescriptor *windows.SECURITY_DESCRIPTOR + + // MessageMode determines whether the pipe is in byte or message mode. In either + // case the pipe is read in byte mode by default. The only practical difference in + // this implementation is that CloseWrite is only supported for message mode pipes; + // CloseWrite is implemented as a zero-byte write, but zero-byte writes are only + // transferred to the reader (and returned as io.EOF in this implementation) + // when the pipe is in message mode. + MessageMode bool + + // InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed. + InputBufferSize int32 + + // OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed. + OutputBufferSize int32 +} + +// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe. +// The pipe must not already exist. +func Listen(path string, c *ListenConfig) (net.Listener, error) { + if c == nil { + c = &ListenConfig{} + } + h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) + if err != nil { + return nil, err + } + l := &pipeListener{ + firstHandle: h, + path: path, + config: *c, + acceptCh: make(chan (chan acceptResponse)), + closeCh: make(chan int), + doneCh: make(chan int), + } + // The first connection is swallowed on Windows 7 & 8, so synthesize it. + if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { + path16, err := windows.UTF16PtrFromString(path) + if err == nil { + h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) + if err == nil { + windows.CloseHandle(h) + } + } + } + go l.listenerRoutine() + return l, nil +} + +func connectPipe(p *file) error { + c, err := p.prepareIo() + if err != nil { + return err + } + defer p.wg.Done() + + err = windows.ConnectNamedPipe(p.handle, &c.o) + _, err = p.asyncIo(c, nil, 0, err) + if err != nil && err != windows.ERROR_PIPE_CONNECTED { + return err + } + return nil +} + +func (l *pipeListener) Accept() (net.Conn, error) { + ch := make(chan acceptResponse) + select { + case l.acceptCh <- ch: + response := <-ch + err := response.err + if err != nil { + return nil, err + } + if l.config.MessageMode { + return &messageBytePipe{ + pipe: pipe{file: response.f, path: l.path}, + }, nil + } + return &pipe{file: response.f, path: l.path}, nil + case <-l.doneCh: + return nil, net.ErrClosed + } +} + +func (l *pipeListener) Close() error { + select { + case l.closeCh <- 1: + <-l.doneCh + case <-l.doneCh: + } + return nil +} + +func (l *pipeListener) Addr() net.Addr { + return pipeAddress(l.path) +} diff --git a/ipc/winpipe/winpipe_test.go b/ipc/winpipe/winpipe_test.go new file mode 100644 index 000000000..34ebeb18e --- /dev/null +++ b/ipc/winpipe/winpipe_test.go @@ -0,0 +1,656 @@ +// +build windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2005 Microsoft + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package winpipe_test + +import ( + "bufio" + "bytes" + "context" + "io" + "net" + "os" + "sync" + "syscall" + "testing" + "time" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/ipc/winpipe" +) + +func randomPipePath() string { + guid, err := windows.GenerateGUID() + if err != nil { + panic(err) + } + return `\\.\pipe\go-winpipe-test-` + guid.String() +} + +func TestPingPong(t *testing.T) { + const ( + ping = 42 + pong = 24 + ) + pipePath := randomPipePath() + listener, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatalf("unable to listen on pipe: %v", err) + } + defer listener.Close() + go func() { + incoming, err := listener.Accept() + if err != nil { + t.Fatalf("unable to accept pipe connection: %v", err) + } + defer incoming.Close() + var data [1]byte + _, err = incoming.Read(data[:]) + if err != nil { + t.Fatalf("unable to read ping from pipe: %v", err) + } + if data[0] != ping { + t.Fatalf("expected ping, got %d", data[0]) + } + data[0] = pong + _, err = incoming.Write(data[:]) + if err != nil { + t.Fatalf("unable to write pong to pipe: %v", err) + } + }() + client, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatalf("unable to dial pipe: %v", err) + } + defer client.Close() + var data [1]byte + data[0] = ping + _, err = client.Write(data[:]) + if err != nil { + t.Fatalf("unable to write ping to pipe: %v", err) + } + _, err = client.Read(data[:]) + if err != nil { + t.Fatalf("unable to read pong from pipe: %v", err) + } + if data[0] != pong { + t.Fatalf("expected pong, got %d", data[0]) + } +} + +func TestDialUnknownFailsImmediately(t *testing.T) { + _, err := winpipe.Dial(randomPipePath(), nil, nil) + if err.(*os.PathError).Err != syscall.ENOENT { + t.Fatalf("expected ENOENT got %v", err) + } +} + +func TestDialListenerTimesOut(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + d := 10 * time.Millisecond + _, err = winpipe.Dial(pipePath, &d, nil) + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func TestDialContextListenerTimesOut(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + d := 10 * time.Millisecond + ctx, _ := context.WithTimeout(context.Background(), d) + _, err = winpipe.DialContext(ctx, pipePath, nil) + if err != context.DeadlineExceeded { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } +} + +func TestDialListenerGetsCancelled(t *testing.T) { + pipePath := randomPipePath() + ctx, cancel := context.WithCancel(context.Background()) + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + ch := make(chan error) + defer l.Close() + go func(ctx context.Context, ch chan error) { + _, err := winpipe.DialContext(ctx, pipePath, nil) + ch <- err + }(ctx, ch) + time.Sleep(time.Millisecond * 30) + cancel() + err = <-ch + if err != context.Canceled { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { + pipePath := randomPipePath() + sd, _ := windows.SecurityDescriptorFromString("D:P(A;;0x1200FF;;;WD)") + c := winpipe.ListenConfig{ + SecurityDescriptor: sd, + } + l, err := winpipe.Listen(pipePath, &c) + if err != nil { + t.Fatal(err) + } + defer l.Close() + _, err = winpipe.Dial(pipePath, nil, nil) + if err.(*os.PathError).Err != syscall.ERROR_ACCESS_DENIED { + t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) + } +} + +func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, err error) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, cfg) + if err != nil { + return + } + defer l.Close() + + type response struct { + c net.Conn + err error + } + ch := make(chan response) + go func() { + c, err := l.Accept() + ch <- response{c, err} + }() + + c, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + return + } + + r := <-ch + if err = r.err; err != nil { + c.Close() + return + } + + client = c + server = r.c + return +} + +func TestReadTimeout(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + + c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + + buf := make([]byte, 10) + _, err = c.Read(buf) + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func server(l net.Listener, ch chan int) { + c, err := l.Accept() + if err != nil { + panic(err) + } + rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) + s, err := rw.ReadString('\n') + if err != nil { + panic(err) + } + _, err = rw.WriteString("got " + s) + if err != nil { + panic(err) + } + err = rw.Flush() + if err != nil { + panic(err) + } + c.Close() + ch <- 1 +} + +func TestFullListenDialReadWrite(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + ch := make(chan int) + go server(l, ch) + + c, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) + _, err = rw.WriteString("hello world\n") + if err != nil { + t.Fatal(err) + } + err = rw.Flush() + if err != nil { + t.Fatal(err) + } + + s, err := rw.ReadString('\n') + if err != nil { + t.Fatal(err) + } + ms := "got hello world\n" + if s != ms { + t.Errorf("expected '%s', got '%s'", ms, s) + } + + <-ch +} + +func TestCloseAbortsListen(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + + ch := make(chan error) + go func() { + _, err := l.Accept() + ch <- err + }() + + time.Sleep(30 * time.Millisecond) + l.Close() + + err = <-ch + if err != net.ErrClosed { + t.Fatalf("expected net.ErrClosed, got %v", err) + } +} + +func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) { + b := make([]byte, 10) + w.Close() + n, err := r.Read(b) + if n > 0 { + t.Errorf("unexpected byte count %d", n) + } + if err != io.EOF { + t.Errorf("expected EOF: %v", err) + } +} + +func TestCloseClientEOFServer(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + ensureEOFOnClose(t, c, s) +} + +func TestCloseServerEOFClient(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + ensureEOFOnClose(t, s, c) +} + +func TestCloseWriteEOF(t *testing.T) { + cfg := &winpipe.ListenConfig{ + MessageMode: true, + } + c, s, err := getConnection(cfg) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + + type closeWriter interface { + CloseWrite() error + } + + err = c.(closeWriter).CloseWrite() + if err != nil { + t.Fatal(err) + } + + b := make([]byte, 10) + _, err = s.Read(b) + if err != io.EOF { + t.Fatal(err) + } +} + +func TestAcceptAfterCloseFails(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + l.Close() + _, err = l.Accept() + if err != net.ErrClosed { + t.Fatalf("expected net.ErrClosed, got %v", err) + } +} + +func TestDialTimesOutByDefault(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + _, err = winpipe.Dial(pipePath, nil, nil) + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func TestTimeoutPendingRead(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + serverDone := make(chan struct{}) + + go func() { + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + s.Close() + close(serverDone) + }() + + client, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + clientErr := make(chan error) + go func() { + buf := make([]byte, 10) + _, err = client.Read(buf) + clientErr <- err + }() + + time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline + client.SetReadDeadline(time.Unix(1, 0)) + + select { + case err = <-clientErr: + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timed out while waiting for read to cancel") + <-clientErr + } + <-serverDone +} + +func TestTimeoutPendingWrite(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + serverDone := make(chan struct{}) + + go func() { + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + s.Close() + close(serverDone) + }() + + client, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + clientErr := make(chan error) + go func() { + _, err = client.Write([]byte("this should timeout")) + clientErr <- err + }() + + time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline + client.SetWriteDeadline(time.Unix(1, 0)) + + select { + case err = <-clientErr: + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timed out while waiting for write to cancel") + <-clientErr + } + <-serverDone +} + +type CloseWriter interface { + CloseWrite() error +} + +func TestEchoWithMessaging(t *testing.T) { + c := winpipe.ListenConfig{ + MessageMode: true, // Use message mode so that CloseWrite() is supported + InputBufferSize: 65536, // Use 64KB buffers to improve performance + OutputBufferSize: 65536, + } + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, &c) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + listenerDone := make(chan bool) + clientDone := make(chan bool) + go func() { + // server echo + conn, e := l.Accept() + if e != nil { + t.Fatal(e) + } + defer conn.Close() + + time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent + io.Copy(conn, conn) + conn.(CloseWriter).CloseWrite() + close(listenerDone) + }() + timeout := 1 * time.Second + client, err := winpipe.Dial(pipePath, &timeout, nil) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + go func() { + // client read back + bytes := make([]byte, 2) + n, e := client.Read(bytes) + if e != nil { + t.Fatal(e) + } + if n != 2 { + t.Fatalf("expected 2 bytes, got %v", n) + } + close(clientDone) + }() + + payload := make([]byte, 2) + payload[0] = 0 + payload[1] = 1 + + n, err := client.Write(payload) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Fatalf("expected 2 bytes, got %v", n) + } + client.(CloseWriter).CloseWrite() + <-listenerDone + <-clientDone +} + +func TestConnectRace(t *testing.T) { + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + go func() { + for { + s, err := l.Accept() + if err == net.ErrClosed { + return + } + + if err != nil { + t.Fatal(err) + } + s.Close() + } + }() + + for i := 0; i < 1000; i++ { + c, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + c.Close() + } +} + +func TestMessageReadMode(t *testing.T) { + if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { + t.Skipf("Skipping on Windows %d", maj) + } + var wg sync.WaitGroup + defer wg.Wait() + pipePath := randomPipePath() + l, err := winpipe.Listen(pipePath, &winpipe.ListenConfig{MessageMode: true}) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + msg := ([]byte)("hello world") + + wg.Add(1) + go func() { + defer wg.Done() + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + _, err = s.Write(msg) + if err != nil { + t.Fatal(err) + } + s.Close() + }() + + c, err := winpipe.Dial(pipePath, nil, nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + mode := uint32(windows.PIPE_READMODE_MESSAGE) + err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil) + if err != nil { + t.Fatal(err) + } + + ch := make([]byte, 1) + var vmsg []byte + for { + n, err := c.Read(ch) + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatalf("expected 1, got %d", n) + } + vmsg = append(vmsg, ch[0]) + } + if !bytes.Equal(msg, vmsg) { + t.Fatalf("expected %s, got %s", msg, vmsg) + } +} + +func TestListenConnectRace(t *testing.T) { + if testing.Short() { + t.Skip("Skipping long race test") + } + pipePath := randomPipePath() + for i := 0; i < 50 && !t.Failed(); i++ { + var wg sync.WaitGroup + wg.Add(1) + go func() { + c, err := winpipe.Dial(pipePath, nil, nil) + if err == nil { + c.Close() + } + wg.Done() + }() + s, err := winpipe.Listen(pipePath, nil) + if err != nil { + t.Error(i, err) + } else { + s.Close() + } + wg.Wait() + } +} diff --git a/ipc/winpipe/zsyscall_windows.go b/ipc/winpipe/zsyscall_windows.go deleted file mode 100644 index 995432975..000000000 --- a/ipc/winpipe/zsyscall_windows.go +++ /dev/null @@ -1,238 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package winpipe - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - modntdll = windows.NewLazySystemDLL("ntdll.dll") - modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") - - procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe") - procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW") - procCreateFileW = modkernel32.NewProc("CreateFileW") - procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo") - procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW") - procLocalAlloc = modkernel32.NewProc("LocalAlloc") - procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile") - procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb") - procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U") - procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl") - procCancelIoEx = modkernel32.NewProc("CancelIoEx") - procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort") - procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus") - procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes") - procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult") -) - -func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) { - r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) { - var _p0 *uint16 - _p0, err = syscall.UTF16PtrFromString(name) - if err != nil { - return - } - return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa) -} - -func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0) - handle = windows.Handle(r0) - if handle == windows.InvalidHandle { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) { - var _p0 *uint16 - _p0, err = syscall.UTF16PtrFromString(name) - if err != nil { - return - } - return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile) -} - -func _createFile(name *uint16, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0) - handle = windows.Handle(r0) - if handle == windows.InvalidHandle { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) { - r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func localAlloc(uFlags uint32, length uint32) (ptr uintptr) { - r0, _, _ := syscall.Syscall(procLocalAlloc.Addr(), 2, uintptr(uFlags), uintptr(length), 0) - ptr = uintptr(r0) - return -} - -func ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) { - r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0) - status = ntstatus(r0) - return -} - -func rtlNtStatusToDosError(status ntstatus) (winerr error) { - r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0) - if r0 != 0 { - winerr = syscall.Errno(r0) - } - return -} - -func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) { - r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0) - status = ntstatus(r0) - return -} - -func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) { - r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0) - status = ntstatus(r0) - return -} - -func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) { - r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) { - r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0) - newport = windows.Handle(r0) - if newport == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) { - r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) { - var _p0 uint32 - if wait { - _p0 = 1 - } else { - _p0 = 0 - } - r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 4846e2f06..87a76072a 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -7,11 +7,8 @@ package netstack import ( "context" - "crypto/rand" - "encoding/binary" "errors" "fmt" - "io" "net" "os" "strconv" @@ -20,7 +17,6 @@ import ( "golang.zx2c4.com/wireguard/tun" - "golang.org/x/net/dns/dnsmessage" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -39,6 +35,7 @@ type netTun struct { incomingPacket chan buffer.VectorisedView mtu int dnsServers []net.IP + resolver *net.Resolver hasV4, hasV6 bool } type endpoint netTun @@ -129,6 +126,11 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) } + dev.resolver = &net.Resolver{ + PreferGo: true, + Dial: (*Net)(dev).dialDNS, + } + dev.events <- tun.EventUp return dev, (*Net)(dev), nil } @@ -247,458 +249,14 @@ func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { } var ( - errNoSuchHost = errors.New("no such host") - errLameReferral = errors.New("lame referral") - errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message") - errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message") - errServerMisbehaving = errors.New("server misbehaving") - errInvalidDNSResponse = errors.New("invalid DNS response") - errNoAnswerFromDNSServer = errors.New("no answer from DNS server") - errServerTemporarilyMisbehaving = errors.New("server misbehaving") - errCanceled = errors.New("operation was canceled") - errTimeout = errors.New("i/o timeout") - errNumericPort = errors.New("port must be numeric") - errNoSuitableAddress = errors.New("no suitable address found") - errMissingAddress = errors.New("missing address") + errCanceled = errors.New("operation was canceled") + errTimeout = errors.New("i/o timeout") + errNumericPort = errors.New("port must be numeric") + errNoSuitableAddress = errors.New("no suitable address found") + errMissingAddress = errors.New("missing address") ) -func (net *Net) LookupHost(host string) (addrs []string, err error) { - return net.LookupContextHost(context.Background(), host) -} - -func isDomainName(s string) bool { - l := len(s) - if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { - return false - } - last := byte('.') - nonNumeric := false - partlen := 0 - for i := 0; i < len(s); i++ { - c := s[i] - switch { - default: - return false - case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': - nonNumeric = true - partlen++ - case '0' <= c && c <= '9': - partlen++ - case c == '-': - if last == '.' { - return false - } - partlen++ - nonNumeric = true - case c == '.': - if last == '.' || last == '-' { - return false - } - if partlen > 63 || partlen == 0 { - return false - } - partlen = 0 - } - last = c - } - if last == '-' || partlen > 63 { - return false - } - return nonNumeric -} - -func randU16() uint16 { - var b [2]byte - _, err := rand.Read(b[:]) - if err != nil { - panic(err) - } - return binary.LittleEndian.Uint16(b[:]) -} - -func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { - id = randU16() - b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) - b.EnableCompression() - if err := b.StartQuestions(); err != nil { - return 0, nil, nil, err - } - if err := b.Question(q); err != nil { - return 0, nil, nil, err - } - tcpReq, err = b.Finish() - udpReq = tcpReq[2:] - l := len(tcpReq) - 2 - tcpReq[0] = byte(l >> 8) - tcpReq[1] = byte(l) - return id, udpReq, tcpReq, err -} - -func equalASCIIName(x, y dnsmessage.Name) bool { - if x.Length != y.Length { - return false - } - for i := 0; i < int(x.Length); i++ { - a := x.Data[i] - b := y.Data[i] - if 'A' <= a && a <= 'Z' { - a += 0x20 - } - if 'A' <= b && b <= 'Z' { - b += 0x20 - } - if a != b { - return false - } - } - return true -} - -func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool { - if !respHdr.Response { - return false - } - if reqID != respHdr.ID { - return false - } - if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) { - return false - } - return true -} - -func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { - if _, err := c.Write(b); err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - b = make([]byte, 512) - for { - n, err := c.Read(b) - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - var p dnsmessage.Parser - h, err := p.Start(b[:n]) - if err != nil { - continue - } - q, err := p.Question() - if err != nil || !checkResponse(id, query, h, q) { - continue - } - return p, h, nil - } -} - -func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { - if _, err := c.Write(b); err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - b = make([]byte, 1280) - if _, err := io.ReadFull(c, b[:2]); err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - l := int(b[0])<<8 | int(b[1]) - if l > len(b) { - b = make([]byte, l) - } - n, err := io.ReadFull(c, b[:l]) - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - var p dnsmessage.Parser - h, err := p.Start(b[:n]) - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage - } - q, err := p.Question() - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage - } - if !checkResponse(id, query, h, q) { - return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse - } - return p, h, nil -} - -func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { - q.Class = dnsmessage.ClassINET - id, udpReq, tcpReq, err := newRequest(q) - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage - } - - for _, useUDP := range []bool{true, false} { - ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) - defer cancel() - - var c net.Conn - var err error - if useUDP { - c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53}) - } else { - c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53}) - } - - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - if d, ok := ctx.Deadline(); ok && !d.IsZero() { - c.SetDeadline(d) - } - var p dnsmessage.Parser - var h dnsmessage.Header - if useUDP { - p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) - } else { - p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) - } - c.Close() - if err != nil { - if err == context.Canceled { - err = errCanceled - } else if err == context.DeadlineExceeded { - err = errTimeout - } - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone { - return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse - } - if h.Truncated { - continue - } - return p, h, nil - } - return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer -} - -func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { - if h.RCode == dnsmessage.RCodeNameError { - return errNoSuchHost - } - _, err := p.AnswerHeader() - if err != nil && err != dnsmessage.ErrSectionDone { - return errCannotUnmarshalDNSMessage - } - if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { - return errLameReferral - } - if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError { - if h.RCode == dnsmessage.RCodeServerFailure { - return errServerTemporarilyMisbehaving - } - return errServerMisbehaving - } - return nil -} - -func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { - for { - h, err := p.AnswerHeader() - if err == dnsmessage.ErrSectionDone { - return errNoSuchHost - } - if err != nil { - return errCannotUnmarshalDNSMessage - } - if h.Type == qtype { - return nil - } - if err := p.SkipAnswer(); err != nil { - return errCannotUnmarshalDNSMessage - } - } -} - -func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { - var lastErr error - - n, err := dnsmessage.NewName(name) - if err != nil { - return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage - } - q := dnsmessage.Question{ - Name: n, - Type: qtype, - Class: dnsmessage.ClassINET, - } - - for i := 0; i < 2; i++ { - for _, server := range tnet.dnsServers { - p, h, err := tnet.exchange(ctx, server, q, time.Second*5) - if err != nil { - dnsErr := &net.DNSError{ - Err: err.Error(), - Name: name, - Server: server.String(), - } - if nerr, ok := err.(net.Error); ok && nerr.Timeout() { - dnsErr.IsTimeout = true - } - if _, ok := err.(*net.OpError); ok { - dnsErr.IsTemporary = true - } - lastErr = dnsErr - continue - } - - if err := checkHeader(&p, h); err != nil { - dnsErr := &net.DNSError{ - Err: err.Error(), - Name: name, - Server: server.String(), - } - if err == errServerTemporarilyMisbehaving { - dnsErr.IsTemporary = true - } - if err == errNoSuchHost { - dnsErr.IsNotFound = true - return p, server.String(), dnsErr - } - lastErr = dnsErr - continue - } - - err = skipToAnswer(&p, qtype) - if err == nil { - return p, server.String(), nil - } - lastErr = &net.DNSError{ - Err: err.Error(), - Name: name, - Server: server.String(), - } - if err == errNoSuchHost { - lastErr.(*net.DNSError).IsNotFound = true - return p, server.String(), lastErr - } - } - } - return dnsmessage.Parser{}, "", lastErr -} - -func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) { - if host == "" || (!tnet.hasV6 && !tnet.hasV4) { - return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} - } - zlen := len(host) - if strings.IndexByte(host, ':') != -1 { - if zidx := strings.LastIndexByte(host, '%'); zidx != -1 { - zlen = zidx - } - } - if ip := net.ParseIP(host[:zlen]); ip != nil { - return []string{host[:zlen]}, nil - } - - if !isDomainName(host) { - return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} - } - type result struct { - p dnsmessage.Parser - server string - error - } - var addrsV4, addrsV6 []net.IP - lanes := 0 - if tnet.hasV4 { - lanes++ - } - if tnet.hasV6 { - lanes++ - } - lane := make(chan result, lanes) - var lastErr error - if tnet.hasV4 { - go func() { - p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA) - lane <- result{p, server, err} - }() - } - if tnet.hasV6 { - go func() { - p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA) - lane <- result{p, server, err} - }() - } - for l := 0; l < lanes; l++ { - result := <-lane - if result.error != nil { - if lastErr == nil { - lastErr = result.error - } - continue - } - - loop: - for { - h, err := result.p.AnswerHeader() - if err != nil && err != dnsmessage.ErrSectionDone { - lastErr = &net.DNSError{ - Err: errCannotMarshalDNSMessage.Error(), - Name: host, - Server: result.server, - } - } - if err != nil { - break - } - switch h.Type { - case dnsmessage.TypeA: - a, err := result.p.AResource() - if err != nil { - lastErr = &net.DNSError{ - Err: errCannotMarshalDNSMessage.Error(), - Name: host, - Server: result.server, - } - break loop - } - addrsV4 = append(addrsV4, net.IP(a.A[:])) - - case dnsmessage.TypeAAAA: - aaaa, err := result.p.AAAAResource() - if err != nil { - lastErr = &net.DNSError{ - Err: errCannotMarshalDNSMessage.Error(), - Name: host, - Server: result.server, - } - break loop - } - addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:])) - - default: - if err := result.p.SkipAnswer(); err != nil { - lastErr = &net.DNSError{ - Err: errCannotMarshalDNSMessage.Error(), - Name: host, - Server: result.server, - } - break loop - } - continue - } - } - } - // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled - var addrs []net.IP - if tnet.hasV6 { - addrs = append(addrsV6, addrsV4...) - } else { - addrs = append(addrsV4, addrsV6...) - } - - if len(addrs) == 0 && lastErr != nil { - return nil, lastErr - } - saddrs := make([]string, 0, len(addrs)) - for _, ip := range addrs { - saddrs = append(saddrs, ip.String()) - } - return saddrs, nil -} +func (net *Net) Resolver() *net.Resolver { return net.resolver } func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { if deadline.IsZero() { @@ -748,20 +306,25 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. if err != nil || port < 0 || port > 65535 { return nil, &net.OpError{Op: "dial", Err: errNumericPort} } - allAddr, err := tnet.LookupContextHost(ctx, host) - if err != nil { - return nil, &net.OpError{Op: "dial", Err: err} - } + var addrs []net.IP - for _, addr := range allAddr { - if strings.IndexByte(addr, ':') != -1 && acceptV6 { - addrs = append(addrs, net.ParseIP(addr)) - } else if strings.IndexByte(addr, '.') != -1 && acceptV4 { - addrs = append(addrs, net.ParseIP(addr)) + if addr := net.ParseIP(host); addr != nil { + addrs = []net.IP{addr} + } else { + allAddr, err := tnet.resolver.LookupHost(ctx, host) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: err} + } + for _, addr := range allAddr { + if strings.IndexByte(addr, ':') != -1 && acceptV6 { + addrs = append(addrs, net.ParseIP(addr)) + } else if strings.IndexByte(addr, '.') != -1 && acceptV4 { + addrs = append(addrs, net.ParseIP(addr)) + } + } + if len(addrs) == 0 && len(allAddr) != 0 { + return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} } - } - if len(addrs) == 0 && len(allAddr) != 0 { - return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} } var firstErr error @@ -816,3 +379,32 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. func (tnet *Net) Dial(network, address string) (net.Conn, error) { return tnet.DialContext(context.Background(), network, address) } + +func (tnet *Net) dialDNS(ctx context.Context, network, address string) (net.Conn, error) { + if len(tnet.dnsServers) == 0 { + return tnet.DialContext(ctx, network, address) + } + + dnsIPs := make(map[string]struct{}, len(tnet.dnsServers)) + for _, dnsServer := range tnet.dnsServers { + ipAddress := dnsServer.String() + if host, _, err := net.SplitHostPort(address); err != nil && host == ipAddress { + return tnet.DialContext(ctx, network, address) + } + dnsIPs[ipAddress] = struct{}{} + } + + var lastErr error + for ipAddress := range dnsIPs { + conn, err := tnet.DialContext(ctx, network, net.JoinHostPort(ipAddress, "53")) + if err != nil { + if nerr, ok := err.(*net.OpError); ok { + lastErr = nerr + continue + } + return nil, err + } + return conn, nil + } + return nil, lastErr +} diff --git a/tun/wintun/dll_fromrsrc_windows.go b/tun/wintun/dll_fromrsrc_windows.go index d107ba98b..dc70486fd 100644 --- a/tun/wintun/dll_fromrsrc_windows.go +++ b/tun/wintun/dll_fromrsrc_windows.go @@ -16,7 +16,6 @@ import ( "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/tun/wintun/memmod" - "golang.zx2c4.com/wireguard/tun/wintun/resource" ) type lazyDLL struct { @@ -37,11 +36,11 @@ func (d *lazyDLL) Load() error { } const ourModule windows.Handle = 0 - resInfo, err := resource.FindByName(ourModule, d.Name, resource.RT_RCDATA) + resInfo, err := windows.FindResource(ourModule, d.Name, windows.RT_RCDATA) if err != nil { return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err) } - data, err := resource.Load(ourModule, resInfo) + data, err := windows.LoadResourceData(ourModule, resInfo) if err != nil { return fmt.Errorf("Unable to load resource: %w", err) } diff --git a/tun/wintun/memmod/memmod_windows.go b/tun/wintun/memmod/memmod_windows.go index a9514c400..c75de5ade 100644 --- a/tun/wintun/memmod/memmod_windows.go +++ b/tun/wintun/memmod/memmod_windows.go @@ -312,7 +312,7 @@ func (module *Module) buildImportTable() error { module.modules = make([]windows.Handle, 0, 16) importDesc := (*IMAGE_IMPORT_DESCRIPTOR)(a2p(module.codeBase + uintptr(directory.VirtualAddress))) - for !isBadReadPtr(uintptr(unsafe.Pointer(importDesc)), unsafe.Sizeof(*importDesc)) && importDesc.Name != 0 { + for importDesc.Name != 0 { handle, err := windows.LoadLibraryEx(windows.BytePtrToString((*byte)(a2p(module.codeBase+uintptr(importDesc.Name)))), 0, windows.LOAD_LIBRARY_SEARCH_SYSTEM32) if err != nil { return fmt.Errorf("Error loading module: %w", err) diff --git a/tun/wintun/memmod/mksyscall.go b/tun/wintun/memmod/mksyscall.go deleted file mode 100644 index a78f613e9..000000000 --- a/tun/wintun/memmod/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package memmod - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go diff --git a/tun/wintun/memmod/syscall_windows.go b/tun/wintun/memmod/syscall_windows.go index 11715c03b..31dd0b5d9 100644 --- a/tun/wintun/memmod/syscall_windows.go +++ b/tun/wintun/memmod/syscall_windows.go @@ -324,8 +324,6 @@ const ( DLL_PROCESS_DETACH = 0 ) -//sys isBadReadPtr(addr uintptr, ucb uintptr) (ret bool) = kernel32.IsBadReadPtr - type SYSTEM_INFO struct { ProcessorArchitecture uint16 Reserved uint16 diff --git a/tun/wintun/memmod/zsyscall_windows.go b/tun/wintun/memmod/zsyscall_windows.go deleted file mode 100644 index 6a5b76f59..000000000 --- a/tun/wintun/memmod/zsyscall_windows.go +++ /dev/null @@ -1,50 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package memmod - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) - errERROR_EINVAL error = syscall.EINVAL -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return errERROR_EINVAL - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - - procIsBadReadPtr = modkernel32.NewProc("IsBadReadPtr") -) - -func isBadReadPtr(addr uintptr, ucb uintptr) (ret bool) { - r0, _, _ := syscall.Syscall(procIsBadReadPtr.Addr(), 2, uintptr(addr), uintptr(ucb), 0) - ret = r0 != 0 - return -} diff --git a/tun/wintun/resource/mksyscall.go b/tun/wintun/resource/mksyscall.go deleted file mode 100644 index 2013f24f2..000000000 --- a/tun/wintun/resource/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package resource - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go resource_windows.go diff --git a/tun/wintun/resource/resource_windows.go b/tun/wintun/resource/resource_windows.go deleted file mode 100644 index f63751830..000000000 --- a/tun/wintun/resource/resource_windows.go +++ /dev/null @@ -1,143 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package resource - -import ( - "errors" - "fmt" - "unsafe" - - "golang.org/x/sys/windows" -) - -func MAKEINTRESOURCE(i uint16) *uint16 { - return (*uint16)(unsafe.Pointer(uintptr(i))) -} - -// Predefined Resource Types -var ( - VS_VERSION_INFO uint16 = 1 - - RT_CURSOR = MAKEINTRESOURCE(1) - RT_BITMAP = MAKEINTRESOURCE(2) - RT_ICON = MAKEINTRESOURCE(3) - RT_MENU = MAKEINTRESOURCE(4) - RT_DIALOG = MAKEINTRESOURCE(5) - RT_STRING = MAKEINTRESOURCE(6) - RT_FONTDIR = MAKEINTRESOURCE(7) - RT_FONT = MAKEINTRESOURCE(8) - RT_ACCELERATOR = MAKEINTRESOURCE(9) - RT_RCDATA = MAKEINTRESOURCE(10) - RT_MESSAGETABLE = MAKEINTRESOURCE(11) - RT_GROUP_CURSOR = MAKEINTRESOURCE(12) - RT_GROUP_ICON = MAKEINTRESOURCE(14) - RT_VERSION = MAKEINTRESOURCE(16) - RT_DLGINCLUDE = MAKEINTRESOURCE(17) - RT_PLUGPLAY = MAKEINTRESOURCE(19) - RT_VXD = MAKEINTRESOURCE(20) - RT_ANICURSOR = MAKEINTRESOURCE(21) - RT_ANIICON = MAKEINTRESOURCE(22) - RT_HTML = MAKEINTRESOURCE(23) - RT_MANIFEST = MAKEINTRESOURCE(24) - CREATEPROCESS_MANIFEST_RESOURCE_ID = MAKEINTRESOURCE(1) - ISOLATIONAWARE_MANIFEST_RESOURCE_ID = MAKEINTRESOURCE(2) - ISOLATIONAWARE_NOSTATICIMPORT_MANIFEST_RESOURCE_ID = MAKEINTRESOURCE(3) - ISOLATIONPOLICY_MANIFEST_RESOURCE_ID = MAKEINTRESOURCE(4) - ISOLATIONPOLICY_BROWSER_MANIFEST_RESOURCE_ID = MAKEINTRESOURCE(5) - MINIMUM_RESERVED_MANIFEST_RESOURCE_ID = MAKEINTRESOURCE(1 /*inclusive*/) - MAXIMUM_RESERVED_MANIFEST_RESOURCE_ID = MAKEINTRESOURCE(16 /*inclusive*/) -) - -//sys findResource(module windows.Handle, name *uint16, resType *uint16) (resInfo windows.Handle, err error) = kernel32.FindResourceW - -func FindByID(module windows.Handle, id uint16, resType *uint16) (resInfo windows.Handle, err error) { - return findResource(module, MAKEINTRESOURCE(id), resType) -} - -func FindByName(module windows.Handle, name string, resType *uint16) (resInfo windows.Handle, err error) { - var name16 *uint16 - name16, err = windows.UTF16PtrFromString(name) - if err != nil { - return - } - resInfo, err = findResource(module, name16, resType) - return -} - -//sys sizeofResource(module windows.Handle, resInfo windows.Handle) (size uint32, err error) = kernel32.SizeofResource -//sys loadResource(module windows.Handle, resInfo windows.Handle) (resData windows.Handle, err error) = kernel32.LoadResource -//sys lockResource(resData windows.Handle) (addr uintptr, err error) = kernel32.LockResource - -func Load(module, resInfo windows.Handle) (data []byte, err error) { - size, err := sizeofResource(module, resInfo) - if err != nil { - err = fmt.Errorf("Unable to size resource: %w", err) - return - } - resData, err := loadResource(module, resInfo) - if err != nil { - err = fmt.Errorf("Unable to load resource: %w", err) - return - } - ptr, err := lockResource(resData) - if err != nil { - err = fmt.Errorf("Unable to lock resource: %w", err) - return - } - unsafeSlice(unsafe.Pointer(&data), unsafe.Pointer(ptr), int(size)) - return -} - -type VS_FIXEDFILEINFO struct { - Signature uint32 - StrucVersion uint32 - FileVersionMS uint32 - FileVersionLS uint32 - ProductVersionMS uint32 - ProductVersionLS uint32 - FileFlagsMask uint32 - FileFlags uint32 - FileOS uint32 - FileType uint32 - FileSubtype uint32 - FileDateMS uint32 - FileDateLS uint32 -} - -//sys verQueryValue(block *byte, section *uint16, value **byte, size *uint32) (err error) = version.VerQueryValueW - -func VerQueryRootValue(block []byte) (ffi *VS_FIXEDFILEINFO, err error) { - var data *byte - var size uint32 - err = verQueryValue(&block[0], windows.StringToUTF16Ptr("\\"), &data, &size) - if err != nil { - return - } - if uintptr(size) < unsafe.Sizeof(VS_FIXEDFILEINFO{}) { - err = errors.New("Incomplete VS_FIXEDFILEINFO") - return - } - ffi = (*VS_FIXEDFILEINFO)(unsafe.Pointer(data)) - return -} - -// unsafeSlice updates the slice slicePtr to be a slice -// referencing the provided data with its length & capacity set to -// lenCap. -// -// TODO: when Go 1.16 or Go 1.17 is the minimum supported version, -// update callers to use unsafe.Slice instead of this. -func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) { - type sliceHeader struct { - Data unsafe.Pointer - Len int - Cap int - } - h := (*sliceHeader)(slicePtr) - h.Data = data - h.Len = lenCap - h.Cap = lenCap -} diff --git a/tun/wintun/resource/zsyscall_windows.go b/tun/wintun/resource/zsyscall_windows.go deleted file mode 100644 index e4c4bf1c3..000000000 --- a/tun/wintun/resource/zsyscall_windows.go +++ /dev/null @@ -1,112 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package resource - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - modversion = windows.NewLazySystemDLL("version.dll") - - procFindResourceW = modkernel32.NewProc("FindResourceW") - procSizeofResource = modkernel32.NewProc("SizeofResource") - procLoadResource = modkernel32.NewProc("LoadResource") - procLockResource = modkernel32.NewProc("LockResource") - procVerQueryValueW = modversion.NewProc("VerQueryValueW") -) - -func findResource(module windows.Handle, name *uint16, resType *uint16) (resInfo windows.Handle, err error) { - r0, _, e1 := syscall.Syscall(procFindResourceW.Addr(), 3, uintptr(module), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(resType))) - resInfo = windows.Handle(r0) - if resInfo == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func sizeofResource(module windows.Handle, resInfo windows.Handle) (size uint32, err error) { - r0, _, e1 := syscall.Syscall(procSizeofResource.Addr(), 2, uintptr(module), uintptr(resInfo), 0) - size = uint32(r0) - if size == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func loadResource(module windows.Handle, resInfo windows.Handle) (resData windows.Handle, err error) { - r0, _, e1 := syscall.Syscall(procLoadResource.Addr(), 2, uintptr(module), uintptr(resInfo), 0) - resData = windows.Handle(r0) - if resData == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func lockResource(resData windows.Handle) (addr uintptr, err error) { - r0, _, e1 := syscall.Syscall(procLockResource.Addr(), 1, uintptr(resData), 0, 0) - addr = uintptr(r0) - if addr == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func verQueryValue(block *byte, section *uint16, value **byte, size *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procVerQueryValueW.Addr(), 4, uintptr(unsafe.Pointer(block)), uintptr(unsafe.Pointer(section)), uintptr(unsafe.Pointer(value)), uintptr(unsafe.Pointer(size)), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -}