Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(rust): Remove extra schema traits #18616

Merged
merged 4 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

119 changes: 35 additions & 84 deletions crates/polars-core/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,97 +90,46 @@ impl SchemaExt for Schema {
}
}

/// This trait exists to be unify the API of polars Schema and arrows Schema.
pub trait IndexOfSchema: Debug {
/// Get the index of a column by name.
fn index_of(&self, name: &str) -> Option<usize>;

/// Get a vector of all column names.
fn get_names(&self) -> Vec<&PlSmallStr>;

fn get_names_str(&self) -> Vec<&str>;

fn get_names_owned(&self) -> Vec<PlSmallStr>;

fn try_index_of(&self, name: &str) -> PolarsResult<usize> {
self.index_of(name).ok_or_else(|| {
polars_err!(
ColumnNotFound:
"unable to find column {:?}; valid columns: {:?}", name, self.get_names(),
)
})
}
}

impl IndexOfSchema for Schema {
fn index_of(&self, name: &str) -> Option<usize> {
self.index_of(name)
}

fn get_names(&self) -> Vec<&PlSmallStr> {
self.iter_names().collect()
}

fn get_names_owned(&self) -> Vec<PlSmallStr> {
self.iter_names().cloned().collect()
}
pub trait SchemaNamesAndDtypes {
const IS_ARROW: bool;
type DataType: Debug + Clone + Default + PartialEq;

fn get_names_str(&self) -> Vec<&str> {
self.iter_names().map(|x| x.as_str()).collect()
}
fn iter_names_and_dtypes(
&self,
) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)>;
}

impl IndexOfSchema for ArrowSchema {
fn index_of(&self, name: &str) -> Option<usize> {
self.iter_values().position(|f| f.name.as_str() == name)
}

fn get_names(&self) -> Vec<&PlSmallStr> {
self.iter_values().map(|f| &f.name).collect()
}

fn get_names_owned(&self) -> Vec<PlSmallStr> {
self.iter_values().map(|f| f.name.clone()).collect()
}
impl SchemaNamesAndDtypes for ArrowSchema {
const IS_ARROW: bool = true;
type DataType = ArrowDataType;

fn get_names_str(&self) -> Vec<&str> {
self.iter_values().map(|f| f.name.as_str()).collect()
fn iter_names_and_dtypes(
&self,
) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)> {
self.iter_values().map(|x| (&x.name, &x.dtype))
}
}

pub trait SchemaNamesAndDtypes {
const IS_ARROW: bool;
type DataType: Debug + PartialEq;

/// Get a vector of (name, dtype) pairs
fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, Self::DataType)>;
}

impl SchemaNamesAndDtypes for Schema {
const IS_ARROW: bool = false;
type DataType = DataType;

fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, Self::DataType)> {
fn iter_names_and_dtypes(
&self,
) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)> {
self.iter()
.map(|(name, dtype)| (name.as_str(), dtype.clone()))
.collect()
}
}

impl SchemaNamesAndDtypes for ArrowSchema {
const IS_ARROW: bool = true;
type DataType = ArrowDataType;

fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, Self::DataType)> {
self.iter_values()
.map(|x| (x.name.as_str(), x.dtype.clone()))
.collect()
}
}

