Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(rust): Add output_schema to all PhysNodes #18272

Merged
merged 4 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 31 additions & 31 deletions crates/polars-core/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ where
}

impl Schema {
/// Create a new, empty schema
/// Create a new, empty schema.
pub fn new() -> Self {
Self::with_capacity(0)
}

/// Create a new, empty schema with capacity
/// Create a new, empty schema with the given capacity.
///
/// If you know the number of fields you have ahead of time, using this is more efficient than using
/// [`new`][Self::new]. Also consider using [`Schema::from_iter`] if you have the collection of fields available
Expand All @@ -87,7 +87,7 @@ impl Schema {
self.inner.reserve(additional);
}

/// The number of fields in the schema
/// The number of fields in the schema.
#[inline]
pub fn len(&self) -> usize {
self.inner.len()
Expand All @@ -98,7 +98,7 @@ impl Schema {
self.inner.is_empty()
}

/// Rename field `old` to `new`, and return the (owned) old name
/// Rename field `old` to `new`, and return the (owned) old name.
///
/// If `old` is not present in the schema, the schema is not modified and `None` is returned. Otherwise the schema
/// is updated and `Some(old_name)` is returned.
Expand All @@ -114,7 +114,7 @@ impl Schema {
Some(old_name)
}

/// Create a new schema from this one, inserting a field with `name` and `dtype` at the given `index`
/// Create a new schema from this one, inserting a field with `name` and `dtype` at the given `index`.
///
/// If a field named `name` already exists, it is updated with the new dtype. Regardless, the field named `name` is
/// always moved to the given index. Valid indices range from `0` (front of the schema) to `self.len()` (after the
Expand Down Expand Up @@ -150,7 +150,7 @@ impl Schema {
Ok(new)
}

/// Insert a field with `name` and `dtype` at the given `index` into this schema
/// Insert a field with `name` and `dtype` at the given `index` into this schema.
///
/// If a field named `name` already exists, it is updated with the new dtype. Regardless, the field named `name` is
/// always moved to the given index. Valid indices range from `0` (front of the schema) to `self.len()` (after the
Expand Down Expand Up @@ -189,32 +189,32 @@ impl Schema {
Ok(old_dtype)
}

/// Get a reference to the dtype of the field named `name`, or `None` if the field doesn't exist
/// Get a reference to the dtype of the field named `name`, or `None` if the field doesn't exist.
pub fn get(&self, name: &str) -> Option<&DataType> {
self.inner.get(name)
}

/// Get a reference to the dtype of the field named `name`, or `Err(PolarsErr)` if the field doesn't exist
/// Get a reference to the dtype of the field named `name`, or `Err(PolarsErr)` if the field doesn't exist.
pub fn try_get(&self, name: &str) -> PolarsResult<&DataType> {
self.get(name)
.ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
}

/// Get a mutable reference to the dtype of the field named `name`, or `Err(PolarsErr)` if the field doesn't exist
/// Get a mutable reference to the dtype of the field named `name`, or `Err(PolarsErr)` if the field doesn't exist.
pub fn try_get_mut(&mut self, name: &str) -> PolarsResult<&mut DataType> {
self.inner
.get_mut(name)
.ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
}

/// Return all data about the field named `name`: its index in the schema, its name, and its dtype
/// Return all data about the field named `name`: its index in the schema, its name, and its dtype.
///
/// Returns `Some((index, &name, &dtype))` if the field exists, `None` if it doesn't.
pub fn get_full(&self, name: &str) -> Option<(usize, &SmartString, &DataType)> {
self.inner.get_full(name)
}

/// Return all data about the field named `name`: its index in the schema, its name, and its dtype
/// Return all data about the field named `name`: its index in the schema, its name, and its dtype.
///
/// Returns `Ok((index, &name, &dtype))` if the field exists, `Err(PolarsErr)` if it doesn't.
pub fn try_get_full(&self, name: &str) -> PolarsResult<(usize, &SmartString, &DataType)> {
Expand All @@ -223,7 +223,7 @@ impl Schema {
.ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
}

/// Look up the name in the schema and return an owned [`Field`] by cloning the data
/// Look up the name in the schema and return an owned [`Field`] by cloning the data.
///
/// Returns `None` if the field does not exist.
///
Expand All @@ -235,7 +235,7 @@ impl Schema {
.map(|dtype| Field::new(name, dtype.clone()))
}

/// Look up the name in the schema and return an owned [`Field`] by cloning the data
/// Look up the name in the schema and return an owned [`Field`] by cloning the data.
///
/// Returns `Err(PolarsErr)` if the field does not exist.
///
Expand All @@ -248,7 +248,7 @@ impl Schema {
.map(|dtype| Field::new(name, dtype.clone()))
}

/// Get references to the name and dtype of the field at `index`
/// Get references to the name and dtype of the field at `index`.
///
/// If `index` is inbounds, returns `Some((&name, &dtype))`, else `None`. See
/// [`get_at_index_mut`][Self::get_at_index_mut] for a mutable version.
Expand All @@ -260,15 +260,15 @@ impl Schema {
self.inner.get_index(index).ok_or_else(|| polars_err!(ComputeError: "index {index} out of bounds with 'schema' of len: {}", self.len()))
}

/// Get mutable references to the name and dtype of the field at `index`
/// Get mutable references to the name and dtype of the field at `index`.
///
/// If `index` is inbounds, returns `Some((&mut name, &mut dtype))`, else `None`. See
/// [`get_at_index`][Self::get_at_index] for an immutable version.
pub fn get_at_index_mut(&mut self, index: usize) -> Option<(&mut SmartString, &mut DataType)> {
self.inner.get_index_mut2(index)
}

/// Swap-remove a field by name and, if the field existed, return its dtype
/// Swap-remove a field by name and, if the field existed, return its dtype.
///
/// If the field does not exist, the schema is not modified and `None` is returned.
///
Expand All @@ -279,7 +279,7 @@ impl Schema {
self.inner.swap_remove(name)
}

/// Remove a field by name, preserving order, and, if the field existed, return its dtype
/// Remove a field by name, preserving order, and, if the field existed, return its dtype.
///
/// If the field does not exist, the schema is not modified and `None` is returned.
///
Expand All @@ -289,7 +289,7 @@ impl Schema {
self.inner.shift_remove(name)
}

/// Remove a field by name, preserving order, and, if the field existed, return its dtype
/// Remove a field by name, preserving order, and, if the field existed, return its dtype.
///
/// If the field does not exist, the schema is not modified and `None` is returned.
///
Expand All @@ -299,12 +299,12 @@ impl Schema {
self.inner.shift_remove_index(index)
}

/// Whether the schema contains a field named `name`
/// Whether the schema contains a field named `name`.
pub fn contains(&self, name: &str) -> bool {
self.get(name).is_some()
}

/// Change the field named `name` to the given `dtype` and return the previous dtype
/// Change the field named `name` to the given `dtype` and return the previous dtype.
///
/// If `name` doesn't already exist in the schema, the schema is not modified and `None` is returned. Otherwise
/// returns `Some(old_dtype)`.
Expand All @@ -316,7 +316,7 @@ impl Schema {
Some(std::mem::replace(old_dtype, dtype))
}

/// Change the field at the given index to the given `dtype` and return the previous dtype
/// Change the field at the given index to the given `dtype` and return the previous dtype.
///
/// If the index is out of bounds, the schema is not modified and `None` is returned. Otherwise returns
/// `Some(old_dtype)`.
Expand All @@ -328,7 +328,7 @@ impl Schema {
Some(std::mem::replace(old_dtype, dtype))
}

/// Insert a new column in the [`Schema`]
/// Insert a new column in the [`Schema`].
///
/// If an equivalent name already exists in the schema: the name remains and
/// retains in its place in the order, its corresponding value is updated
Expand All @@ -344,7 +344,7 @@ impl Schema {
self.inner.insert(name, dtype)
}

/// Merge `other` into `self`
/// Merge `other` into `self`.
///
/// Merging logic:
/// - Fields that occur in `self` but not `other` are unmodified
Expand All @@ -355,7 +355,7 @@ impl Schema {
self.inner.extend(other.inner)
}

/// Merge borrowed `other` into `self`
/// Merge borrowed `other` into `self`.
///
/// Merging logic:
/// - Fields that occur in `self` but not `other` are unmodified
Expand All @@ -370,7 +370,7 @@ impl Schema {
)
}

/// Convert self to `ArrowSchema` by cloning the fields
/// Convert self to `ArrowSchema` by cloning the fields.
pub fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema {
let fields: Vec<_> = self
.inner
Expand All @@ -380,7 +380,7 @@ impl Schema {
ArrowSchema::from(fields)
}

/// Iterates the [`Field`]s in this schema, constructing them anew by cloning each `(&name, &dtype)` pair
/// Iterates the [`Field`]s in this schema, constructing them anew by cloning each `(&name, &dtype)` pair.
///
/// Note that this clones each name and dtype in order to form an owned [`Field`]. For a clone-free version, use
/// [`iter`][Self::iter], which returns `(&name, &dtype)`.
Expand All @@ -390,22 +390,22 @@ impl Schema {
.map(|(name, dtype)| Field::new(name, dtype.clone()))
}

/// Iterates over references to the dtypes in this schema
/// Iterates over references to the dtypes in this schema.
pub fn iter_dtypes(&self) -> impl '_ + ExactSizeIterator<Item = &DataType> {
self.inner.iter().map(|(_name, dtype)| dtype)
}

/// Iterates over mut references to the dtypes in this schema
/// Iterates over mut references to the dtypes in this schema.
pub fn iter_dtypes_mut(&mut self) -> impl '_ + ExactSizeIterator<Item = &mut DataType> {
self.inner.iter_mut().map(|(_name, dtype)| dtype)
}

/// Iterates over references to the names in this schema
/// Iterates over references to the names in this schema.
pub fn iter_names(&self) -> impl '_ + ExactSizeIterator<Item = &SmartString> {
self.inner.iter().map(|(name, _dtype)| name)
}

