Skip to content

Commit

Permalink
Update Table methods (#74)
Browse files Browse the repository at this point in the history
* Update methods

* Allow more str | field

* Implement Table.from_arrays
  • Loading branch information
kylebarron authored Jul 30, 2024
1 parent 0148611 commit 9e54d5b
Show file tree
Hide file tree
Showing 6 changed files with 395 additions and 75 deletions.
163 changes: 136 additions & 27 deletions arro3-core/python/arro3/core/_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -591,44 +591,124 @@ class RecordBatch:
@classmethod
def from_arrays(
cls, arrays: Sequence[ArrowArrayExportable], *, schema: ArrowSchemaExportable
) -> RecordBatch: ...
) -> RecordBatch:
"""Construct a RecordBatch from multiple pyarrow.Arrays
Args:
arrays: One for each field in RecordBatch
schema: Schema for the created batch. If not passed, names must be passed
Returns:
_description_
"""
@classmethod
def from_pydict(
cls,
mapping: dict[str, ArrowArrayExportable],
*,
metadata: ArrowSchemaExportable | None = None,
) -> RecordBatch: ...
) -> RecordBatch:
"""Construct a Table or RecordBatch from Arrow arrays or columns.
Args:
mapping: A mapping of strings to Arrays.
metadata: Optional metadata for the schema (if inferred). Defaults to None.
Returns:
_description_
"""
@classmethod
def from_struct_array(cls, struct_array: ArrowArrayExportable) -> RecordBatch: ...
def from_struct_array(cls, struct_array: ArrowArrayExportable) -> RecordBatch:
"""Construct a RecordBatch from a StructArray.
Each field in the StructArray will become a column in the resulting RecordBatch.
Args:
struct_array: Array to construct the record batch from.
Returns:
New RecordBatch
"""
@classmethod
def from_arrow(cls, input: ArrowArrayExportable) -> RecordBatch: ...
@classmethod
def from_arrow_pycapsule(cls, schema_capsule, array_capsule) -> RecordBatch:
"""Construct this object from bare Arrow PyCapsules"""
def add_column(
self, i: int, field: ArrowSchemaExportable, column: ArrowArrayExportable
self, i: int, field: str | ArrowSchemaExportable, column: ArrowArrayExportable
) -> RecordBatch: ...
def append_column(
self, field: ArrowSchemaExportable, column: ArrowArrayExportable
) -> RecordBatch: ...
def column(self, i: int) -> ChunkedArray: ...
self, field: str | ArrowSchemaExportable, column: ArrowArrayExportable
) -> RecordBatch:
"""Append column at end of columns.
Args:
field: If a string is passed then the type is deduced from the column data.
column: Column data
Returns:
_description_
"""

def column(self, i: int | str) -> ChunkedArray:
"""Select single column from Table or RecordBatch.
Args:
i: The index or name of the column to retrieve.
Returns:
_description_
"""
@property
def column_names(self) -> list[str]: ...
def column_names(self) -> list[str]:
"""Names of the RecordBatch columns."""
@property
def columns(self) -> list[Array]: ...
def equals(self, other: ArrowArrayExportable) -> bool: ...
def field(self, i: int) -> Field: ...
def columns(self) -> list[Array]:
"""List of all columns in numerical order."""
def equals(self, other: ArrowArrayExportable) -> bool:
"""Check if contents of two record batches are equal.
Args:
other: RecordBatch to compare against.
Returns:
_description_
"""

def field(self, i: int | str) -> Field:
"""Select a schema field by its column name or numeric index.
Args:
i: The index or name of the field to retrieve.
Returns:
_description_
"""
@property
def num_columns(self) -> int: ...
def num_columns(self) -> int:
"""Number of columns."""
@property
def num_rows(self) -> int: ...
def remove_column(self, i: int) -> RecordBatch: ...
def num_rows(self) -> int:
"""Number of rows
Due to the definition of a RecordBatch, all columns have the same number of
rows.
"""
def remove_column(self, i: int) -> RecordBatch:
"""Create new RecordBatch with the indicated column removed.
Args:
i: Index of column to remove.
Returns:
New record batch without the column.
"""
@property
def schema(self) -> Schema: ...
def select(self, columns: list[int]) -> RecordBatch: ...
def schema(self) -> Schema:
"""Access the schema of this RecordBatch"""
def select(self, columns: list[int] | list[str]) -> RecordBatch: ...
def set_column(
self, i: int, field: ArrowSchemaExportable, column: ArrowArrayExportable
self, i: int, field: str | ArrowSchemaExportable, column: ArrowArrayExportable
) -> RecordBatch: ...
@property
def shape(self) -> tuple[int, int]: ...
Expand Down Expand Up @@ -870,7 +950,7 @@ class Table:
_description_
"""
def add_column(
self, i: int, field: ArrowSchemaExportable, column: ArrowStreamExportable
self, i: int, field: str | ArrowSchemaExportable, column: ArrowStreamExportable
) -> RecordBatch:
"""Add column to Table at position.
Expand All @@ -885,7 +965,7 @@ class Table:
New table with the passed column added.
"""
def append_column(
self, field: ArrowSchemaExportable, column: ArrowStreamExportable
self, field: str | ArrowSchemaExportable, column: ArrowStreamExportable
) -> RecordBatch:
"""Append column at end of columns.
Expand All @@ -899,7 +979,7 @@ class Table:
@property
def chunk_lengths(self) -> list[int]:
"""The number of rows in each internal chunk."""
def column(self, i: int) -> ChunkedArray:
def column(self, i: int | str) -> ChunkedArray:
"""Select single column from Table or RecordBatch.
Args:
Expand Down Expand Up @@ -949,15 +1029,20 @@ class Table:
Due to the definition of a table, all columns have the same number of rows.
"""
def set_column(
self, i: int, field: ArrowSchemaExportable, column: ArrowStreamExportable
) -> Table:
"""Replace column in Table at position.
def remove_column(self, i: int) -> Table:
"""Create new Table with the indicated column removed.
Args:
i: Index to place the column at.
field: _description_
column: Column data.
i: Index of column to remove.
Returns:
New table without the column.
"""
def rename_columns(self, names: Sequence[str]) -> Table:
"""Create new table with columns renamed to provided names.
Args:
names: List of new column names.
Returns:
_description_
Expand All @@ -966,6 +1051,30 @@ class Table:
def schema(self) -> Schema:
"""Schema of the table and its columns.
Returns:
_description_
"""
def select(self, columns: Sequence[int] | Sequence[str]) -> Table:
"""Select columns of the Table.
Returns a new Table with the specified columns, and metadata preserved.
Args:
columns: The column names or integer indices to select.
Returns:
_description_
"""
def set_column(
self, i: int, field: str | ArrowSchemaExportable, column: ArrowStreamExportable
) -> Table:
"""Replace column in Table at position.
Args:
i: Index to place the column at.
field: _description_
column: Column data.
Returns:
_description_
"""
Expand Down
6 changes: 5 additions & 1 deletion pyo3-arrow/src/chunked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;

