Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure Binary -> Binview cast doesn't overflow the buffer size #15408

Merged
merged 2 commits into from
Mar 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 99 additions & 7 deletions crates/polars-arrow/src/compute/cast/utf8_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use polars_utils::slice::GetSaferUnchecked;
use polars_utils::vec::PushUnchecked;

use crate::array::*;
use crate::buffer::Buffer;
use crate::datatypes::ArrowDataType;
use crate::offset::Offset;
use crate::types::NativeType;
Expand Down Expand Up @@ -69,14 +70,51 @@ pub fn utf8_to_binary<O: Offset>(
}
}

// Different types to test the overflow path.
#[cfg(not(test))]
type OffsetType = u32;

// To trigger overflow
#[cfg(test)]
type OffsetType = i8;

// If we don't do this the GC of binview will trigger. As we will split up buffers into multiple
// chunks so that we don't overflow the offset u32.
fn truncate_buffer(buf: &Buffer<u8>) -> Buffer<u8> {
// * 2, as it must be able to hold u32::MAX offset + u32::MAX len.
buf.clone()
.sliced(0, std::cmp::min(buf.len(), OffsetType::MAX as usize * 2))
}

pub fn binary_to_binview<O: Offset>(arr: &BinaryArray<O>) -> BinaryViewArray {
let buffer_idx = 0_u32;
let base_ptr = arr.values().as_ptr() as usize;
// Ensure we didn't accidentally set wrong type
#[cfg(not(debug_assertions))]
{
assert_eq!(
std::mem::size_of::<u32>(),
std::mem::size_of::<OffsetType>()
);
}

let mut views = Vec::with_capacity(arr.len());
let mut uses_buffer = false;

let mut base_buffer = arr.values().clone();
// Offset into the buffer
let mut base_ptr = base_buffer.as_ptr() as usize;

// Offset into the binview buffers
let mut buffer_idx = 0_u32;

// Binview buffers
// Note that the buffer may look far further than u32::MAX, but as we don't clone data
let mut buffers = vec![truncate_buffer(&base_buffer)];

for bytes in arr.values_iter() {
let len: u32 = bytes.len().try_into().unwrap();
let len: u32 = bytes
.len()
.try_into()
.expect("max string/binary length exceeded");

let mut payload = [0; 16];
payload[0..4].copy_from_slice(&len.to_le_bytes());
Expand All @@ -85,18 +123,42 @@ pub fn binary_to_binview<O: Offset>(arr: &BinaryArray<O>) -> BinaryViewArray {
payload[4..4 + bytes.len()].copy_from_slice(bytes);
} else {
uses_buffer = true;

// Copy the parts we know are correct.
unsafe { payload[4..8].copy_from_slice(bytes.get_unchecked_release(0..4)) };
let offset = (bytes.as_ptr() as usize - base_ptr) as u32;
payload[0..4].copy_from_slice(&len.to_le_bytes());
payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes());
payload[12..16].copy_from_slice(&offset.to_le_bytes());

let current_bytes_ptr = bytes.as_ptr() as usize;
let offset = current_bytes_ptr - base_ptr;

// Here we check the overflow of the buffer offset.
if let Ok(offset) = OffsetType::try_from(offset) {
#[allow(clippy::unnecessary_cast)]
let offset = offset as u32;
payload[12..16].copy_from_slice(&offset.to_le_bytes());
payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes());
} else {
let len = base_buffer.len() - offset;

// Set new buffer
base_buffer = base_buffer.clone().sliced(offset, len);
base_ptr = base_buffer.as_ptr() as usize;

// And add the (truncated) one to the buffers
buffers.push(truncate_buffer(&base_buffer));
buffer_idx = buffer_idx.checked_add(1).expect("max buffers exceeded");

let offset = 0u32;
payload[12..16].copy_from_slice(&offset.to_le_bytes());
payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes());
}
}

let value = View::from_le_bytes(payload);
unsafe { views.push_unchecked(value) };
}
let buffers = if uses_buffer {
Arc::from([arr.values().clone()])
Arc::from(buffers)
} else {
Arc::from([])
};
Expand All @@ -114,3 +176,33 @@ pub fn binary_to_binview<O: Offset>(arr: &BinaryArray<O>) -> BinaryViewArray {
pub fn utf8_to_utf8view<O: Offset>(arr: &Utf8Array<O>) -> Utf8ViewArray {
unsafe { binary_to_binview(&arr.to_binary()).to_utf8view_unchecked() }
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn overflowing_utf8_to_binview() {
let values = [
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"123",
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"234",
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf",
"324",
];
let array = Utf8Array::<i64>::from_slice(values);

let out = utf8_to_utf8view(&array);
// Ensure we hit the multiple buffers part.
assert_eq!(out.buffers().len(), 6);
// Ensure we created a valid binview
let out = out.values_iter().collect::<Vec<_>>();
assert_eq!(out, values);
}
}
Loading