Skip to content

Commit

Permalink
fix: Infer reshape dims when determining schema (#18923)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Sep 26, 2024
1 parent 68b6f0e commit aec911f
Show file tree
Hide file tree
Showing 18 changed files with 285 additions and 110 deletions.
2 changes: 2 additions & 0 deletions crates/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ mod any_value;
mod dtype;
mod field;
mod into_scalar;
mod reshape;
#[cfg(feature = "object")]
mod static_array_collect;
mod time_unit;
Expand Down Expand Up @@ -41,6 +42,7 @@ use polars_utils::abs_diff::AbsDiff;
use polars_utils::float::IsFloat;
use polars_utils::min_max::MinMax;
use polars_utils::nulls::IsNull;
pub use reshape::*;
#[cfg(feature = "serde")]
use serde::de::{EnumAccess, Error, Unexpected, VariantAccess, Visitor};
#[cfg(any(feature = "serde", feature = "serde-lazy"))]
Expand Down
118 changes: 118 additions & 0 deletions crates/polars-core/src/datatypes/reshape.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use std::fmt;
use std::hash::Hash;
use std::num::NonZeroU64;

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[repr(transparent)]
pub struct Dimension(NonZeroU64);

/// A dimension in a reshape.
///
/// Any dimension smaller than 0 is seen as an `infer`.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ReshapeDimension {
Infer,
Specified(Dimension),
}

impl fmt::Debug for Dimension {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.get().fmt(f)
}
}

impl fmt::Display for ReshapeDimension {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Infer => f.write_str("inferred"),
Self::Specified(v) => v.get().fmt(f),
}
}
}

impl Hash for ReshapeDimension {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.to_repr().hash(state)
}
}

impl Dimension {
#[inline]
pub const fn new(v: u64) -> Self {
assert!(v <= i64::MAX as u64);

// SAFETY: Bounds check done before
let dim = unsafe { NonZeroU64::new_unchecked(v.wrapping_add(1)) };
Self(dim)
}

#[inline]
pub const fn get(self) -> u64 {
self.0.get() - 1
}
}

impl ReshapeDimension {
#[inline]
pub const fn new(v: i64) -> Self {
if v < 0 {
Self::Infer
} else {
// SAFETY: We have bounds checked for -1
let dim = unsafe { NonZeroU64::new_unchecked((v as u64).wrapping_add(1)) };
Self::Specified(Dimension(dim))
}
}

#[inline]
fn to_repr(self) -> u64 {
match self {
Self::Infer => 0,
Self::Specified(dim) => dim.0.get(),
}
}

#[inline]
pub const fn get(self) -> Option<u64> {
match self {
ReshapeDimension::Infer => None,
ReshapeDimension::Specified(dim) => Some(dim.get()),
}
}

#[inline]
pub const fn get_or_infer(self, inferred: u64) -> u64 {
match self {
ReshapeDimension::Infer => inferred,
ReshapeDimension::Specified(dim) => dim.get(),
}
}

#[inline]
pub fn get_or_infer_with(self, f: impl Fn() -> u64) -> u64 {
match self {
ReshapeDimension::Infer => f(),
ReshapeDimension::Specified(dim) => dim.get(),
}
}

pub const fn new_dimension(dimension: u64) -> ReshapeDimension {
Self::Specified(Dimension::new(dimension))
}
}

impl TryFrom<i64> for Dimension {
type Error = ();

#[inline]
fn try_from(value: i64) -> Result<Self, Self::Error> {
let ReshapeDimension::Specified(v) = ReshapeDimension::new(value) else {
return Err(());
};

Ok(v)
}
}
5 changes: 3 additions & 2 deletions crates/polars-core/src/frame/column/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use polars_utils::pl_str::PlSmallStr;
use self::gather::check_bounds_ca;
use crate::chunked_array::cast::CastOptions;
use crate::chunked_array::metadata::{MetadataFlags, MetadataTrait};
use crate::datatypes::ReshapeDimension;
use crate::prelude::*;
use crate::series::{BitRepr, IsSorted, SeriesPhysIter};
use crate::utils::{slice_offsets, Container};
Expand Down Expand Up @@ -730,15 +731,15 @@ impl Column {
self.as_materialized_series().unique().map(Column::from)
}