use arrow::compute::concat;
use arrow_array::{make_array, Array, ArrayRef};
use arrow_schema::{ArrowError, Field, FieldRef};
use arrow_schema::{ArrowError, DataType, Field, FieldRef};
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::intern;
use pyo3::prelude::*;
Expand Down Expand Up @@ -38,6 +38,10 @@ impl PyChunkedArray {
Self { chunks, field }
}

pub fn data_type(&self) -> &DataType {
self.field.data_type()
}

pub fn from_arrays<A: Array>(chunks: &[A]) -> PyArrowResult<Self> {
let arrays = chunks
.iter()
Expand Down
83 changes: 79 additions & 4 deletions pyo3-arrow/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
use std::collections::HashMap;
use std::string::FromUtf8Error;
use std::sync::Arc;

use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::{FieldRef, SchemaRef};
use arrow_schema::{ArrowError, Field, FieldRef, Fields, Schema, SchemaRef};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

use crate::array_reader::PyArrayReader;
use crate::error::PyArrowResult;
use crate::ffi::{ArrayIterator, ArrayReader};
use crate::{PyArray, PyRecordBatch, PyRecordBatchReader};
use crate::{PyArray, PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader};

/// An enum over [PyRecordBatch] and [PyRecordBatchReader], used when a function accepts either
/// Arrow object as input.
Expand Down Expand Up @@ -49,6 +52,22 @@ pub enum AnyArray {
}

impl AnyArray {
pub fn into_chunked_array(self) -> PyArrowResult<PyChunkedArray> {
match self {
Self::Array(array) => {
let (array, field) = array.into_inner();
Ok(PyChunkedArray::new(vec![array], field))
}
Self::Stream(stream) => {
let field = stream.field_ref()?;
let chunks = stream
.into_reader()?
.collect::<Result<Vec<_>, ArrowError>>()?;
Ok(PyChunkedArray::new(chunks, field))
}
}
}

pub fn into_reader(self) -> PyResult<Box<dyn ArrayReader + Send>> {
match self {
Self::Array(array) => {
Expand Down Expand Up @@ -97,6 +116,62 @@ impl Default for MetadataInput {

#[derive(FromPyObject)]
pub enum FieldIndexInput {
String(String),
Int(usize),
Name(String),
Position(usize),
}

impl FieldIndexInput {
pub fn into_position(self, schema: &Schema) -> PyArrowResult<usize> {
match self {
Self::Name(name) => Ok(schema.index_of(name.as_ref())?),
Self::Position(position) => Ok(position),
}
}
}

#[derive(FromPyObject)]
pub enum NameOrField {
Name(String),
Field(PyField),
}

impl NameOrField {
pub fn into_field(self, source_field: &Field) -> FieldRef {
match self {
Self::Name(name) => Arc::new(
Field::new(
name,
source_field.data_type().clone(),
source_field.is_nullable(),
)
.with_metadata(source_field.metadata().clone()),
),
Self::Field(field) => field.into_inner(),
}
}
}

#[derive(FromPyObject)]
pub enum SelectIndices {
Names(Vec<String>),
Positions(Vec<usize>),
}

impl SelectIndices {
pub fn into_positions(self, fields: &Fields) -> PyResult<Vec<usize>> {
match self {
Self::Names(names) => {
let mut positions = Vec::with_capacity(names.len());
for name in names {
let index = fields
.iter()
.position(|field| field.name() == &name)
.ok_or(PyValueError::new_err(format!("{name} not in schema.")))?;
positions.push(index);
}
Ok(positions)
}
Self::Positions(positions) => Ok(positions),
}
}
}
Loading

0 comments on commit 9e54d5b

Please sign in to comment.