Skip to content

Commit

Permalink
io: remove unsafe from ReadToString (#2384)
Browse files Browse the repository at this point in the history
  • Loading branch information
goffrie authored Apr 21, 2020
1 parent 7e88b56 commit 2da15b5
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 15 deletions.
2 changes: 1 addition & 1 deletion tokio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ default-features = false
optional = true

[dev-dependencies]
tokio-test = { version = "0.2.0" }
tokio-test = { version = "0.2.0", path = "../tokio-test" }
futures = { version = "0.3.0", features = ["async-await"] }
proptest = "0.9.4"
tempfile = "3.1.0"
Expand Down
36 changes: 22 additions & 14 deletions tokio/src/io/util/read_to_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::io::AsyncRead;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, mem, str};
use std::{io, mem};

cfg_io_util! {
/// Future for the [`read_to_string`](super::AsyncReadExt::read_to_string) method.
Expand All @@ -25,7 +25,7 @@ where
let start_len = buf.len();
ReadToString {
reader,
bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) },
bytes: mem::replace(buf, String::new()).into_bytes(),
buf,
start_len,
}
Expand All @@ -38,19 +38,20 @@ fn read_to_string_internal<R: AsyncRead + ?Sized>(
bytes: &mut Vec<u8>,
start_len: usize,
) -> Poll<io::Result<usize>> {
let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len));
if str::from_utf8(&bytes).is_err() {
Poll::Ready(ret.and_then(|_| {
Err(io::Error::new(
let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len))?;
match String::from_utf8(mem::replace(bytes, Vec::new())) {
Ok(string) => {
debug_assert!(buf.is_empty());
*buf = string;
Poll::Ready(Ok(ret))
}
Err(e) => {
*bytes = e.into_bytes();
Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
"stream did not contain valid UTF-8",
))
}))
} else {
debug_assert!(buf.is_empty());
// Safety: `bytes` is a valid UTF-8 because `str::from_utf8` returned `Ok`.
mem::swap(unsafe { buf.as_mut_vec() }, bytes);
Poll::Ready(ret)
)))
}
}
}

Expand All @@ -67,7 +68,14 @@ where
bytes,
start_len,
} = &mut *self;
read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len)
let ret = read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len);
if let Poll::Ready(Err(_)) = ret {
// Put back the original string.
bytes.truncate(*start_len);
**buf = String::from_utf8(mem::replace(bytes, Vec::new()))
.expect("original string no longer utf-8");
}
ret
}
}

Expand Down
49 changes: 49 additions & 0 deletions tokio/tests/read_to_string.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use std::io;
use tokio::io::AsyncReadExt;
use tokio_test::io::Builder;

#[tokio::test]
async fn to_string_does_not_truncate_on_utf8_error() {
let data = vec![0xff, 0xff, 0xff];

let mut s = "abc".to_string();

match AsyncReadExt::read_to_string(&mut data.as_slice(), &mut s).await {
Ok(len) => panic!("Should fail: {} bytes.", len),
Err(err) if err.to_string() == "stream did not contain valid UTF-8" => {}
Err(err) => panic!("Fail: {}.", err),
}

assert_eq!(s, "abc");
}

#[tokio::test]
async fn to_string_does_not_truncate_on_io_error() {
let mut mock = Builder::new()
.read(b"def")
.read_error(io::Error::new(io::ErrorKind::Other, "whoops"))
.build();
let mut s = "abc".to_string();

match AsyncReadExt::read_to_string(&mut mock, &mut s).await {
Ok(len) => panic!("Should fail: {} bytes.", len),
Err(err) if err.to_string() == "whoops" => {}
Err(err) => panic!("Fail: {}.", err),
}

assert_eq!(s, "abc");
}

#[tokio::test]
async fn to_string_appends() {
let data = b"def".to_vec();

let mut s = "abc".to_string();

let len = AsyncReadExt::read_to_string(&mut data.as_slice(), &mut s)
.await
.unwrap();

assert_eq!(len, 3);
assert_eq!(s, "abcdef");
}

0 comments on commit 2da15b5

Please sign in to comment.