diff --git a/arro3-core/python/arro3/core/_core.pyi b/arro3-core/python/arro3/core/_core.pyi index 91a4a14..8d6cc42 100644 --- a/arro3-core/python/arro3/core/_core.pyi +++ b/arro3-core/python/arro3/core/_core.pyi @@ -591,44 +591,124 @@ class RecordBatch: @classmethod def from_arrays( cls, arrays: Sequence[ArrowArrayExportable], *, schema: ArrowSchemaExportable - ) -> RecordBatch: ... + ) -> RecordBatch: + """Construct a RecordBatch from multiple pyarrow.Arrays + + Args: + arrays: One for each field in RecordBatch + schema: Schema for the created batch. If not passed, names must be passed + + Returns: + _description_ + """ @classmethod def from_pydict( cls, mapping: dict[str, ArrowArrayExportable], *, metadata: ArrowSchemaExportable | None = None, - ) -> RecordBatch: ... + ) -> RecordBatch: + """Construct a Table or RecordBatch from Arrow arrays or columns. + + Args: + mapping: A mapping of strings to Arrays. + metadata: Optional metadata for the schema (if inferred). Defaults to None. + + Returns: + _description_ + """ @classmethod - def from_struct_array(cls, struct_array: ArrowArrayExportable) -> RecordBatch: ... + def from_struct_array(cls, struct_array: ArrowArrayExportable) -> RecordBatch: + """Construct a RecordBatch from a StructArray. + + Each field in the StructArray will become a column in the resulting RecordBatch. + + Args: + struct_array: Array to construct the record batch from. + + Returns: + New RecordBatch + """ @classmethod def from_arrow(cls, input: ArrowArrayExportable) -> RecordBatch: ... @classmethod def from_arrow_pycapsule(cls, schema_capsule, array_capsule) -> RecordBatch: """Construct this object from bare Arrow PyCapsules""" def add_column( - self, i: int, field: ArrowSchemaExportable, column: ArrowArrayExportable + self, i: int, field: str | ArrowSchemaExportable, column: ArrowArrayExportable ) -> RecordBatch: ... def append_column( - self, field: ArrowSchemaExportable, column: ArrowArrayExportable - ) -> RecordBatch: ... - def column(self, i: int) -> ChunkedArray: ... + self, field: str | ArrowSchemaExportable, column: ArrowArrayExportable + ) -> RecordBatch: + """Append column at end of columns. + + Args: + field: If a string is passed then the type is deduced from the column data. + column: Column data + + Returns: + _description_ + """ + + def column(self, i: int | str) -> ChunkedArray: + """Select single column from Table or RecordBatch. + + Args: + i: The index or name of the column to retrieve. + + Returns: + _description_ + """ @property - def column_names(self) -> list[str]: ... + def column_names(self) -> list[str]: + """Names of the RecordBatch columns.""" @property - def columns(self) -> list[Array]: ... - def equals(self, other: ArrowArrayExportable) -> bool: ... - def field(self, i: int) -> Field: ... + def columns(self) -> list[Array]: + """List of all columns in numerical order.""" + def equals(self, other: ArrowArrayExportable) -> bool: + """Check if contents of two record batches are equal. + + Args: + other: RecordBatch to compare against. + + Returns: + _description_ + """ + + def field(self, i: int | str) -> Field: + """Select a schema field by its column name or numeric index. + + Args: + i: The index or name of the field to retrieve. + + Returns: + _description_ + """ @property - def num_columns(self) -> int: ... + def num_columns(self) -> int: + """Number of columns.""" @property - def num_rows(self) -> int: ... - def remove_column(self, i: int) -> RecordBatch: ... + def num_rows(self) -> int: + """Number of rows + + Due to the definition of a RecordBatch, all columns have the same number of + rows. + """ + def remove_column(self, i: int) -> RecordBatch: + """Create new RecordBatch with the indicated column removed. + + Args: + i: Index of column to remove. + + Returns: + New record batch without the column. + """ @property - def schema(self) -> Schema: ... - def select(self, columns: list[int]) -> RecordBatch: ... + def schema(self) -> Schema: + """Access the schema of this RecordBatch""" + def select(self, columns: list[int] | list[str]) -> RecordBatch: ... def set_column( - self, i: int, field: ArrowSchemaExportable, column: ArrowArrayExportable + self, i: int, field: str | ArrowSchemaExportable, column: ArrowArrayExportable ) -> RecordBatch: ... @property def shape(self) -> tuple[int, int]: ... @@ -870,7 +950,7 @@ class Table: _description_ """ def add_column( - self, i: int, field: ArrowSchemaExportable, column: ArrowStreamExportable + self, i: int, field: str | ArrowSchemaExportable, column: ArrowStreamExportable ) -> RecordBatch: """Add column to Table at position. @@ -885,7 +965,7 @@ class Table: New table with the passed column added. """ def append_column( - self, field: ArrowSchemaExportable, column: ArrowStreamExportable + self, field: str | ArrowSchemaExportable, column: ArrowStreamExportable ) -> RecordBatch: """Append column at end of columns. @@ -899,7 +979,7 @@ class Table: @property def chunk_lengths(self) -> list[int]: """The number of rows in each internal chunk.""" - def column(self, i: int) -> ChunkedArray: + def column(self, i: int | str) -> ChunkedArray: """Select single column from Table or RecordBatch. Args: @@ -949,15 +1029,20 @@ class Table: Due to the definition of a table, all columns have the same number of rows. """ - def set_column( - self, i: int, field: ArrowSchemaExportable, column: ArrowStreamExportable - ) -> Table: - """Replace column in Table at position. + def remove_column(self, i: int) -> Table: + """Create new Table with the indicated column removed. Args: - i: Index to place the column at. - field: _description_ - column: Column data. + i: Index of column to remove. + + Returns: + New table without the column. + """ + def rename_columns(self, names: Sequence[str]) -> Table: + """Create new table with columns renamed to provided names. + + Args: + names: List of new column names. Returns: _description_ @@ -966,6 +1051,30 @@ class Table: def schema(self) -> Schema: """Schema of the table and its columns. + Returns: + _description_ + """ + def select(self, columns: Sequence[int] | Sequence[str]) -> Table: + """Select columns of the Table. + + Returns a new Table with the specified columns, and metadata preserved. + + Args: + columns: The column names or integer indices to select. + + Returns: + _description_ + """ + def set_column( + self, i: int, field: str | ArrowSchemaExportable, column: ArrowStreamExportable + ) -> Table: + """Replace column in Table at position. + + Args: + i: Index to place the column at. + field: _description_ + column: Column data. + Returns: _description_ """ diff --git a/pyo3-arrow/src/chunked.rs b/pyo3-arrow/src/chunked.rs index c215d37..36a57dd 100644 --- a/pyo3-arrow/src/chunked.rs +++ b/pyo3-arrow/src/chunked.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use arrow::compute::concat; use arrow_array::{make_array, Array, ArrayRef}; -use arrow_schema::{ArrowError, Field, FieldRef}; +use arrow_schema::{ArrowError, DataType, Field, FieldRef}; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::intern; use pyo3::prelude::*; @@ -38,6 +38,10 @@ impl PyChunkedArray { Self { chunks, field } } + pub fn data_type(&self) -> &DataType { + self.field.data_type() + } + pub fn from_arrays(chunks: &[A]) -> PyArrowResult { let arrays = chunks .iter() diff --git a/pyo3-arrow/src/input.rs b/pyo3-arrow/src/input.rs index f5a9e58..8e893af 100644 --- a/pyo3-arrow/src/input.rs +++ b/pyo3-arrow/src/input.rs @@ -5,14 +5,17 @@ use std::collections::HashMap; use std::string::FromUtf8Error; +use std::sync::Arc; use arrow_array::{RecordBatchIterator, RecordBatchReader}; -use arrow_schema::{FieldRef, SchemaRef}; +use arrow_schema::{ArrowError, Field, FieldRef, Fields, Schema, SchemaRef}; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use crate::array_reader::PyArrayReader; +use crate::error::PyArrowResult; use crate::ffi::{ArrayIterator, ArrayReader}; -use crate::{PyArray, PyRecordBatch, PyRecordBatchReader}; +use crate::{PyArray, PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader}; /// An enum over [PyRecordBatch] and [PyRecordBatchReader], used when a function accepts either /// Arrow object as input. @@ -49,6 +52,22 @@ pub enum AnyArray { } impl AnyArray { + pub fn into_chunked_array(self) -> PyArrowResult { + match self { + Self::Array(array) => { + let (array, field) = array.into_inner(); + Ok(PyChunkedArray::new(vec![array], field)) + } + Self::Stream(stream) => { + let field = stream.field_ref()?; + let chunks = stream + .into_reader()? + .collect::, ArrowError>>()?; + Ok(PyChunkedArray::new(chunks, field)) + } + } + } + pub fn into_reader(self) -> PyResult> { match self { Self::Array(array) => { @@ -97,6 +116,62 @@ impl Default for MetadataInput { #[derive(FromPyObject)] pub enum FieldIndexInput { - String(String), - Int(usize), + Name(String), + Position(usize), +} + +impl FieldIndexInput { + pub fn into_position(self, schema: &Schema) -> PyArrowResult { + match self { + Self::Name(name) => Ok(schema.index_of(name.as_ref())?), + Self::Position(position) => Ok(position), + } + } +} + +#[derive(FromPyObject)] +pub enum NameOrField { + Name(String), + Field(PyField), +} + +impl NameOrField { + pub fn into_field(self, source_field: &Field) -> FieldRef { + match self { + Self::Name(name) => Arc::new( + Field::new( + name, + source_field.data_type().clone(), + source_field.is_nullable(), + ) + .with_metadata(source_field.metadata().clone()), + ), + Self::Field(field) => field.into_inner(), + } + } +} + +#[derive(FromPyObject)] +pub enum SelectIndices { + Names(Vec), + Positions(Vec), +} + +impl SelectIndices { + pub fn into_positions(self, fields: &Fields) -> PyResult> { + match self { + Self::Names(names) => { + let mut positions = Vec::with_capacity(names.len()); + for name in names { + let index = fields + .iter() + .position(|field| field.name() == &name) + .ok_or(PyValueError::new_err(format!("{name} not in schema.")))?; + positions.push(index); + } + Ok(positions) + } + Self::Positions(positions) => Ok(positions), + } + } } diff --git a/pyo3-arrow/src/record_batch.rs b/pyo3-arrow/src/record_batch.rs index 88a2dab..d5c4be5 100644 --- a/pyo3-arrow/src/record_batch.rs +++ b/pyo3-arrow/src/record_batch.rs @@ -14,7 +14,7 @@ use crate::error::PyArrowResult; use crate::ffi::from_python::utils::import_array_pycapsules; use crate::ffi::to_python::nanoarrow::to_nanoarrow_array; use crate::ffi::to_python::to_array_pycapsules; -use crate::input::MetadataInput; +use crate::input::{FieldIndexInput, MetadataInput, NameOrField, SelectIndices}; use crate::schema::display_schema; use crate::{PyArray, PyField, PySchema}; @@ -258,11 +258,11 @@ impl PyRecordBatch { &self, py: Python, i: usize, - field: PyField, + field: NameOrField, column: PyArray, ) -> PyArrowResult { let mut fields = self.0.schema_ref().fields().to_vec(); - fields.insert(i, field.into_inner()); + fields.insert(i, field.into_field(column.field())); let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone()); let mut arrays = self.0.columns().to_vec(); @@ -275,11 +275,11 @@ impl PyRecordBatch { pub fn append_column( &self, py: Python, - field: PyField, + field: NameOrField, column: PyArray, ) -> PyArrowResult { let mut fields = self.0.schema_ref().fields().to_vec(); - fields.push(field.into_inner()); + fields.push(field.into_field(column.field())); let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone()); let mut arrays = self.0.columns().to_vec(); @@ -290,9 +290,10 @@ impl PyRecordBatch { } /// Select single column from RecordBatch - pub fn column(&self, py: Python, i: usize) -> PyResult { - let field = self.0.schema().field(i).clone(); - let array = self.0.column(i).clone(); + pub fn column(&self, py: Python, i: FieldIndexInput) -> PyResult { + let column_index = i.into_position(self.0.schema_ref())?; + let field = self.0.schema().field(column_index).clone(); + let array = self.0.column(column_index).clone(); PyArray::new(array, field.into()).to_arro3(py) } @@ -311,7 +312,7 @@ impl PyRecordBatch { #[getter] pub fn columns(&self, py: Python) -> PyResult> { (0..self.num_columns()) - .map(|i| self.column(py, i)) + .map(|i| self.column(py, FieldIndexInput::Position(i))) .collect() } @@ -320,8 +321,10 @@ impl PyRecordBatch { } /// Select a schema field by its numeric index. - pub fn field(&self, py: Python, i: usize) -> PyResult { - PyField::new(self.0.schema().field(i).clone().into()).to_arro3(py) + pub fn field(&self, py: Python, i: FieldIndexInput) -> PyResult { + let schema_ref = self.0.schema_ref(); + let field = schema_ref.field(i.into_position(schema_ref)?); + PyField::new(field.clone().into()).to_arro3(py) } /// Number of columns in this RecordBatch. @@ -342,14 +345,14 @@ impl PyRecordBatch { PyRecordBatch::new(rb).to_arro3(py) } - /// Access the schema of this RecordBatch #[getter] pub fn schema(&self, py: Python) -> PyResult { PySchema::new(self.0.schema()).to_arro3(py) } - pub fn select(&self, py: Python, columns: Vec) -> PyArrowResult { - let new_rb = self.0.project(&columns)?; + pub fn select(&self, py: Python, columns: SelectIndices) -> PyArrowResult { + let positions = columns.into_positions(self.0.schema_ref().fields())?; + let new_rb = self.0.project(&positions)?; Ok(PyRecordBatch::new(new_rb).to_arro3(py)?) } @@ -357,11 +360,11 @@ impl PyRecordBatch { &self, py: Python, i: usize, - field: PyField, + field: NameOrField, column: PyArray, ) -> PyArrowResult { let mut fields = self.0.schema_ref().fields().to_vec(); - fields[i] = field.into_inner(); + fields[i] = field.into_field(column.field()); let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone()); let mut arrays = self.0.columns().to_vec(); diff --git a/pyo3-arrow/src/schema.rs b/pyo3-arrow/src/schema.rs index a3cccd6..1610ffb 100644 --- a/pyo3-arrow/src/schema.rs +++ b/pyo3-arrow/src/schema.rs @@ -167,10 +167,8 @@ impl PySchema { } pub fn field(&self, py: Python, i: FieldIndexInput) -> PyArrowResult { - let field = match i { - FieldIndexInput::String(name) => self.0.field_with_name(&name)?, - FieldIndexInput::Int(i) => self.0.field(i), - }; + let index = i.into_position(&self.0)?; + let field = self.0.field(index); Ok(PyField::new(field.clone().into()).to_arro3(py)?) } diff --git a/pyo3-arrow/src/table.rs b/pyo3-arrow/src/table.rs index 55a4d69..dd9cc52 100644 --- a/pyo3-arrow/src/table.rs +++ b/pyo3-arrow/src/table.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::fmt::Display; use std::sync::Arc; @@ -16,7 +17,7 @@ use crate::ffi::from_python::utils::import_stream_pycapsule; use crate::ffi::to_python::chunked::ArrayIterator; use crate::ffi::to_python::nanoarrow::to_nanoarrow_array_stream; use crate::ffi::to_python::to_stream_pycapsule; -use crate::input::FieldIndexInput; +use crate::input::{AnyArray, FieldIndexInput, MetadataInput, NameOrField, SelectIndices}; use crate::schema::display_schema; use crate::{PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader, PySchema}; @@ -139,11 +140,92 @@ impl PyTable { Ok(Self::new(batches, schema)) } + #[classmethod] + #[pyo3(signature = (mapping, *, schema=None, metadata=None))] + pub fn from_pydict( + cls: &Bound, + mapping: HashMap, + schema: Option, + metadata: Option, + ) -> PyArrowResult { + let (names, arrays): (Vec<_>, Vec<_>) = mapping.into_iter().unzip(); + Self::from_arrays(cls, arrays, Some(names), schema, metadata) + } + + #[classmethod] + #[pyo3(signature = (arrays, *, names=None, schema=None, metadata=None))] + pub fn from_arrays( + _cls: &Bound, + arrays: Vec, + names: Option>, + schema: Option, + metadata: Option, + ) -> PyArrowResult { + let columns = arrays + .into_iter() + .map(|array| array.into_chunked_array()) + .collect::>>()?; + + let schema: SchemaRef = if let Some(schema) = schema { + schema.into_inner() + } else { + let names = names.ok_or(PyValueError::new_err( + "names must be passed if schema is not passed.", + ))?; + + let fields = columns + .iter() + .zip(names.iter()) + .map(|(array, name)| Field::new(name.clone(), array.data_type().clone(), true)) + .collect::>(); + Arc::new( + Schema::new(fields) + .with_metadata(metadata.unwrap_or_default().into_string_hashmap().unwrap()), + ) + }; + + if columns.is_empty() { + return Ok(Self::new(vec![], schema)); + } + + let column_chunk_lengths = columns + .iter() + .map(|column| { + let chunk_lengths = column + .chunks() + .iter() + .map(|chunk| chunk.len()) + .collect::>(); + chunk_lengths + }) + .collect::>(); + if !column_chunk_lengths.windows(2).all(|w| w[0] == w[1]) { + return Err( + PyValueError::new_err("All columns must have the same chunk lengths").into(), + ); + } + let num_batches = column_chunk_lengths[0].len(); + + let mut batches = vec![]; + for batch_idx in 0..num_batches { + let batch = RecordBatch::try_new( + schema.clone(), + columns + .iter() + .map(|column| column.chunks()[batch_idx].clone()) + .collect(), + )?; + batches.push(batch); + } + + Ok(Self::new(batches, schema)) + } + pub fn add_column( &self, py: Python, i: usize, - field: PyField, + field: NameOrField, column: PyChunkedArray, ) -> PyArrowResult { if self.num_rows() != column.len() { @@ -155,7 +237,7 @@ impl PyTable { let column = column.rechunk(self.chunk_lengths())?; let mut fields = self.schema.fields().to_vec(); - fields.insert(i, field.into_inner()); + fields.insert(i, field.into_field(column.field())); let new_schema = Arc::new(Schema::new_with_metadata( fields, self.schema.metadata().clone(), @@ -184,7 +266,7 @@ impl PyTable { pub fn append_column( &self, py: Python, - field: PyField, + field: NameOrField, column: PyChunkedArray, ) -> PyArrowResult { if self.num_rows() != column.len() { @@ -196,7 +278,7 @@ impl PyTable { let column = column.rechunk(self.chunk_lengths())?; let mut fields = self.schema.fields().to_vec(); - fields.push(field.into_inner()); + fields.push(field.into_field(column.field())); let new_schema = Arc::new(Schema::new_with_metadata( fields, self.schema.metadata().clone(), @@ -227,14 +309,15 @@ impl PyTable { self.batches.iter().map(|batch| batch.num_rows()).collect() } - pub fn column(&self, py: Python, i: usize) -> PyResult { - let field = self.schema.field(i).clone(); + pub fn column(&self, py: Python, i: FieldIndexInput) -> PyArrowResult { + let column_index = i.into_position(&self.schema)?; + let field = self.schema.field(column_index).clone(); let chunks = self .batches .iter() - .map(|batch| batch.column(i).clone()) + .map(|batch| batch.column(column_index).clone()) .collect(); - PyChunkedArray::new(chunks, field.into()).to_arro3(py) + Ok(PyChunkedArray::new(chunks, field.into()).to_arro3(py)?) } #[getter] @@ -247,9 +330,9 @@ impl PyTable { } #[getter] - pub fn columns(&self, py: Python) -> PyResult> { + pub fn columns(&self, py: Python) -> PyArrowResult> { (0..self.num_columns()) - .map(|i| self.column(py, i)) + .map(|i| self.column(py, FieldIndexInput::Position(i))) .collect() } @@ -259,11 +342,7 @@ impl PyTable { } pub fn field(&self, py: Python, i: FieldIndexInput) -> PyArrowResult { - let schema = &self.schema; - let field = match i { - FieldIndexInput::String(name) => schema.field_with_name(&name)?, - FieldIndexInput::Int(i) => schema.field(i), - }; + let field = self.schema.field(i.into_position(&self.schema)?); Ok(PyField::new(field.clone().into()).to_arro3(py)?) } @@ -281,11 +360,68 @@ impl PyTable { // pub fn rechunk(&self, py: Python, max_chunksize: usize) {} + pub fn remove_column(&self, py: Python, i: usize) -> PyArrowResult { + let mut fields = self.schema.fields().to_vec(); + fields.remove(i); + let new_schema = Arc::new(Schema::new_with_metadata( + fields, + self.schema.metadata().clone(), + )); + + let new_batches = self + .batches + .iter() + .map(|batch| { + let mut columns = batch.columns().to_vec(); + columns.remove(i); + Ok(RecordBatch::try_new(new_schema.clone(), columns)?) + }) + .collect::, PyArrowError>>()?; + + Ok(PyTable::new(new_batches, new_schema).to_arro3(py)?) + } + + pub fn rename_columns(&self, py: Python, names: Vec) -> PyArrowResult { + if names.len() != self.num_columns() { + return Err(PyValueError::new_err("When names is a list[str], must pass the same number of names as there are columns.").into()); + } + + let new_fields = self + .schema + .fields() + .iter() + .zip(names) + .map(|(field, name)| field.as_ref().clone().with_name(name)) + .collect::>(); + let new_schema = Arc::new(Schema::new_with_metadata( + new_fields, + self.schema.metadata().clone(), + )); + Ok(PyTable::new(self.batches.clone(), new_schema).to_arro3(py)?) + } + + #[getter] + pub fn schema(&self, py: Python) -> PyResult { + PySchema::new(self.schema.clone()).to_arro3(py) + } + + pub fn select(&self, py: Python, columns: SelectIndices) -> PyArrowResult { + let positions = columns.into_positions(self.schema.fields())?; + + let new_schema = Arc::new(self.schema.project(&positions)?); + let new_batches = self + .batches + .iter() + .map(|batch| batch.project(&positions)) + .collect::, ArrowError>>()?; + Ok(PyTable::new(new_batches, new_schema).to_arro3(py)?) + } + pub fn set_column( &self, py: Python, i: usize, - field: PyField, + field: NameOrField, column: PyChunkedArray, ) -> PyArrowResult { if self.num_rows() != column.len() { @@ -297,7 +433,7 @@ impl PyTable { let column = column.rechunk(self.chunk_lengths())?; let mut fields = self.schema.fields().to_vec(); - fields[i] = field.into_inner(); + fields[i] = field.into_field(column.field()); let new_schema = Arc::new(Schema::new_with_metadata( fields, self.schema.metadata().clone(), @@ -323,11 +459,6 @@ impl PyTable { Ok(PyTable::new(new_batches, new_schema).to_arro3(py)?) } - #[getter] - pub fn schema(&self, py: Python) -> PyResult { - PySchema::new(self.schema.clone()).to_arro3(py) - } - #[getter] pub fn shape(&self) -> (usize, usize) { (self.num_rows(), self.num_columns())