/// Iterates over the `(&name, &dtype)` pairs in this schema
/// Iterates over the `(&name, &dtype)` pairs in this schema.
///
/// For an owned version, use [`iter_fields`][Self::iter_fields], which clones the data to iterate owned `Field`s
pub fn iter(&self) -> impl Iterator<Item = (&SmartString, &DataType)> + '_ {
Expand Down Expand Up @@ -439,7 +439,7 @@ impl IntoIterator for Schema {
}
}

/// This trait exists to be unify the API of polars Schema and arrows Schema
/// This trait exists to be unify the API of polars Schema and arrows Schema.
pub trait IndexOfSchema: Debug {
/// Get the index of a column by name.
fn index_of(&self, name: &str) -> Option<usize>;
Expand Down
56 changes: 56 additions & 0 deletions crates/polars-plan/src/plans/ir/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,60 @@ impl IR {
};
Cow::Borrowed(schema)
}

/// Get the schema of the logical plan node, using caching.
#[recursive]
pub fn schema_with_cache<'a>(
node: Node,
arena: &'a Arena<IR>,
cache: &mut PlHashMap<Node, Arc<Schema>>,
) -> Arc<Schema> {
use IR::*;
if let Some(schema) = cache.get(&node) {
return schema.clone();
}

let schema = match arena.get(node) {
#[cfg(feature = "python")]
PythonScan { options } => options
.output_schema
.as_ref()
.unwrap_or(&options.schema)
.clone(),
Union { inputs, .. } => IR::schema_with_cache(inputs[0], arena, cache),
HConcat { schema, .. } => schema.clone(),
Cache { input, .. }
| Sort { input, .. }
| Filter { input, .. }
| Distinct { input, .. }
| Sink { input, .. }
| Slice { input, .. } => IR::schema_with_cache(*input, arena, cache),
Scan {
output_schema,
file_info,
..
} => output_schema.as_ref().unwrap_or(&file_info.schema).clone(),
DataFrameScan {
schema,
output_schema,
..
} => output_schema.as_ref().unwrap_or(schema).clone(),
Select { schema, .. }
| Reduce { schema, .. }
| GroupBy { schema, .. }
| Join { schema, .. }
| HStack { schema, .. }
| ExtContext { schema, .. }
| SimpleProjection {
columns: schema, ..
} => schema.clone(),
MapFunction { input, function } => {
let input_schema = IR::schema_with_cache(*input, arena, cache);
function.schema(&input_schema).unwrap().into_owned()
},
Invalid => unreachable!(),
};
cache.insert(node, schema.clone());
schema
}
}
Loading