Skip to content

Commit

Permalink
perf: Batch utf8-validation in csv 18% / 25% on 1.9.0 (#19124)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 7, 2024
1 parent addaf83 commit 1e28cc7
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 28 deletions.
60 changes: 33 additions & 27 deletions crates/polars-io/src/csv/read/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ where

pub struct Utf8Field {
name: PlSmallStr,
mutable: MutableBinaryViewArray<str>,
mutable: MutableBinaryViewArray<[u8]>,
scratch: Vec<u8>,
quote_char: u8,
encoding: CsvEncoding,
Expand All @@ -172,7 +172,7 @@ impl Utf8Field {
}

#[inline]
fn validate_utf8(bytes: &[u8]) -> bool {
pub(super) fn validate_utf8(bytes: &[u8]) -> bool {
simdutf8::basic::from_utf8(bytes).is_ok()
}

Expand All @@ -190,7 +190,7 @@ impl ParsedBuffer for Utf8Field {
if missing_is_null {
self.mutable.push_null()
} else {
self.mutable.push(Some(""))
self.mutable.push(Some([]))
}
return Ok(());
}
Expand All @@ -199,7 +199,7 @@ impl ParsedBuffer for Utf8Field {
let escaped_bytes = if needs_escaping {
self.scratch.clear();
self.scratch.reserve(bytes.len());
polars_ensure!(bytes.len() > 1, ComputeError: "invalid csv file\n\nField `{}` is not properly escaped.", std::str::from_utf8(bytes).map_err(to_compute_err)?);
polars_ensure!(bytes.len() > 1 && bytes.last() == Some(&self.quote_char), ComputeError: "invalid csv file\n\nField `{}` is not properly escaped.", std::str::from_utf8(bytes).map_err(to_compute_err)?);

// SAFETY:
// we just allocated enough capacity and data_len is correct.
Expand All @@ -208,36 +208,41 @@ impl ParsedBuffer for Utf8Field {
escape_field(bytes, self.quote_char, self.scratch.spare_capacity_mut());
self.scratch.set_len(n_written);
}

self.scratch.as_slice()
} else {
bytes
};

// It is important that this happens after escaping, as invalid escaped string can produce
// invalid utf8.
let parse_result = validate_utf8(escaped_bytes);
if matches!(self.encoding, CsvEncoding::LossyUtf8) | ignore_errors {
// It is important that this happens after escaping, as invalid escaped string can produce
// invalid utf8.
let parse_result = validate_utf8(escaped_bytes);

match parse_result {
true => {
let value = unsafe { std::str::from_utf8_unchecked(escaped_bytes) };
self.mutable.push_value(value)
},
false => {
if matches!(self.encoding, CsvEncoding::LossyUtf8) {
// TODO! do this without allocating
let s = String::from_utf8_lossy(escaped_bytes);
self.mutable.push_value(s.as_ref())
} else if ignore_errors {
self.mutable.push_null()
} else {
// If field before escaping is valid utf8, the escaping is incorrect.
if needs_escaping && validate_utf8(bytes) {
polars_bail!(ComputeError: "string field is not properly escaped");
match parse_result {
true => {
let value = escaped_bytes;
self.mutable.push_value(value)
},
false => {
if matches!(self.encoding, CsvEncoding::LossyUtf8) {
// TODO! do this without allocating
let s = String::from_utf8_lossy(escaped_bytes);
self.mutable.push_value(s.as_ref().as_bytes())
} else if ignore_errors {
self.mutable.push_null()
} else {
polars_bail!(ComputeError: "invalid utf-8 sequence");
// If field before escaping is valid utf8, the escaping is incorrect.
if needs_escaping && validate_utf8(bytes) {
polars_bail!(ComputeError: "string field is not properly escaped");
} else {
polars_bail!(ComputeError: "invalid utf-8 sequence");
}
}
}
},
},
}
} else {
self.mutable.push_value(escaped_bytes)
}

Ok(())
Expand Down Expand Up @@ -631,7 +636,8 @@ impl Buffer {

Buffer::Utf8(v) => {
let arr = v.mutable.freeze();
StringChunked::with_chunk(v.name.clone(), arr).into_series()
StringChunked::with_chunk(v.name.clone(), unsafe { arr.to_utf8view_unchecked() })
.into_series()
},
#[allow(unused_variables)]
Buffer::Categorical(buf) => {
Expand Down
11 changes: 11 additions & 0 deletions crates/polars-io/src/csv/read/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,12 +479,23 @@ impl<'a> CoreReader<'a> {
total_offset += position + 1;
(b, count)
};
let check_utf8 = matches!(self.encoding, CsvEncoding::Utf8)
&& self.schema.iter_fields().any(|f| f.dtype().is_string());

if !b.is_empty() {
let results = results.clone();
let projection = projection.as_ref();
let slf = &(*self);
s.spawn(move |_| {
if check_utf8 && !super::buffer::validate_utf8(b) {
let mut results = results.lock().unwrap();
results.push((
b.as_ptr() as usize,
Err(polars_err!(ComputeError: "invalid utf-8 sequence")),
));
return;
}

let result = slf
.read_chunk(b, projection, 0, count, starting_point_offset, b.len())
.and_then(|mut df| {
Expand Down
7 changes: 6 additions & 1 deletion py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,10 +2064,15 @@ def test_read_csv_single_column(columns: list[str] | str) -> None:


def test_csv_invalid_escape_utf8_14960() -> None:
with pytest.raises(ComputeError, match=r"field is not properly escaped"):
with pytest.raises(ComputeError, match=r"Field .* is not properly escaped"):
pl.read_csv('col1\n""•'.encode())


def test_csv_invalid_escape() -> None:
with pytest.raises(ComputeError):
pl.read_csv(b'col1,col2\n"a,b')


@pytest.mark.slow
@pytest.mark.write_disk
def test_read_csv_only_loads_selected_columns(
Expand Down

0 comments on commit 1e28cc7

Please sign in to comment.