diff --git a/src/fs/file_io_ext.rs b/src/fs/file_io_ext.rs index 6094f2b..36f4622 100644 --- a/src/fs/file_io_ext.rs +++ b/src/fs/file_io_ext.rs @@ -25,7 +25,7 @@ use rustix::io::{preadv, pwritev}; use std::io::{self, IoSlice, IoSliceMut, Seek, SeekFrom}; use std::slice; #[cfg(windows)] -use {cap_fs_ext::Reopen, fd_lock::RwLock, std::fs, std::os::windows::fs::FileExt}; +use {cap_fs_ext::Reopen, std::fs, std::os::windows::fs::FileExt}; #[cfg(not(windows))] use {rustix::fs::tell, rustix::fs::FileExt}; @@ -703,144 +703,37 @@ impl FileIoExt for std::fs::File { fn write_at(&self, buf: &[u8], offset: u64) -> io::Result { // Windows' `seek_write` modifies the current position in the file, so // re-open the file to leave the original open file unmodified. - // - // We take a lock so that we can test for writing past the end of the - // file and implement writing zeros if the offset is past the end. - // - // Windows documentation [says]: - // - // > A write operation increases the size of the file to the file - // > pointer position plus the size of the buffer written, which results - // > in the intervening bytes uninitialized. - // - // [says]: https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-setfilepointer#remarks - // - // However, our desired behavior is to fill the intervening bytes with - // zeros, so we do it ourselves. - let reopened = reopen_write(self)?; - let mut reopened = RwLock::new(reopened); - let reopened = reopened.write()?; - let mut seek_offset = offset; - let mut write_buf = buf; - let mut prepend_zeros; - let reopened_size = reopened.metadata()?.len(); - let num_zeros = offset.saturating_sub(reopened_size); - let num_zeros: usize = num_zeros - .try_into() - .map_err(|_| io::Error::new(io::ErrorKind::OutOfMemory, "write_all_vectored_at"))?; - if num_zeros > 0 { - prepend_zeros = vec![0_u8; num_zeros]; - prepend_zeros.extend_from_slice(buf); - seek_offset = reopened_size; - write_buf = &prepend_zeros; - } - let num_written = reopened - .seek_write(write_buf, seek_offset)? - .saturating_sub(num_zeros); - Ok(num_written) + reopen_write(self)?.seek_write(buf, offset) } #[inline] - fn write_all_at(&self, buf: &[u8], offset: u64) -> io::Result<()> { - // Similar to `read_exact_at`, re-open the file so that we can do a seek and - // leave the original file unmodified. - // - // Similar to `write_at`, we take a lock to do this. - let reopened = loop { - match reopen_write(self) { - Ok(file) => break file, - Err(err) if err.kind() == io::ErrorKind::Interrupted => continue, - Err(err) => return Err(err), - } - }; - let mut reopened = RwLock::new(reopened); - let reopened = reopened.write()?; - let mut seek_offset = offset; - let mut write_buf = buf; - let mut prepend_zeros; - let reopened_size = reopened.metadata()?.len(); - let num_zeros = offset.saturating_sub(reopened_size); - let num_zeros: usize = num_zeros - .try_into() - .map_err(|_| io::Error::new(io::ErrorKind::OutOfMemory, "write_all_vectored_at"))?; - if num_zeros > 0 { - prepend_zeros = vec![0_u8; num_zeros]; - prepend_zeros.extend_from_slice(buf); - seek_offset = reopened_size; - write_buf = &prepend_zeros; - } - loop { - match reopened.seek(SeekFrom::Start(seek_offset)) { - Ok(_) => break, - Err(err) if err.kind() == io::ErrorKind::Interrupted => continue, - Err(err) => return Err(err), - } + fn write_all_at(&self, mut buf: &[u8], mut offset: u64) -> io::Result<()> { + // Similar to `read_exact_at`, re-open the file so that we can do a seek + // and leave the original file unmodified. + let reopened = reopen_write(self)?; + while buf.len() > 0 { + let n = reopened.seek_write(buf, offset)?; + offset += u64::try_from(n).unwrap(); + buf = &buf[n..]; } - reopened.write_all(write_buf)?; Ok(()) } #[inline] fn write_vectored_at(&self, bufs: &[IoSlice], offset: u64) -> io::Result { - // Similar to `read_vectored_at`, re-open the file to avoid adjusting - // the current position of the already-open file. - // - // Similar to `write_at`, we take a lock to do this. - let reopened = reopen_write(self)?; - let mut reopened = RwLock::new(reopened); - let reopened = reopened.write()?; - let mut seek_offset = offset; - let mut write_bufs = bufs; - let zeros; - let mut prepend_zeros; - let reopened_size = reopened.metadata()?.len(); - let num_zeros = offset.saturating_sub(reopened_size); - let num_zeros: usize = num_zeros - .try_into() - .map_err(|_| io::Error::new(io::ErrorKind::OutOfMemory, "write_vectored_at"))?; - if num_zeros > 0 { - zeros = vec![0_u8; num_zeros]; - prepend_zeros = vec![IoSlice::new(&zeros)]; - prepend_zeros.extend_from_slice(bufs); - seek_offset = reopened_size; - write_bufs = &prepend_zeros; + // Windows doesn't have a vectored write for files, so pick the first + // non-empty slice and write that. + match bufs.iter().find(|p| p.len() > 0) { + Some(buf) => self.write_at(buf, offset), + None => Ok(0), } - reopened.seek(SeekFrom::Start(seek_offset))?; - let num_written = reopened - .write_vectored(write_bufs)? - .saturating_sub(num_zeros); - Ok(num_written) } #[inline] - fn write_all_vectored_at(&self, bufs: &mut [IoSlice], offset: u64) -> io::Result<()> { - // Similar to `read_vectored_at`, re-open the file to avoid adjusting - // the current position of the already-open file. - // - // Similar to `write_at`, we take a lock to do this. - let reopened = loop { - match reopen_write(self) { - Ok(file) => break file, - Err(err) if err.kind() == io::ErrorKind::Interrupted => continue, - Err(err) => return Err(err), - } - }; - let mut reopened = RwLock::new(reopened); - let reopened = reopened.write()?; - let reopened_size = reopened.metadata()?.len(); - let num_zeros = offset.saturating_sub(reopened_size); - let num_zeros: usize = num_zeros - .try_into() - .map_err(|_| io::Error::new(io::ErrorKind::OutOfMemory, "write_vectored_at"))?; - if num_zeros > 0 { - let zeros = vec![0_u8; num_zeros]; - let mut prepend_zeros = vec![IoSlice::new(&zeros)]; - prepend_zeros.extend_from_slice(bufs); - reopened.seek(SeekFrom::Start(reopened_size))?; - reopened.write_all_vectored(&mut prepend_zeros)?; - } else { - reopened.seek(SeekFrom::Start(offset))?; - reopened.write_all_vectored(bufs)?; + fn write_all_vectored_at(&self, bufs: &mut [IoSlice], mut offset: u64) -> io::Result<()> { + for buf in bufs { + self.write_all_at(buf, offset)?; + offset += u64::try_from(buf.len()).unwrap(); } Ok(()) } diff --git a/tests/vectored_at.rs b/tests/vectored_at.rs index 5917064..14efb6f 100644 --- a/tests/vectored_at.rs +++ b/tests/vectored_at.rs @@ -412,6 +412,7 @@ fn write_vectored_after_end() { let buf1 = b"MNOPQRST".to_vec(); let bufs = vec![IoSlice::new(&buf0), IoSlice::new(&buf1)]; let nwritten = check!(file.write_vectored_at(&bufs, 32)); + assert!(nwritten > 0); assert_eq!(check!(file.stream_position()), 26); let mut back = String::new(); check!(file.seek(std::io::SeekFrom::Start(0)));