Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows: Don't error on broken non UTF-8 output #134534

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion library/std/src/io/buffered/linewritershim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,14 @@ impl<'a, W: ?Sized + Write> Write for LineWriterShim<'a, W> {
// the buffer?
// - If not, scan for the last newline that *does* fit in the buffer
let tail = if flushed >= newline_idx {
&buf[flushed..]
let tail = &buf[flushed..];
// Avoid unnecessary short writes by not splitting the remaining
// bytes if they're larger than the buffer.
// They can be written in full by the next call to write.
if tail.len() >= self.buffer.capacity() {
return Ok(flushed);
}
tail
} else if newline_idx - flushed <= self.buffer.capacity() {
&buf[flushed..newline_idx]
} else {
Expand Down
18 changes: 13 additions & 5 deletions library/std/src/io/buffered/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -847,22 +847,30 @@ fn long_line_flushed() {
}

/// Test that, given a very long partial line *after* successfully
/// flushing a complete line, the very long partial line is buffered
/// unconditionally, and no additional writes take place. This assures
/// flushing a complete line, no additional writes take place. This assures
/// the property that `write` should make at-most-one attempt to write
/// new data.
#[test]
fn line_long_tail_not_flushed() {
let writer = ProgrammableSink::default();
let mut writer = LineWriter::with_capacity(5, writer);

// Assert that Line 1\n is flushed, and 01234 is buffered
assert_eq!(writer.write(b"Line 1\n0123456789").unwrap(), 12);
// Assert that Line 1\n is flushed and the long tail isn't.
let bytes = b"Line 1\n0123456789";
writer.write(bytes).unwrap();
assert_eq!(&writer.get_ref().buffer, b"Line 1\n");
}

// Test that appending to a full buffer emits a single write, flushing the buffer.
#[test]
fn line_full_buffer_flushed() {
let writer = ProgrammableSink::default();
let mut writer = LineWriter::with_capacity(5, writer);
assert_eq!(writer.write(b"01234").unwrap(), 5);

// Because the buffer is full, this subsequent write will flush it
assert_eq!(writer.write(b"5").unwrap(), 1);
assert_eq!(&writer.get_ref().buffer, b"Line 1\n01234");
assert_eq!(&writer.get_ref().buffer, b"01234");
}

/// Test that, if an attempt to pre-flush buffered data returns Ok(0),
Expand Down
1 change: 1 addition & 0 deletions library/std/src/sys/pal/windows/c/bindings.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,7 @@ Windows.Win32.System.Console.ENABLE_VIRTUAL_TERMINAL_PROCESSING
Windows.Win32.System.Console.ENABLE_WINDOW_INPUT
Windows.Win32.System.Console.ENABLE_WRAP_AT_EOL_OUTPUT
Windows.Win32.System.Console.GetConsoleMode
Windows.Win32.System.Console.GetConsoleOutputCP
Windows.Win32.System.Console.GetStdHandle
Windows.Win32.System.Console.ReadConsoleW
Windows.Win32.System.Console.STD_ERROR_HANDLE
Expand Down
2 changes: 2 additions & 0 deletions library/std/src/sys/pal/windows/c/windows_sys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ windows_targets::link!("kernel32.dll" "system" fn FreeEnvironmentStringsW(penv :
windows_targets::link!("kernel32.dll" "system" fn GetActiveProcessorCount(groupnumber : u16) -> u32);
windows_targets::link!("kernel32.dll" "system" fn GetCommandLineW() -> PCWSTR);
windows_targets::link!("kernel32.dll" "system" fn GetConsoleMode(hconsolehandle : HANDLE, lpmode : *mut CONSOLE_MODE) -> BOOL);
windows_targets::link!("kernel32.dll" "system" fn GetConsoleOutputCP() -> u32);
windows_targets::link!("kernel32.dll" "system" fn GetCurrentDirectoryW(nbufferlength : u32, lpbuffer : PWSTR) -> u32);
windows_targets::link!("kernel32.dll" "system" fn GetCurrentProcess() -> HANDLE);
windows_targets::link!("kernel32.dll" "system" fn GetCurrentProcessId() -> u32);
Expand Down Expand Up @@ -3317,6 +3318,7 @@ pub struct XSAVE_FORMAT {
pub XmmRegisters: [M128A; 8],
pub Reserved4: [u8; 224],
}

#[cfg(target_arch = "arm")]
#[repr(C)]
pub struct WSADATA {
Expand Down
216 changes: 93 additions & 123 deletions library/std/src/sys/pal/windows/stdio.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
#![unstable(issue = "none", feature = "windows_stdio")]

use core::str::utf8_char_width;

use super::api::{self, WinError};
use crate::mem::MaybeUninit;
use crate::os::windows::io::{FromRawHandle, IntoRawHandle};
use crate::sys::handle::Handle;
use crate::sys::{c, cvt};
use crate::{cmp, io, ptr, str};
use crate::{cmp, io, ptr};

#[cfg(test)]
mod tests;
Expand All @@ -19,13 +17,9 @@ pub struct Stdin {
incomplete_utf8: IncompleteUtf8,
}

pub struct Stdout {
incomplete_utf8: IncompleteUtf8,
}
pub struct Stdout {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for the removal of incomplete UTF-8 handling at the end of the string is not clear to me from the commit description. Why was that removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it now simply truncates the write to remove incomplete UTF-8 from the end and instead leaves the buffering to buffer types, i.e. LineWriter in this case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A LineWriter will flush an incomplete line if its buffer capacity is exceeded. If that happens, the output must support partial UTF-8 writes, or non-ASCII characters might get lost or replaced with the replacement character.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That can only result in broken UTF-8 if the user writes incomplete UTF-8 to LineWriter themselves.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, digging through the source code, BufWriter makes sure to not split writes that the user issued.

What is the motivation for truncating invalid UTF-8 at the end of the string?

All else being equal, I'd rather expect the previous behavior, that I can construct UTF-8 output byte-by-byte.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than having a secret stack buffer that can't be inspected or flushed, I'd strongly prefer buffering be done at a higher level. It's also a lot of added complexity for an edge case where the better solution is to set the console code page.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, if that behavior is wanted, it should probably be documented in the commit message so that it is clear to future readers that this change was on purpose.

Rather than ignoring trailing invalid UTF-8, I think it'd be better to replace it with a replacement character so that it becomes clear that something was removed.

Copy link
Member Author

@ChrisDenton ChrisDenton Dec 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than ignoring trailing invalid UTF-8, I think it'd be better to replace it with a replacement character so that it becomes clear that something was removed.

That's what happens in this code. No bytes are ever lost. Either the caller is informed that less bytes were written than were provided or, if there is only an incomplete code point, then that is written to the console (which will be converted to replacement characters when lossy translating to UTF-16).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, that makes sense. I didn't realize the caller would be informed by the return value of write.

What is the motivation for special casing trailing invalid UTF-8? It seems to increase the code complexity a little as well, and is not necessary for std's own use cases.

Is it for supporting a potential non-std buffered writer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, it could be removed. Stderr is not buffered by us though and there have been proposals for unbuffered stdout.


pub struct Stderr {
incomplete_utf8: IncompleteUtf8,
}
pub struct Stderr {}

struct IncompleteUtf8 {
bytes: [u8; 4],
Expand Down Expand Up @@ -84,140 +78,69 @@ fn is_console(handle: c::HANDLE) -> bool {
unsafe { c::GetConsoleMode(handle, &mut mode) != 0 }
}

fn write(handle_id: u32, data: &[u8], incomplete_utf8: &mut IncompleteUtf8) -> io::Result<usize> {
/// Returns true if the attached console's code page is currently UTF-8.
#[cfg(not(target_vendor = "win7"))]
fn is_utf8_console() -> bool {
unsafe { c::GetConsoleOutputCP() == c::CP_UTF8 }
}

#[cfg(target_vendor = "win7")]
fn is_utf8_console() -> bool {
// Windows 7 has a fun "feature" where WriteFile on a console handle will return
// the number of UTF-16 code units written and not the number of bytes from the input string.
// So we always claim the console isn't UTF-8 to trigger the WriteConsole fallback code.
false
}

fn write(handle_id: u32, data: &[u8]) -> io::Result<usize> {
if data.is_empty() {
return Ok(0);
}

let handle = get_handle(handle_id)?;
if !is_console(handle) {
if !is_console(handle) || is_utf8_console() {
unsafe {
let handle = Handle::from_raw_handle(handle);
let ret = handle.write(data);
let _ = handle.into_raw_handle(); // Don't close the handle
return ret;
}
} else {
write_console_utf16(data, handle)
}

if incomplete_utf8.len > 0 {
assert!(
incomplete_utf8.len < 4,
"Unexpected number of bytes for incomplete UTF-8 codepoint."
);
if data[0] >> 6 != 0b10 {
// not a continuation byte - reject
incomplete_utf8.len = 0;
return Err(io::const_error!(
io::ErrorKind::InvalidData,
"Windows stdio in console mode does not support writing non-UTF-8 byte sequences",
));
}
incomplete_utf8.bytes[incomplete_utf8.len as usize] = data[0];
incomplete_utf8.len += 1;
let char_width = utf8_char_width(incomplete_utf8.bytes[0]);
if (incomplete_utf8.len as usize) < char_width {
// more bytes needed
return Ok(1);
}
let s = str::from_utf8(&incomplete_utf8.bytes[0..incomplete_utf8.len as usize]);
incomplete_utf8.len = 0;
match s {
Ok(s) => {
assert_eq!(char_width, s.len());
let written = write_valid_utf8_to_console(handle, s)?;
assert_eq!(written, s.len()); // guaranteed by write_valid_utf8_to_console() for single codepoint writes
return Ok(1);
}
Err(_) => {
return Err(io::const_error!(
io::ErrorKind::InvalidData,
"Windows stdio in console mode does not support writing non-UTF-8 byte sequences",
));
}
}
}

// As the console is meant for presenting text, we assume bytes of `data` are encoded as UTF-8,
// which needs to be encoded as UTF-16.
//
// If the data is not valid UTF-8 we write out as many bytes as are valid.
// If the first byte is invalid it is either first byte of a multi-byte sequence but the
// provided byte slice is too short or it is the first byte of an invalid multi-byte sequence.
let len = cmp::min(data.len(), MAX_BUFFER_SIZE / 2);
let utf8 = match str::from_utf8(&data[..len]) {
Ok(s) => s,
Err(ref e) if e.valid_up_to() == 0 => {
let first_byte_char_width = utf8_char_width(data[0]);
if first_byte_char_width > 1 && data.len() < first_byte_char_width {
incomplete_utf8.bytes[0] = data[0];
incomplete_utf8.len = 1;
return Ok(1);
} else {
return Err(io::const_error!(
io::ErrorKind::InvalidData,
"Windows stdio in console mode does not support writing non-UTF-8 byte sequences",
));
}
}
Err(e) => str::from_utf8(&data[..e.valid_up_to()]).unwrap(),
};

write_valid_utf8_to_console(handle, utf8)
}

fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usize> {
debug_assert!(!utf8.is_empty());

let mut utf16 = [MaybeUninit::<u16>::uninit(); MAX_BUFFER_SIZE / 2];
let utf8 = &utf8[..utf8.floor_char_boundary(utf16.len())];
fn write_console_utf16(data: &[u8], handle: c::HANDLE) -> io::Result<usize> {
let mut buffer = [MaybeUninit::<u16>::uninit(); MAX_BUFFER_SIZE / 2];
let data = &data[..data.len().min(buffer.len())];

// Split off any trailing incomplete UTF-8 from the end of the input.
let utf8 = trim_last_char_boundary(data);
let utf16 = utf8_to_utf16_lossy(utf8, &mut buffer);
debug_assert!(!utf16.is_empty());

// Write the UTF-16 chars to the console.
// This will succeed in one write so long as our [u16] slice is smaller than the console's buffer,
// which we've ensured by truncating the input (see `MAX_BUFFER_SIZE`).
let written = write_u16s(handle, &utf16)?;
debug_assert_eq!(written, utf16.len());
Ok(utf8.len())
}

let utf16: &[u16] = unsafe {
// Note that this theoretically checks validity twice in the (most common) case
// where the underlying byte sequence is valid utf-8 (given the check in `write()`).
fn utf8_to_utf16_lossy<'a>(utf8: &[u8], utf16: &'a mut [MaybeUninit<u16>]) -> &'a [u16] {
unsafe {
let result = c::MultiByteToWideChar(
c::CP_UTF8, // CodePage
c::MB_ERR_INVALID_CHARS, // dwFlags
0, // dwFlags
utf8.as_ptr(), // lpMultiByteStr
utf8.len() as i32, // cbMultiByte
utf16.as_mut_ptr() as *mut c::WCHAR, // lpWideCharStr
utf16.len() as i32, // cchWideChar
);
assert!(result != 0, "Unexpected error in MultiByteToWideChar");

// The only way an error can happen here is if we've messed up.
debug_assert!(result != 0, "Unexpected error in MultiByteToWideChar");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
debug_assert!(result != 0, "Unexpected error in MultiByteToWideChar");
assert!(result != 0, "Unexpected error in MultiByteToWideChar");

I think this should be an assert since this isn't performance critical — we've just done a syscall.

// Safety: MultiByteToWideChar initializes `result` values.
MaybeUninit::slice_assume_init_ref(&utf16[..result as usize])
};

let mut written = write_u16s(handle, utf16)?;

// Figure out how many bytes of as UTF-8 were written away as UTF-16.
if written == utf16.len() {
Ok(utf8.len())
} else {
// Make sure we didn't end up writing only half of a surrogate pair (even though the chance
// is tiny). Because it is not possible for user code to re-slice `data` in such a way that
// a missing surrogate can be produced (and also because of the UTF-8 validation above),
// write the missing surrogate out now.
// Buffering it would mean we have to lie about the number of bytes written.
let first_code_unit_remaining = utf16[written];
if matches!(first_code_unit_remaining, 0xDCEE..=0xDFFF) {
// low surrogate
// We just hope this works, and give up otherwise
let _ = write_u16s(handle, &utf16[written..written + 1]);
written += 1;
}
// Calculate the number of bytes of `utf8` that were actually written.
let mut count = 0;
for ch in utf16[..written].iter() {
count += match ch {
0x0000..=0x007F => 1,
0x0080..=0x07FF => 2,
0xDCEE..=0xDFFF => 1, // Low surrogate. We already counted 3 bytes for the other.
_ => 3,
};
}
debug_assert!(String::from_utf16(&utf16[..written]).unwrap() == utf8[..count]);
Ok(count)
}
}

Expand Down Expand Up @@ -410,13 +333,13 @@ impl IncompleteUtf8 {

impl Stdout {
pub const fn new() -> Stdout {
Stdout { incomplete_utf8: IncompleteUtf8::new() }
Stdout {}
}
}

impl io::Write for Stdout {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
write(c::STD_OUTPUT_HANDLE, buf, &mut self.incomplete_utf8)
write(c::STD_OUTPUT_HANDLE, buf)
}

fn flush(&mut self) -> io::Result<()> {
Expand All @@ -426,13 +349,13 @@ impl io::Write for Stdout {

impl Stderr {
pub const fn new() -> Stderr {
Stderr { incomplete_utf8: IncompleteUtf8::new() }
Stderr {}
}
}

impl io::Write for Stderr {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
write(c::STD_ERROR_HANDLE, buf, &mut self.incomplete_utf8)
write(c::STD_ERROR_HANDLE, buf)
}

fn flush(&mut self) -> io::Result<()> {
Expand All @@ -447,3 +370,50 @@ pub fn is_ebadf(err: &io::Error) -> bool {
pub fn panic_output() -> Option<impl io::Write> {
Some(Stderr::new())
}

/// Trim one incomplete UTF-8 char from the end of a byte slice.
///
/// If trimming would lead to an empty slice then it returns `bytes` instead.
///
/// Note: This function is optimized for size rather than speed.
pub fn trim_last_char_boundary(bytes: &[u8]) -> &[u8] {
// UTF-8's multiple-byte encoding uses the leading bits to encode the length of a code point.
// The bits of a multi-byte sequence are (where `n` is a placeholder for any bit):
//
// 11110nnn 10nnnnnn 10nnnnnn 10nnnnnn
// 1110nnnn 10nnnnnn 10nnnnnn
// 110nnnnn 10nnnnnn
//
// So if follows that an incomplete sequence is one of these:
// 11110nnn 10nnnnnn 10nnnnnn
// 11110nnn 10nnnnnn
// 1110nnnn 10nnnnnn
// 11110nnn
// 1110nnnn
// 110nnnnn

// Get up to three bytes from the end of the slice and encode them as a u32
// because it turns out the compiler is very good at optimizing numbers.
let u = match bytes {
[.., b1, b2, b3] => (*b1 as u32) << 16 | (*b2 as u32) << 8 | *b3 as u32,
[.., b1, b2] => (*b1 as u32) << 8 | *b2 as u32,
// If it's just a single byte or empty then we return the full slice
_ => return bytes,
};
if (u & 0b_11111000_11000000_11000000 == 0b_11110000_10000000_10000000) && bytes.len() >= 4 {
&bytes[..bytes.len() - 3]
} else if (u & 0b_11111000_11000000 == 0b_11110000_10000000
|| u & 0b_11110000_11000000 == 0b_11100000_10000000)
&& bytes.len() >= 3
{
&bytes[..bytes.len() - 2]
} else if (u & 0b_1111_1000 == 0b_1111_0000
|| u & 0b_11110000 == 0b_11100000
|| u & 0b_11100000 == 0b_11000000)
&& bytes.len() >= 2
{
&bytes[..bytes.len() - 1]
} else {
bytes
}
}
Loading