Skip to content

Commit

Permalink
Implement __getitem__ for Table, RecordBatch, Schema
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Jul 31, 2024
1 parent d5a529a commit c1e1c4d
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 12 deletions.
3 changes: 3 additions & 0 deletions arro3-core/python/arro3/core/_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ class RecordBatch:
self, requested_schema: object | None = None
) -> tuple[object, object]: ...
def __eq__(self, other) -> bool: ...
def __getitem__(self, key: int | str) -> Array: ...
def __repr__(self) -> str: ...
@classmethod
def from_arrays(
Expand Down Expand Up @@ -767,6 +768,7 @@ class Schema:
"""

def __eq__(self, other) -> bool: ...
def __getitem__(self, key: int | str) -> Field: ...
def __len__(self) -> int: ...
def __repr__(self) -> str: ...
@classmethod
Expand Down Expand Up @@ -928,6 +930,7 @@ class Table:
_description_
"""
def __eq__(self, other) -> bool: ...
def __getitem__(self, key: int | str) -> ChunkedArray: ...
def __len__(self) -> int: ...
def __repr__(self) -> str: ...
@overload
Expand Down
16 changes: 4 additions & 12 deletions pyo3-arrow/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,24 +127,16 @@ impl PyRecordBatch {
let field = Field::new_struct("", self.0.schema_ref().fields().clone(), false);
let array: ArrayRef = Arc::new(StructArray::from(self.0.clone()));
to_array_pycapsules(py, field.into(), &array, requested_schema)

// let schema = self.0.schema();
// let array = StructArray::from(self.0.clone());

// let ffi_schema = FFI_ArrowSchema::try_from(schema.as_ref())?;
// let ffi_array = FFI_ArrowArray::new(&array.to_data());

// let schema_capsule_name = CString::new("arrow_schema").unwrap();
// let array_capsule_name = CString::new("arrow_array").unwrap();
// let schema_capsule = PyCapsule::new_bound(py, ffi_schema, Some(schema_capsule_name))?;
// let array_capsule = PyCapsule::new_bound(py, ffi_array, Some(array_capsule_name))?;
// Ok(PyTuple::new_bound(py, vec![schema_capsule, array_capsule]))
}

pub fn __eq__(&self, other: &PyRecordBatch) -> bool {
self.0 == other.0
}

fn __getitem__(&self, py: Python, key: FieldIndexInput) -> PyResult<PyObject> {
self.column(py, key)
}

pub fn __repr__(&self) -> String {
self.to_string()
}
Expand Down
4 changes: 4 additions & 0 deletions pyo3-arrow/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ impl PySchema {
self.0 == other.0
}

fn __getitem__(&self, py: Python, key: FieldIndexInput) -> PyArrowResult<PyObject> {
self.field(py, key)
}

pub fn __len__(&self) -> usize {
self.0.fields().len()
}
Expand Down
4 changes: 4 additions & 0 deletions pyo3-arrow/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ impl PyTable {
self.batches == other.batches && self.schema == other.schema
}

fn __getitem__(&self, py: Python, key: FieldIndexInput) -> PyArrowResult<PyObject> {
self.column(py, key)
}

pub fn __len__(&self) -> usize {
self.batches.iter().fold(0, |acc, x| acc + x.num_rows())
}
Expand Down
10 changes: 10 additions & 0 deletions tests/core/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
from arro3.core import Table


def test_table_getitem():
a = pa.chunked_array([[1, 2, 3, 4]])
b = pa.chunked_array([["a", "b", "c", "d"]])
table = Table.from_pydict({"a": a, "b": b})
assert a == pa.chunked_array(table["a"])
assert b == pa.chunked_array(table["b"])
assert a == pa.chunked_array(table[0])
assert b == pa.chunked_array(table[1])


def test_table_from_arrays():
a = pa.array([1, 2, 3, 4])
b = pa.array(["a", "b", "c", "d"])
Expand Down

0 comments on commit c1e1c4d

Please sign in to comment.