Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(rust): Refactor ArrowSchema to use polars_schema::Schema<D> #18564

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/polars-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ hashbrown = { workspace = true }
num-traits = { workspace = true }
parking_lot = { workspace = true }
polars-error = { workspace = true }
polars-schema = { workspace = true }
polars-utils = { workspace = true }
serde = { workspace = true, optional = true }
simdutf8 = { workspace = true }
Expand Down Expand Up @@ -153,7 +154,7 @@ compute = [
"compute_take",
"compute_temporal",
]
serde = ["dep:serde"]
serde = ["dep:serde", "polars-schema/serde"]
simd = []

# polars-arrow
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-arrow/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ pub struct Field {
pub metadata: Metadata,
}

/// Support for `ArrowSchema::from_iter([field, ..])`
impl From<Field> for (PlSmallStr, Field) {
fn from(value: Field) -> Self {
(value.name.clone(), value)
}
}

impl Field {
/// Creates a new [`Field`].
pub fn new(name: PlSmallStr, data_type: ArrowDataType, is_nullable: bool) -> Self {
Expand Down
64 changes: 1 addition & 63 deletions crates/polars-arrow/src/datatypes/schema.rs
Original file line number Diff line number Diff line change
@@ -1,73 +1,11 @@
use std::sync::Arc;

use polars_error::{polars_bail, PolarsResult};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use super::Field;

/// An ordered sequence of [`Field`]s
///
/// [`ArrowSchema`] is an abstraction used to read from, and write to, Arrow IPC format,
/// Apache Parquet, and Apache Avro. All these formats have a concept of a schema
/// with fields and metadata.
#[derive(Debug, Clone, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ArrowSchema {
/// The fields composing this schema.
pub fields: Vec<Field>,
}

pub type ArrowSchema = polars_schema::Schema<Field>;
pub type ArrowSchemaRef = Arc<ArrowSchema>;

impl ArrowSchema {
#[inline]
pub fn len(&self) -> usize {
self.fields.len()
}

#[inline]
pub fn is_empty(&self) -> bool {
self.fields.is_empty()
}

/// Returns a new [`ArrowSchema`] with a subset of all fields whose `predicate`
/// evaluates to true.
pub fn filter<F: Fn(usize, &Field) -> bool>(self, predicate: F) -> Self {
let fields = self
.fields
.into_iter()
.enumerate()
.filter_map(|(index, f)| {
if (predicate)(index, &f) {
Some(f)
} else {
None
}
})
.collect();

ArrowSchema { fields }
}

pub fn try_project(&self, indices: &[usize]) -> PolarsResult<Self> {
let fields = indices.iter().map(|&i| {
let Some(out) = self.fields.get(i) else {
polars_bail!(
SchemaFieldNotFound: "projection index {} is out of bounds for schema of length {}",
i, self.fields.len()
);
};

Ok(out.clone())
}).collect::<PolarsResult<Vec<_>>>()?;

Ok(ArrowSchema { fields })
}
}

impl From<Vec<Field>> for ArrowSchema {
fn from(fields: Vec<Field>) -> Self {
Self { fields }
}
}
6 changes: 3 additions & 3 deletions crates/polars-arrow/src/io/avro/read/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ fn skip_item<'a>(
/// `fields`, `avro_fields` and `projection` must have the same length.
pub fn deserialize(
block: &Block,
fields: &[Field],
fields: &ArrowSchema,
avro_fields: &[AvroField],
projection: &[bool],
) -> PolarsResult<RecordBatchT<Box<dyn Array>>> {
Expand All @@ -479,7 +479,7 @@ pub fn deserialize(

// create mutables, one per field
let mut arrays: Vec<Box<dyn MutableArray>> = fields
.iter()
.iter_values()
.zip(avro_fields.iter())
.zip(projection.iter())
.map(|((field, avro_field), projection)| {
Expand All @@ -496,7 +496,7 @@ pub fn deserialize(
for _ in 0..rows {
let iter = arrays
.iter_mut()
.zip(fields.iter())
.zip(fields.iter_values())
.zip(avro_fields.iter())
.zip(projection.iter());

Expand Down
8 changes: 4 additions & 4 deletions crates/polars-arrow/src/io/avro/read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ mod util;
pub use schema::infer_schema;

use crate::array::Array;
use crate::datatypes::Field;
use crate::datatypes::ArrowSchema;
use crate::record_batch::RecordBatchT;

/// Single threaded, blocking reader of Avro; [`Iterator`] of [`RecordBatchT`].
pub struct Reader<R: Read> {
iter: BlockStreamingIterator<R>,
avro_fields: Vec<AvroField>,
fields: Vec<Field>,
fields: ArrowSchema,
projection: Vec<bool>,
}

Expand All @@ -33,7 +33,7 @@ impl<R: Read> Reader<R> {
pub fn new(
reader: R,
metadata: FileMetadata,
fields: Vec<Field>,
fields: ArrowSchema,
projection: Option<Vec<bool>>,
) -> Self {
let projection = projection.unwrap_or_else(|| fields.iter().map(|_| true).collect());
Expand All @@ -56,7 +56,7 @@ impl<R: Read> Iterator for Reader<R> {
type Item = PolarsResult<RecordBatchT<Box<dyn Array>>>;

fn next(&mut self) -> Option<Self::Item> {
let fields = &self.fields[..];
let fields = &self.fields;
let avro_fields = &self.avro_fields;
let projection = &self.projection;

Expand Down
11 changes: 6 additions & 5 deletions crates/polars-arrow/src/io/avro/read/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,19 @@ fn external_props(schema: &AvroSchema) -> Metadata {
/// Infers an [`ArrowSchema`] from the root [`Record`].
/// This
pub fn infer_schema(record: &Record) -> PolarsResult<ArrowSchema> {
Ok(record
record
.fields
.iter()
.map(|field| {
schema_to_field(
let field = schema_to_field(
&field.schema,
Some(&field.name),
external_props(&field.schema),
)
)?;

Ok((field.name.clone(), field))
})
.collect::<PolarsResult<Vec<_>>>()?
.into())
.collect::<PolarsResult<ArrowSchema>>()
}

fn schema_to_field(
Expand Down
3 changes: 1 addition & 2 deletions crates/polars-arrow/src/io/avro/write/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ use crate::datatypes::*;
pub fn to_record(schema: &ArrowSchema, name: String) -> PolarsResult<Record> {
let mut name_counter: i32 = 0;
let fields = schema
.fields
.iter()
.iter_values()
.map(|f| field_to_field(f, &mut name_counter))
.collect::<PolarsResult<_>>()?;
Ok(Record {
Expand Down
10 changes: 5 additions & 5 deletions crates/polars-arrow/src/io/flight/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pub fn serialize_schema_to_info(
let encoded_data = if let Some(ipc_fields) = ipc_fields {
schema_as_encoded_data(schema, ipc_fields)
} else {
let ipc_fields = default_ipc_fields(&schema.fields);
let ipc_fields = default_ipc_fields(schema.iter_values());
schema_as_encoded_data(schema, &ipc_fields)
};

Expand All @@ -92,7 +92,7 @@ fn _serialize_schema(schema: &ArrowSchema, ipc_fields: Option<&[IpcField]>) -> V
if let Some(ipc_fields) = ipc_fields {
write::schema_to_bytes(schema, ipc_fields)
} else {
let ipc_fields = default_ipc_fields(&schema.fields);
let ipc_fields = default_ipc_fields(schema.iter_values());
write::schema_to_bytes(schema, &ipc_fields)
}
}
Expand All @@ -113,7 +113,7 @@ pub fn deserialize_schemas(bytes: &[u8]) -> PolarsResult<(ArrowSchema, IpcSchema
/// Deserializes [`FlightData`] representing a record batch message to [`RecordBatchT`].
pub fn deserialize_batch(
data: &FlightData,
fields: &[Field],
fields: &ArrowSchema,
ipc_schema: &IpcSchema,
dictionaries: &read::Dictionaries,
) -> PolarsResult<RecordBatchT<Box<dyn Array>>> {
Expand Down Expand Up @@ -147,7 +147,7 @@ pub fn deserialize_batch(
/// Deserializes [`FlightData`], assuming it to be a dictionary message, into `dictionaries`.
pub fn deserialize_dictionary(
data: &FlightData,
fields: &[Field],
fields: &ArrowSchema,
ipc_schema: &IpcSchema,
dictionaries: &mut read::Dictionaries,
) -> PolarsResult<()> {
Expand Down Expand Up @@ -182,7 +182,7 @@ pub fn deserialize_dictionary(
/// or by upserting into `dictionaries` (when the message is a dictionary)
pub fn deserialize_message(
data: &FlightData,
fields: &[Field],
fields: &ArrowSchema,
ipc_schema: &IpcSchema,
dictionaries: &mut Dictionaries,
) -> PolarsResult<Option<RecordBatchT<Box<dyn Array>>>> {
Expand Down
38 changes: 25 additions & 13 deletions crates/polars-arrow/src/io/ipc/read/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use polars_utils::pl_str::PlSmallStr;
use super::deserialize::{read, skip};
use super::Dictionaries;
use crate::array::*;
use crate::datatypes::{ArrowDataType, Field};
use crate::datatypes::{ArrowDataType, ArrowSchema, Field};
use crate::io::ipc::read::OutOfSpecKind;
use crate::io::ipc::{IpcField, IpcSchema};
use crate::record_batch::RecordBatchT;
Expand Down Expand Up @@ -77,7 +77,7 @@ impl<'a, A, I: Iterator<Item = A>> Iterator for ProjectionIter<'a, A, I> {
#[allow(clippy::too_many_arguments)]
pub fn read_record_batch<R: Read + Seek>(
batch: arrow_format::ipc::RecordBatchRef,
fields: &[Field],
fields: &ArrowSchema,
ipc_schema: &IpcSchema,
projection: Option<&[usize]>,
limit: Option<usize>,
Expand Down Expand Up @@ -127,8 +127,10 @@ pub fn read_record_batch<R: Read + Seek>(
let mut field_nodes = field_nodes.iter().collect::<VecDeque<_>>();

let columns = if let Some(projection) = projection {
let projection =
ProjectionIter::new(projection, fields.iter().zip(ipc_schema.fields.iter()));
let projection = ProjectionIter::new(
projection,
fields.iter_values().zip(ipc_schema.fields.iter()),
);

projection
.map(|maybe_field| match maybe_field {
Expand Down Expand Up @@ -163,7 +165,7 @@ pub fn read_record_batch<R: Read + Seek>(
.collect::<PolarsResult<Vec<_>>>()?
} else {
fields
.iter()
.iter_values()
.zip(ipc_schema.fields.iter())
.map(|(field, ipc_field)| {
read(
Expand Down Expand Up @@ -227,11 +229,11 @@ fn find_first_dict_field<'a>(

pub(crate) fn first_dict_field<'a>(
id: i64,
fields: &'a [Field],
fields: &'a ArrowSchema,
ipc_fields: &'a [IpcField],
) -> PolarsResult<(&'a Field, &'a IpcField)> {
assert_eq!(fields.len(), ipc_fields.len());
for (field, ipc_field) in fields.iter().zip(ipc_fields.iter()) {
for (field, ipc_field) in fields.iter_values().zip(ipc_fields.iter()) {
if let Some(field) = find_first_dict_field(id, field, ipc_field) {
return Ok(field);
}
Expand All @@ -246,7 +248,7 @@ pub(crate) fn first_dict_field<'a>(
#[allow(clippy::too_many_arguments)]
pub fn read_dictionary<R: Read + Seek>(
batch: arrow_format::ipc::DictionaryBatchRef,
fields: &[Field],
fields: &ArrowSchema,
ipc_schema: &IpcSchema,
dictionaries: &mut Dictionaries,
reader: &mut R,
Expand Down Expand Up @@ -280,7 +282,11 @@ pub fn read_dictionary<R: Read + Seek>(
};

// Make a fake schema for the dictionary batch.
let fields = vec![Field::new(PlSmallStr::EMPTY, value_type.clone(), false)];
let fields = std::iter::once((
PlSmallStr::EMPTY,
Field::new(PlSmallStr::EMPTY, value_type.clone(), false),
))
.collect();
let ipc_schema = IpcSchema {
fields: vec![first_ipc_field.clone()],
is_little_endian: ipc_schema.is_little_endian,
Expand All @@ -305,10 +311,16 @@ pub fn read_dictionary<R: Read + Seek>(
}

pub fn prepare_projection(
fields: &[Field],
schema: &ArrowSchema,
mut projection: Vec<usize>,
) -> (Vec<usize>, PlHashMap<usize, usize>, Vec<Field>) {
let fields = projection.iter().map(|x| fields[*x].clone()).collect();
) -> (Vec<usize>, PlHashMap<usize, usize>, ArrowSchema) {
let schema = projection
.iter()
.map(|x| {
let (k, v) = schema.get_at_index(*x).unwrap();
(k.clone(), v.clone())
})
.collect();

// todo: find way to do this more efficiently
let mut indices = (0..projection.len()).collect::<Vec<_>>();
Expand All @@ -335,7 +347,7 @@ pub fn prepare_projection(
}
}

(projection, map, fields)
(projection, map, schema)
}

pub fn apply_projection(
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-arrow/src/io/ipc/read/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ fn read_dictionary_block<R: Read + Seek>(

read_dictionary(
batch,
&metadata.schema.fields,
&metadata.schema,
&metadata.ipc_schema,
dictionaries,
reader,
Expand Down Expand Up @@ -317,7 +317,7 @@ pub fn read_batch<R: Read + Seek>(

read_record_batch(
batch,
&metadata.schema.fields,
&metadata.schema,
&metadata.ipc_schema,
projection,
limit,
Expand Down
Loading