From b6acf6780e27f4d88ea35d31e388183f19f62c0e Mon Sep 17 00:00:00 2001 From: ritchie Date: Wed, 6 Mar 2024 10:15:50 +0100 Subject: [PATCH 1/3] fix: Assert chunks are equal after physical cast to prevent OOB --- crates/polars-core/src/datatypes/dtype.rs | 21 +++++++++++ .../src/chunked_array/gather/chunked.rs | 23 +++++++++++- .../src/frame/join/hash_join/mod.rs | 11 ++++-- crates/polars/tests/it/chunks/join.rs | 37 +++++++++++++++++++ crates/polars/tests/it/chunks/mod.rs | 1 + crates/polars/tests/it/main.rs | 1 + 6 files changed, 88 insertions(+), 6 deletions(-) create mode 100644 crates/polars/tests/it/chunks/join.rs create mode 100644 crates/polars/tests/it/chunks/mod.rs diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index 9fb5ea7f9ac9..1187402231e2 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -209,6 +209,27 @@ impl DataType { matches!(self, DataType::Boolean) } + /// Check if this [`DataType`] is a list + pub fn is_list(&self) -> bool { + matches!(self, DataType::List(_)) + } + + pub fn is_nested(&self) -> bool { + self.is_list() || self.is_struct() + } + + /// Check if this [`DataType`] is a struct + pub fn is_struct(&self) -> bool { + #[cfg(feature = "dtype-struct")] + { + matches!(self, DataType::Struct(_)) + } + #[cfg(not(feature = "dtype-struct"))] + { + false + } + } + pub fn is_binary(&self) -> bool { matches!(self, DataType::Binary) } diff --git a/crates/polars-ops/src/chunked_array/gather/chunked.rs b/crates/polars-ops/src/chunked_array/gather/chunked.rs index 3872c4603a71..044f2f5e4af9 100644 --- a/crates/polars-ops/src/chunked_array/gather/chunked.rs +++ b/crates/polars-ops/src/chunked_array/gather/chunked.rs @@ -1,3 +1,6 @@ +use std::borrow::Cow; +use std::fmt::Debug; + use polars_core::prelude::gather::_update_gather_sorted_flag; use polars_core::prelude::*; use polars_core::series::IsSorted; @@ -67,9 +70,24 @@ pub trait TakeChunked { unsafe fn take_opt_chunked_unchecked(&self, by: &[ChunkId]) -> Self; } +fn prepare_series(s: &Series) -> Cow { + let phys = if s.dtype().is_nested() { + Cow::Borrowed(s) + } else { + s.to_physical_repr() + }; + // If this is hit the cast rechunked the data and the gather will OOB + assert_eq!( + phys.chunks().len(), + s.chunks().len(), + "implementation error" + ); + phys +} + impl TakeChunked for Series { unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { - let phys = self.to_physical_repr(); + let phys = prepare_series(self); use DataType::*; let out = match phys.dtype() { dt if dt.is_numeric() => { @@ -122,7 +140,7 @@ impl TakeChunked for Series { /// Take function that checks of null state in `ChunkIdx`. unsafe fn take_opt_chunked_unchecked(&self, by: &[NullableChunkId]) -> Self { - let phys = self.to_physical_repr(); + let phys = prepare_series(self); use DataType::*; let out = match phys.dtype() { dt if dt.is_numeric() => { @@ -177,6 +195,7 @@ impl TakeChunked for Series { impl TakeChunked for ChunkedArray where T: PolarsDataType, + T::Array: Debug, { unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { let arrow_dtype = self.dtype().to_arrow(true); diff --git a/crates/polars-ops/src/frame/join/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs index dfe4ae71c00b..f07667130cc5 100644 --- a/crates/polars-ops/src/frame/join/hash_join/mod.rs +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -173,11 +173,11 @@ pub trait JoinDispatch: IntoDf { args: JoinArgs, verbose: bool, ) -> PolarsResult { - let ca_self = self.to_df(); + let df_self = self.to_df(); #[cfg(feature = "dtype-categorical")] _check_categorical_src(s_left.dtype(), s_right.dtype())?; - let mut left = ca_self.clone(); + let mut left = df_self.clone(); let mut s_left = s_left.clone(); // Eagerly limit left if possible. if let Some((offset, len)) = args.slice { @@ -188,16 +188,19 @@ pub trait JoinDispatch: IntoDf { } // Ensure that the chunks are aligned otherwise we go OOB. - let mut right = other.clone(); + let mut right = Cow::Borrowed(other); let mut s_right = s_right.clone(); if left.should_rechunk() { left.as_single_chunk_par(); s_left = s_left.rechunk(); } if right.should_rechunk() { - right.as_single_chunk_par(); + let mut other = other.clone(); + other.as_single_chunk_par(); + right = Cow::Owned(other); s_right = s_right.rechunk(); } + let ids = sort_or_hash_left(&s_left, &s_right, verbose, args.validation, args.join_nulls)?; left._finish_left_join(ids, &right.drop(s_right.name()).unwrap(), args) } diff --git a/crates/polars/tests/it/chunks/join.rs b/crates/polars/tests/it/chunks/join.rs new file mode 100644 index 000000000000..bd5c7cb91baa --- /dev/null +++ b/crates/polars/tests/it/chunks/join.rs @@ -0,0 +1,37 @@ +use std::io::{Seek, SeekFrom}; + +use polars::prelude::*; + +fn test_cast_join_14872() { + let df1 = df![ + "ints" => [1] + ] + .unwrap(); + + let mut df2 = df![ + "ints" => [0, 1], + "strings" => vec![Series::new("", ["a"]); 2], + ] + .unwrap(); + + let mut buf = std::io::Cursor::new(vec![]); + ParquetWriter::new(&mut buf) + .with_row_group_size(Some(1)) + .finish(&mut df2) + .unwrap(); + + let _ = buf.seek(SeekFrom::Start(0)); + let df2 = ParquetReader::new(buf).finish().unwrap(); + + let out = df1 + .join(&df2, ["ints"], ["ints"], JoinArgs::new(JoinType::Left)) + .unwrap(); + + let expected = df![ + "ints" => [1], + "strings" => vec![Series::new("", ["a"]); 1], + ] + .unwrap(); + + assert!(expected.equals(&out)); +} diff --git a/crates/polars/tests/it/chunks/mod.rs b/crates/polars/tests/it/chunks/mod.rs new file mode 100644 index 000000000000..670b0d2b7cde --- /dev/null +++ b/crates/polars/tests/it/chunks/mod.rs @@ -0,0 +1 @@ +mod join; diff --git a/crates/polars/tests/it/main.rs b/crates/polars/tests/it/main.rs index de6fd0d7d33e..4395ce47028f 100644 --- a/crates/polars/tests/it/main.rs +++ b/crates/polars/tests/it/main.rs @@ -7,5 +7,6 @@ mod schema; mod time; mod arrow; +mod chunks; pub static FOODS_CSV: &str = "../../examples/datasets/foods1.csv"; From 397fe7cf3eee5329a6e720306488be6b84089bde Mon Sep 17 00:00:00 2001 From: ritchie Date: Wed, 6 Mar 2024 10:19:26 +0100 Subject: [PATCH 2/3] tag test --- crates/polars/tests/it/chunks/join.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/polars/tests/it/chunks/join.rs b/crates/polars/tests/it/chunks/join.rs index bd5c7cb91baa..26c37566845a 100644 --- a/crates/polars/tests/it/chunks/join.rs +++ b/crates/polars/tests/it/chunks/join.rs @@ -2,6 +2,7 @@ use std::io::{Seek, SeekFrom}; use polars::prelude::*; +#[test] fn test_cast_join_14872() { let df1 = df![ "ints" => [1] From d0845e3f9acf7ff3a96c99237f534772ca0320bf Mon Sep 17 00:00:00 2001 From: ritchie Date: Wed, 6 Mar 2024 10:27:50 +0100 Subject: [PATCH 3/3] features --- crates/polars/tests/it/chunks/mod.rs | 3 ++- crates/polars/tests/it/chunks/{join.rs => parquet.rs} | 0 2 files changed, 2 insertions(+), 1 deletion(-) rename crates/polars/tests/it/chunks/{join.rs => parquet.rs} (100%) diff --git a/crates/polars/tests/it/chunks/mod.rs b/crates/polars/tests/it/chunks/mod.rs index 670b0d2b7cde..ab7fe5c8ec35 100644 --- a/crates/polars/tests/it/chunks/mod.rs +++ b/crates/polars/tests/it/chunks/mod.rs @@ -1 +1,2 @@ -mod join; +#[cfg(feature = "parquet")] +mod parquet; diff --git a/crates/polars/tests/it/chunks/join.rs b/crates/polars/tests/it/chunks/parquet.rs similarity index 100% rename from crates/polars/tests/it/chunks/join.rs rename to crates/polars/tests/it/chunks/parquet.rs