pub fn reshape_list(&self, dimensions: &[i64]) -> PolarsResult<Self> {
pub fn reshape_list(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Self> {
// @scalar-opt
self.as_materialized_series()
.reshape_list(dimensions)
.map(Self::from)
}

#[cfg(feature = "dtype-array")]
pub fn reshape_array(&self, dimensions: &[i64]) -> PolarsResult<Self> {
pub fn reshape_array(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Self> {
// @scalar-opt
self.as_materialized_series()
.reshape_array(dimensions)
Expand Down
10 changes: 6 additions & 4 deletions crates/polars-core/src/series/arithmetic/borrowed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,18 @@ impl NumOpsDispatchInner for BooleanType {
}

#[cfg(feature = "dtype-array")]
fn array_shape(dt: &DataType, infer: bool) -> Vec<i64> {
fn inner(dt: &DataType, buf: &mut Vec<i64>) {
fn array_shape(dt: &DataType, infer: bool) -> Vec<ReshapeDimension> {
fn inner(dt: &DataType, buf: &mut Vec<ReshapeDimension>) {
if let DataType::Array(_, size) = dt {
buf.push(*size as i64)
buf.push(ReshapeDimension::Specified(
Dimension::try_from(*size as i64).unwrap(),
))
}
}

let mut buf = vec![];
if infer {
buf.push(-1)
buf.push(ReshapeDimension::Infer)
}
inner(dt, &mut buf);
buf
Expand Down
118 changes: 61 additions & 57 deletions crates/polars-core/src/series/ops/reshape.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
use std::borrow::Cow;
#[cfg(feature = "dtype-array")]
use std::cmp::Ordering;
#[cfg(feature = "dtype-array")]
use std::collections::VecDeque;

use arrow::array::*;
use arrow::legacy::kernels::list::array_to_unit_list;
use arrow::offset::Offsets;
use polars_error::{polars_bail, polars_ensure, PolarsResult};
#[cfg(feature = "dtype-array")]
use polars_utils::format_tuple;

use crate::chunked_array::builder::get_list_builder;
Expand Down Expand Up @@ -90,70 +85,70 @@ impl Series {
}

#[cfg(feature = "dtype-array")]
pub fn reshape_array(&self, dimensions: &[i64]) -> PolarsResult<Series> {
pub fn reshape_array(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {
polars_ensure!(
!dimensions.is_empty(),
InvalidOperation: "at least one dimension must be specified"
);

let mut dims = dimensions.iter().copied().collect::<VecDeque<_>>();

let leaf_array = self.get_leaf_array();
let size = leaf_array.len();

let mut total_dim_size = 1;
let mut infer_dim_index: Option<usize> = None;
for (index, &dim) in dims.iter().enumerate() {
match dim.cmp(&0) {
Ordering::Greater => total_dim_size *= dim as usize,
Ordering::Equal => {
let mut num_infers = 0;
for (index, &dim) in dimensions.iter().enumerate() {
match dim {
ReshapeDimension::Infer => {
polars_ensure!(
index == 0,
InvalidOperation: "cannot reshape array into shape containing a zero dimension after the first: {}",
format_tuple!(dims)
num_infers == 0,
InvalidOperation: "can only specify one inferred dimension"
);
total_dim_size = 0;
// We can early exit here, as empty arrays will error with multiple dimensions,
// and non-empty arrays will error when the first dimension is zero.
break;
num_infers += 1;
},
Ordering::Less => {
polars_ensure!(
infer_dim_index.is_none(),
InvalidOperation: "can only specify one unknown dimension"
);
infer_dim_index = Some(index);
ReshapeDimension::Specified(dim) => {
let dim = dim.get();

if dim > 0 {
total_dim_size *= dim as usize
} else {
polars_ensure!(
index == 0,
InvalidOperation: "cannot reshape array into shape containing a zero dimension after the first: {}",
format_tuple!(dimensions)
);
total_dim_size = 0;
// We can early exit here, as empty arrays will error with multiple dimensions,
// and non-empty arrays will error when the first dimension is zero.
break;
}
},
}
}

if size == 0 {
if dims.len() > 1 || (infer_dim_index.is_none() && total_dim_size != 0) {
polars_bail!(InvalidOperation: "cannot reshape empty array into shape {}", format_tuple!(dims))
if dimensions.len() > 1 || (num_infers == 0 && total_dim_size != 0) {
polars_bail!(InvalidOperation: "cannot reshape empty array into shape {}", format_tuple!(dimensions))
}
} else if total_dim_size == 0 {
polars_bail!(InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", format_tuple!(dims))
polars_bail!(InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", format_tuple!(dimensions))
} else {
polars_ensure!(
size % total_dim_size == 0,
InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dims)
InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions)
);
}

// Infer dimension
if let Some(index) = infer_dim_index {
let inferred_dim = size / total_dim_size;
let item = dims.get_mut(index).unwrap();
*item = i64::try_from(inferred_dim).unwrap();
}

let leaf_array = leaf_array.rechunk();
let mut prev_dtype = leaf_array.dtype().clone();
let mut prev_array = leaf_array.chunks()[0].clone();

// We pop the outer dimension as that is the height of the series.
let _ = dims.pop_front();
while let Some(dim) = dims.pop_back() {
for idx in (1..dimensions.len()).rev() {
// Infer dimension if needed
let dim = dimensions[idx].get_or_infer_with(|| {
debug_assert!(num_infers > 0);
(size / total_dim_size) as u64
});
prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);

prev_array = FixedSizeListArray::new(
Expand All @@ -172,7 +167,7 @@ impl Series {
})
}

pub fn reshape_list(&self, dimensions: &[i64]) -> PolarsResult<Series> {
pub fn reshape_list(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {
polars_ensure!(
!dimensions.is_empty(),
InvalidOperation: "at least one dimension must be specified"
Expand All @@ -187,38 +182,43 @@ impl Series {

let s_ref = s.as_ref();

let dimensions = dimensions.to_vec();
// let dimensions = dimensions.to_vec();

match dimensions.len() {
1 => {
polars_ensure!(
dimensions[0] as usize == s_ref.len() || dimensions[0] == -1_i64,
dimensions[0].get().map_or(true, |dim| dim as usize == s_ref.len()),
InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions,
);
Ok(s_ref.clone())
},
2 => {
let mut rows = dimensions[0];
let mut cols = dimensions[1];
let rows = dimensions[0];
let cols = dimensions[1];

if s_ref.len() == 0_usize {
if (rows == -1 || rows == 0) && (cols == -1 || cols == 0 || cols == 1) {
if rows.get_or_infer(0) == 0 && cols.get_or_infer(0) <= 1 {
let s = reshape_fast_path(s.name().clone(), s_ref);
return Ok(s);
} else {
polars_bail!(InvalidOperation: "cannot reshape len 0 into shape {:?}", dimensions,)
polars_bail!(InvalidOperation: "cannot reshape len 0 into shape {}", format_tuple!(dimensions))
}
}

use ReshapeDimension as RD;
// Infer dimension.
if rows == -1 && cols >= 1 {
rows = s_ref.len() as i64 / cols
} else if cols == -1 && rows >= 1 {
cols = s_ref.len() as i64 / rows
} else if rows == -1 && cols == -1 {
rows = s_ref.len() as i64;
cols = 1_i64;
}

let (rows, cols) = match (rows, cols) {
(RD::Infer, RD::Specified(cols)) if cols.get() >= 1 => {
(s_ref.len() as u64 / cols.get(), cols.get())
},
(RD::Specified(rows), RD::Infer) if rows.get() >= 1 => {
(rows.get(), s_ref.len() as u64 / rows.get())
},
(RD::Infer, RD::Infer) => (s_ref.len() as u64, 1u64),
(RD::Specified(rows), RD::Specified(cols)) => (rows.get(), cols.get()),
_ => polars_bail!(InvalidOperation: "reshape of non-zero list into zero list"),
};

// Fast path, we can create a unit list so we only allocate offsets.
if rows as usize == s_ref.len() && cols == 1 {
Expand All @@ -234,9 +234,9 @@ impl Series {
let mut builder =
get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name().clone())?;

let mut offset = 0i64;
let mut offset = 0u64;
for _ in 0..rows {
let row = s_ref.slice(offset, cols as usize);
let row = s_ref.slice(offset as i64, cols as usize);
builder.append_series(&row).unwrap();
offset += cols;
}
Expand Down Expand Up @@ -279,7 +279,11 @@ mod test {
(&[-1, 2], 2),
(&[2, -1], 2),
] {
let out = s.reshape_list(dims)?;
let dims = dims
.iter()
.map(|&v| ReshapeDimension::new(v))
.collect::<Vec<_>>();
let out = s.reshape_list(&dims)?;
assert_eq!(out.len(), list_len);
assert!(matches!(out.dtype(), DataType::List(_)));
assert_eq!(out.explode()?.len(), 4);
Expand Down
Loading

0 comments on commit aec911f

Please sign in to comment.