Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Assert chunks are equal after physical cast to prevent OOB #14873

Merged
merged 3 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,27 @@ impl DataType {
matches!(self, DataType::Boolean)
}

/// Check if this [`DataType`] is a list
pub fn is_list(&self) -> bool {
matches!(self, DataType::List(_))
}

pub fn is_nested(&self) -> bool {
self.is_list() || self.is_struct()
}

/// Check if this [`DataType`] is a struct
pub fn is_struct(&self) -> bool {
#[cfg(feature = "dtype-struct")]
{
matches!(self, DataType::Struct(_))
}
#[cfg(not(feature = "dtype-struct"))]
{
false
}
}

pub fn is_binary(&self) -> bool {
matches!(self, DataType::Binary)
}
Expand Down
23 changes: 21 additions & 2 deletions crates/polars-ops/src/chunked_array/gather/chunked.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::borrow::Cow;
use std::fmt::Debug;

use polars_core::prelude::gather::_update_gather_sorted_flag;
use polars_core::prelude::*;
use polars_core::series::IsSorted;
Expand Down Expand Up @@ -67,9 +70,24 @@ pub trait TakeChunked {
unsafe fn take_opt_chunked_unchecked(&self, by: &[ChunkId]) -> Self;
}

fn prepare_series(s: &Series) -> Cow<Series> {
let phys = if s.dtype().is_nested() {
Cow::Borrowed(s)
} else {
s.to_physical_repr()
};
// If this is hit the cast rechunked the data and the gather will OOB
assert_eq!(
phys.chunks().len(),
s.chunks().len(),
"implementation error"
);
phys
}

impl TakeChunked for Series {
unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self {
let phys = self.to_physical_repr();
let phys = prepare_series(self);
use DataType::*;
let out = match phys.dtype() {
dt if dt.is_numeric() => {
Expand Down Expand Up @@ -122,7 +140,7 @@ impl TakeChunked for Series {

/// Take function that checks of null state in `ChunkIdx`.
unsafe fn take_opt_chunked_unchecked(&self, by: &[NullableChunkId]) -> Self {
let phys = self.to_physical_repr();
let phys = prepare_series(self);
use DataType::*;
let out = match phys.dtype() {
dt if dt.is_numeric() => {
Expand Down Expand Up @@ -177,6 +195,7 @@ impl TakeChunked for Series {
impl<T> TakeChunked for ChunkedArray<T>
where
T: PolarsDataType,
T::Array: Debug,
{
unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self {
let arrow_dtype = self.dtype().to_arrow(true);
Expand Down
11 changes: 7 additions & 4 deletions crates/polars-ops/src/frame/join/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ pub trait JoinDispatch: IntoDf {
args: JoinArgs,
verbose: bool,
) -> PolarsResult<DataFrame> {
let ca_self = self.to_df();
let df_self = self.to_df();
#[cfg(feature = "dtype-categorical")]
_check_categorical_src(s_left.dtype(), s_right.dtype())?;

let mut left = ca_self.clone();
let mut left = df_self.clone();
let mut s_left = s_left.clone();
// Eagerly limit left if possible.
if let Some((offset, len)) = args.slice {
Expand All @@ -188,16 +188,19 @@ pub trait JoinDispatch: IntoDf {
}

// Ensure that the chunks are aligned otherwise we go OOB.
let mut right = other.clone();
let mut right = Cow::Borrowed(other);
let mut s_right = s_right.clone();
if left.should_rechunk() {
left.as_single_chunk_par();
s_left = s_left.rechunk();
}
if right.should_rechunk() {
right.as_single_chunk_par();
let mut other = other.clone();
other.as_single_chunk_par();
right = Cow::Owned(other);
s_right = s_right.rechunk();
}

let ids = sort_or_hash_left(&s_left, &s_right, verbose, args.validation, args.join_nulls)?;
left._finish_left_join(ids, &right.drop(s_right.name()).unwrap(), args)
}
Expand Down
2 changes: 2 additions & 0 deletions crates/polars/tests/it/chunks/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#[cfg(feature = "parquet")]
mod parquet;
38 changes: 38 additions & 0 deletions crates/polars/tests/it/chunks/parquet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use std::io::{Seek, SeekFrom};

use polars::prelude::*;

#[test]
fn test_cast_join_14872() {
let df1 = df![
"ints" => [1]
]
.unwrap();

let mut df2 = df![
"ints" => [0, 1],
"strings" => vec![Series::new("", ["a"]); 2],
]
.unwrap();

let mut buf = std::io::Cursor::new(vec![]);
ParquetWriter::new(&mut buf)
.with_row_group_size(Some(1))
.finish(&mut df2)
.unwrap();

let _ = buf.seek(SeekFrom::Start(0));
let df2 = ParquetReader::new(buf).finish().unwrap();

let out = df1
.join(&df2, ["ints"], ["ints"], JoinArgs::new(JoinType::Left))
.unwrap();

let expected = df![
"ints" => [1],
"strings" => vec![Series::new("", ["a"]); 1],
]
.unwrap();

assert!(expected.equals(&out));
}
1 change: 1 addition & 0 deletions crates/polars/tests/it/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ mod schema;
mod time;

mod arrow;
mod chunks;

pub static FOODS_CSV: &str = "../../examples/datasets/foods1.csv";
Loading