From ebadc3dfd70d70c22ffcb7906a2fc290e1e0be70 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Wed, 31 Jul 2024 00:16:49 -0400 Subject: [PATCH] Implement `__getitem__` for Table, RecordBatch, Schema (#84) --- arro3-core/python/arro3/core/_core.pyi | 3 +++ pyo3-arrow/src/record_batch.rs | 16 ++++------------ pyo3-arrow/src/schema.rs | 4 ++++ pyo3-arrow/src/table.rs | 4 ++++ tests/core/test_table.py | 10 ++++++++++ 5 files changed, 25 insertions(+), 12 deletions(-) diff --git a/arro3-core/python/arro3/core/_core.pyi b/arro3-core/python/arro3/core/_core.pyi index a10922e..37a6337 100644 --- a/arro3-core/python/arro3/core/_core.pyi +++ b/arro3-core/python/arro3/core/_core.pyi @@ -591,6 +591,7 @@ class RecordBatch: self, requested_schema: object | None = None ) -> tuple[object, object]: ... def __eq__(self, other) -> bool: ... + def __getitem__(self, key: int | str) -> Array: ... def __repr__(self) -> str: ... @classmethod def from_arrays( @@ -767,6 +768,7 @@ class Schema: """ def __eq__(self, other) -> bool: ... + def __getitem__(self, key: int | str) -> Field: ... def __len__(self) -> int: ... def __repr__(self) -> str: ... @classmethod @@ -928,6 +930,7 @@ class Table: _description_ """ def __eq__(self, other) -> bool: ... + def __getitem__(self, key: int | str) -> ChunkedArray: ... def __len__(self) -> int: ... def __repr__(self) -> str: ... @overload diff --git a/pyo3-arrow/src/record_batch.rs b/pyo3-arrow/src/record_batch.rs index 3e1e36c..63ee228 100644 --- a/pyo3-arrow/src/record_batch.rs +++ b/pyo3-arrow/src/record_batch.rs @@ -127,24 +127,16 @@ impl PyRecordBatch { let field = Field::new_struct("", self.0.schema_ref().fields().clone(), false); let array: ArrayRef = Arc::new(StructArray::from(self.0.clone())); to_array_pycapsules(py, field.into(), &array, requested_schema) - - // let schema = self.0.schema(); - // let array = StructArray::from(self.0.clone()); - - // let ffi_schema = FFI_ArrowSchema::try_from(schema.as_ref())?; - // let ffi_array = FFI_ArrowArray::new(&array.to_data()); - - // let schema_capsule_name = CString::new("arrow_schema").unwrap(); - // let array_capsule_name = CString::new("arrow_array").unwrap(); - // let schema_capsule = PyCapsule::new_bound(py, ffi_schema, Some(schema_capsule_name))?; - // let array_capsule = PyCapsule::new_bound(py, ffi_array, Some(array_capsule_name))?; - // Ok(PyTuple::new_bound(py, vec![schema_capsule, array_capsule])) } pub fn __eq__(&self, other: &PyRecordBatch) -> bool { self.0 == other.0 } + fn __getitem__(&self, py: Python, key: FieldIndexInput) -> PyResult { + self.column(py, key) + } + pub fn __repr__(&self) -> String { self.to_string() } diff --git a/pyo3-arrow/src/schema.rs b/pyo3-arrow/src/schema.rs index 80a571d..2db845d 100644 --- a/pyo3-arrow/src/schema.rs +++ b/pyo3-arrow/src/schema.rs @@ -127,6 +127,10 @@ impl PySchema { self.0 == other.0 } + fn __getitem__(&self, py: Python, key: FieldIndexInput) -> PyArrowResult { + self.field(py, key) + } + pub fn __len__(&self) -> usize { self.0.fields().len() } diff --git a/pyo3-arrow/src/table.rs b/pyo3-arrow/src/table.rs index f885b15..d072083 100644 --- a/pyo3-arrow/src/table.rs +++ b/pyo3-arrow/src/table.rs @@ -110,6 +110,10 @@ impl PyTable { self.batches == other.batches && self.schema == other.schema } + fn __getitem__(&self, py: Python, key: FieldIndexInput) -> PyArrowResult { + self.column(py, key) + } + pub fn __len__(&self) -> usize { self.batches.iter().fold(0, |acc, x| acc + x.num_rows()) } diff --git a/tests/core/test_table.py b/tests/core/test_table.py index 95e0d12..21d17df 100644 --- a/tests/core/test_table.py +++ b/tests/core/test_table.py @@ -2,6 +2,16 @@ from arro3.core import Table +def test_table_getitem(): + a = pa.chunked_array([[1, 2, 3, 4]]) + b = pa.chunked_array([["a", "b", "c", "d"]]) + table = Table.from_pydict({"a": a, "b": b}) + assert a == pa.chunked_array(table["a"]) + assert b == pa.chunked_array(table["b"]) + assert a == pa.chunked_array(table[0]) + assert b == pa.chunked_array(table[1]) + + def test_table_from_arrays(): a = pa.array([1, 2, 3, 4]) b = pa.array(["a", "b", "c", "d"])