pub fn ensure_matching_schema<S: SchemaNamesAndDtypes>(lhs: &S, rhs: &S) -> PolarsResult<()> {
let lhs = lhs.get_names_and_dtypes();
let rhs = rhs.get_names_and_dtypes();
pub fn ensure_matching_schema<D>(
lhs: &polars_schema::Schema<D>,
rhs: &polars_schema::Schema<D>,
) -> PolarsResult<()>
where
polars_schema::Schema<D>: SchemaNamesAndDtypes,
{
let lhs = lhs.iter_names_and_dtypes();
let rhs = rhs.iter_names_and_dtypes();

if lhs.len() != rhs.len() {
polars_bail!(
Expand All @@ -190,7 +139,7 @@ pub fn ensure_matching_schema<S: SchemaNamesAndDtypes>(lhs: &S, rhs: &S) -> Pola
);
}

for (i, ((l_name, l_dtype), (r_name, r_dtype))) in lhs.iter().zip(&rhs).enumerate() {
for (i, ((l_name, l_dtype), (r_name, r_dtype))) in lhs.zip(rhs).enumerate() {
if l_name != r_name {
polars_bail!(
SchemaMismatch:
Expand All @@ -199,18 +148,20 @@ pub fn ensure_matching_schema<S: SchemaNamesAndDtypes>(lhs: &S, rhs: &S) -> Pola
)
}
if l_dtype != r_dtype
&& (!S::IS_ARROW
&& (!polars_schema::Schema::<D>::IS_ARROW
|| unsafe {
// For timezone normalization. Easier than writing out the entire PartialEq.
DataType::from_arrow(
std::mem::transmute::<&<S as SchemaNamesAndDtypes>::DataType, &ArrowDataType>(
l_dtype,
),
std::mem::transmute::<
&<polars_schema::Schema<D> as SchemaNamesAndDtypes>::DataType,
&ArrowDataType,
>(l_dtype),
true,
) != DataType::from_arrow(
std::mem::transmute::<&<S as SchemaNamesAndDtypes>::DataType, &ArrowDataType>(
r_dtype,
),
std::mem::transmute::<
&<polars_schema::Schema<D> as SchemaNamesAndDtypes>::DataType,
&ArrowDataType,
>(r_dtype),
true,
)
})
Expand Down
1 change: 1 addition & 0 deletions crates/polars-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ polars-core = { workspace = true }
polars-error = { workspace = true }
polars-json = { workspace = true, optional = true }
polars-parquet = { workspace = true, optional = true }
polars-schema = { workspace = true }
polars-time = { workspace = true, features = [], optional = true }
polars-utils = { workspace = true, features = ['mmap'] }

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/csv/read/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::path::PathBuf;
use std::sync::Arc;

use polars_core::datatypes::{DataType, Field};
use polars_core::schema::{IndexOfSchema, Schema, SchemaRef};
use polars_core::schema::{Schema, SchemaRef};
use polars_error::PolarsResult;
use polars_utils::pl_str::PlSmallStr;
#[cfg(feature = "serde")]
Expand Down
8 changes: 6 additions & 2 deletions crates/polars-io/src/csv/write/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::io::Write;
use std::num::NonZeroUsize;

use polars_core::frame::DataFrame;
use polars_core::schema::{IndexOfSchema, Schema};
use polars_core::schema::Schema;
use polars_core::POOL;
use polars_error::PolarsResult;

Expand Down Expand Up @@ -228,7 +228,11 @@ impl<W: Write> BatchedWriter<W> {

if !self.has_written_header {
self.has_written_header = true;
let names = self.schema.get_names_str();
let names = self
.schema
.iter_names()
.map(|x| x.as_str())
.collect::<Vec<_>>();
write_header(&mut self.writer.buffer, &names, &self.writer.options)?;
};

Expand Down
5 changes: 2 additions & 3 deletions crates/polars-io/src/hive.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use polars_core::frame::DataFrame;
use polars_core::schema::IndexOfSchema;
use polars_core::series::Series;

/// Materializes hive partitions.
Expand All @@ -9,9 +8,9 @@ use polars_core::series::Series;
/// # Safety
///
/// num_rows equals the height of the df when the df height is non-zero.
pub(crate) fn materialize_hive_partitions<S: IndexOfSchema>(
pub(crate) fn materialize_hive_partitions<D>(
df: &mut DataFrame,
reader_schema: &S,
reader_schema: &polars_schema::Schema<D>,
hive_partition_columns: Option<&[Series]>,
num_rows: usize,
) {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/parquet/read/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ fn rg_to_dfs_prefiltered(
// We first add the columns with the live columns at the start. Then, we do a
// projections that puts the columns at the right spot.
df._add_columns(rg_columns, &rearranged_schema)?;
let df = df.select(schema.get_names_owned())?;
let df = df.select(schema.iter_names().cloned())?;

PolarsResult::Ok(Some(df))
})
Expand Down
24 changes: 4 additions & 20 deletions crates/polars-io/src/utils/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,26 +111,10 @@ pub(crate) fn columns_to_projection(
schema: &ArrowSchema,
) -> PolarsResult<Vec<usize>> {
let mut prj = Vec::with_capacity(columns.len());
if columns.len() > 100 {
let mut column_names = PlHashMap::with_capacity(schema.len());
schema.iter_values().enumerate().for_each(|(i, c)| {
column_names.insert(c.name.as_str(), i);
});

for column in columns.iter() {
let Some(&i) = column_names.get(column.as_str()) else {
polars_bail!(
ColumnNotFound:
"unable to find column {:?}; valid columns: {:?}", column, schema.get_names(),
);
};
prj.push(i);
}
} else {
for column in columns.iter() {
let i = schema.try_index_of(column)?;
prj.push(i);
}

for column in columns {
let i = schema.try_index_of(column)?;
prj.push(i);
}

Ok(prj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ impl ProjectionSimple {
impl Executor for ProjectionSimple {
fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult<DataFrame> {
state.should_stop()?;
let columns = self.columns.get_names_owned();
let columns = self.columns.iter_names().cloned().collect::<Vec<_>>();

let profile_name = if state.has_node_timer() {
let name = comma_delimited("simple-projection".to_string(), columns.as_slice());
Expand Down
10 changes: 3 additions & 7 deletions crates/polars-pipe/src/executors/operators/reproject.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use polars_core::error::PolarsResult;
use polars_core::frame::DataFrame;
use polars_core::prelude::IndexOfSchema;
use polars_core::schema::Schema;

use crate::operators::DataChunk;
Expand All @@ -15,12 +14,9 @@ pub(crate) fn reproject_chunk(
// the positions for subsequent calls
let chunk_schema = chunk.data.schema();

let check_duplicates = false;
let out = chunk.data._select_with_schema_impl(
schema.get_names_owned().as_slice(),
&chunk_schema,
check_duplicates,
)?;
let out = chunk
.data
.select_with_schema_unchecked(schema.iter_names().cloned(), &chunk_schema)?;

*positions = out
.get_columns()
Expand Down
3 changes: 1 addition & 2 deletions crates/polars-pipe/src/executors/sinks/reproject.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::any::Any;

use polars_core::prelude::IndexOfSchema;
use polars_core::schema::SchemaRef;

use crate::executors::sources::ReProjectSource;
Expand Down Expand Up @@ -41,7 +40,7 @@ impl Sink for ReProjectSink {
fn finalize(&mut self, context: &PExecutionContext) -> PolarsResult<FinalizedSink> {
Ok(match self.sink.finalize(context)? {
FinalizedSink::Finished(df) => {
FinalizedSink::Finished(df._select_impl(self.schema.get_names_owned().as_slice())?)
FinalizedSink::Finished(df.select(self.schema.iter_names().cloned())?)
},
FinalizedSink::Source(source) => {
FinalizedSink::Source(Box::new(ReProjectSource::new(self.schema.clone(), source)))
Expand Down
10 changes: 5 additions & 5 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1073,14 +1073,14 @@ pub(crate) fn maybe_init_projection_excluding_hive(

let (first_hive_name, _) = hive_schema.get_at_index(0)?;

// TODO: Optimize this
let names = match reader_schema {
Either::Left(ref v) => {
let names = v.get_names_owned();
names.contains(first_hive_name).then_some(names)
},
Either::Left(ref v) => v
.contains(first_hive_name.as_str())
.then(|| v.iter_names().cloned().collect::<Vec<_>>()),
Either::Right(ref v) => v
.contains(first_hive_name.as_str())
.then(|| v.get_names_owned()),
.then(|| v.iter_names().cloned().collect()),
};

let names = names?;
Expand Down
Loading
Loading