From e4f2b3a0fdbbb2555d75d2a72a134bd18238a306 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Sun, 28 Jul 2024 00:24:03 -0400 Subject: [PATCH] Allow more str | field --- arro3-core/python/arro3/core/_rust.pyi | 6 +++--- pyo3-arrow/src/record_batch.rs | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/arro3-core/python/arro3/core/_rust.pyi b/arro3-core/python/arro3/core/_rust.pyi index 27b63ab..7eed0a5 100644 --- a/arro3-core/python/arro3/core/_rust.pyi +++ b/arro3-core/python/arro3/core/_rust.pyi @@ -213,10 +213,10 @@ class RecordBatch: def from_arrow_pycapsule(cls, schema_capsule, array_capsule) -> RecordBatch: """Construct this object from bare Arrow PyCapsules""" def add_column( - self, i: int, field: ArrowSchemaExportable, column: ArrowArrayExportable + self, i: int, field: str | ArrowSchemaExportable, column: ArrowArrayExportable ) -> RecordBatch: ... def append_column( - self, field: ArrowSchemaExportable, column: ArrowArrayExportable + self, field: str | ArrowSchemaExportable, column: ArrowArrayExportable ) -> RecordBatch: """Append column at end of columns. @@ -286,7 +286,7 @@ class RecordBatch: """Access the schema of this RecordBatch""" def select(self, columns: list[int] | list[str]) -> RecordBatch: ... def set_column( - self, i: int, field: ArrowSchemaExportable, column: ArrowArrayExportable + self, i: int, field: str | ArrowSchemaExportable, column: ArrowArrayExportable ) -> RecordBatch: ... @property def shape(self) -> tuple[int, int]: ... diff --git a/pyo3-arrow/src/record_batch.rs b/pyo3-arrow/src/record_batch.rs index 76f7c9d..b82593a 100644 --- a/pyo3-arrow/src/record_batch.rs +++ b/pyo3-arrow/src/record_batch.rs @@ -14,7 +14,7 @@ use crate::error::PyArrowResult; use crate::ffi::from_python::utils::import_array_pycapsules; use crate::ffi::to_python::nanoarrow::to_nanoarrow_array; use crate::ffi::to_python::to_array_pycapsules; -use crate::input::{FieldIndexInput, MetadataInput, SelectIndices}; +use crate::input::{FieldIndexInput, MetadataInput, NameOrField, SelectIndices}; use crate::schema::display_schema; use crate::{PyArray, PyField, PySchema}; @@ -258,11 +258,11 @@ impl PyRecordBatch { &self, py: Python, i: usize, - field: PyField, + field: NameOrField, column: PyArray, ) -> PyArrowResult { let mut fields = self.0.schema_ref().fields().to_vec(); - fields.insert(i, field.into_inner()); + fields.insert(i, field.into_field(column.field())); let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone()); let mut arrays = self.0.columns().to_vec(); @@ -275,11 +275,11 @@ impl PyRecordBatch { pub fn append_column( &self, py: Python, - field: PyField, + field: NameOrField, column: PyArray, ) -> PyArrowResult { let mut fields = self.0.schema_ref().fields().to_vec(); - fields.push(field.into_inner()); + fields.push(field.into_field(column.field())); let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone()); let mut arrays = self.0.columns().to_vec(); @@ -360,11 +360,11 @@ impl PyRecordBatch { &self, py: Python, i: usize, - field: PyField, + field: NameOrField, column: PyArray, ) -> PyArrowResult { let mut fields = self.0.schema_ref().fields().to_vec(); - fields[i] = field.into_inner(); + fields[i] = field.into_field(column.field()); let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone()); let mut arrays = self.0.columns().to_vec();