From 1a187179cd02ea2b21d95f07d376756b65438859 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Tue, 30 Jul 2024 01:39:26 -0400 Subject: [PATCH] Use indexmap to preserve insertion order of python arrays --- Cargo.lock | 2 + Cargo.toml | 1 + arro3-core/python/arro3/core/_core.pyi | 78 +++++++++++++++++++++++++- pyo3-arrow/Cargo.toml | 3 +- pyo3-arrow/src/record_batch.rs | 6 +- pyo3-arrow/src/table.rs | 4 +- tests/core/test_table.py | 17 ++++++ 7 files changed, 104 insertions(+), 7 deletions(-) create mode 100644 tests/core/test_table.py diff --git a/Cargo.lock b/Cargo.lock index 5fdeeee..a05cc90 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1047,6 +1047,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" dependencies = [ "cfg-if", + "indexmap", "indoc", "libc", "memoffset", @@ -1067,6 +1068,7 @@ dependencies = [ "arrow-buffer", "arrow-schema", "arrow-select", + "indexmap", "numpy", "pyo3", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index 8acc12b..2e9dce2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ arrow-csv = "52" arrow-ipc = "52" arrow-schema = "52" arrow-select = "52" +indexmap = "2" numpy = "0.21" parquet = "52" pyo3 = "0.21" diff --git a/arro3-core/python/arro3/core/_core.pyi b/arro3-core/python/arro3/core/_core.pyi index 8d6cc42..58d58c5 100644 --- a/arro3-core/python/arro3/core/_core.pyi +++ b/arro3-core/python/arro3/core/_core.pyi @@ -1,4 +1,4 @@ -from typing import Any, Literal, Sequence +from typing import Any, Literal, Sequence, overload import numpy as np from numpy.typing import NDArray @@ -922,6 +922,46 @@ class Table: def __eq__(self, other) -> bool: ... def __len__(self) -> int: ... def __repr__(self) -> str: ... + @overload + @classmethod + def from_arrays( + cls, + arrays: Sequence[ArrowArrayExportable | ArrowStreamExportable], + *, + names: Sequence[str], + schema: None = None, + metadata: dict[str, str] | dict[bytes, bytes] | None = None, + ) -> Table: ... + @overload + @classmethod + def from_arrays( + cls, + arrays: Sequence[ArrowArrayExportable | ArrowStreamExportable], + *, + names: None = None, + schema: ArrowSchemaExportable, + metadata: None = None, + ) -> Table: ... + @classmethod + def from_arrays( + cls, + arrays: Sequence[ArrowArrayExportable | ArrowStreamExportable], + *, + names: Sequence[str] | None = None, + schema: ArrowSchemaExportable | None = None, + metadata: dict[str, str] | dict[bytes, bytes] | None = None, + ) -> Table: + """Construct a Table from Arrow arrays. + + Args: + arrays: Equal-length arrays that should form the table. + names: Names for the table columns. If not passed, `schema` must be passed. Defaults to None. + schema: Schema for the created table. If not passed, `names` must be passed. Defaults to None. + metadata: Optional metadata for the schema (if inferred). Defaults to None. + + Returns: + new table + """ @classmethod def from_arrow(cls, input: ArrowStreamExportable) -> Table: """ @@ -949,6 +989,42 @@ class Table: Returns: _description_ """ + @overload + @classmethod + def from_pydict( + cls, + mapping: dict[str, ArrowArrayExportable | ArrowStreamExportable], + *, + schema: None = None, + metadata: dict[str, str] | dict[bytes, bytes] | None = None, + ) -> Table: ... + @overload + @classmethod + def from_pydict( + cls, + mapping: dict[str, ArrowArrayExportable | ArrowStreamExportable], + *, + schema: ArrowSchemaExportable, + metadata: None = None, + ) -> Table: ... + @classmethod + def from_pydict( + cls, + mapping: dict[str, ArrowArrayExportable | ArrowStreamExportable], + *, + schema: ArrowSchemaExportable | None = None, + metadata: dict[str, str] | dict[bytes, bytes] | None = None, + ) -> Table: + """Construct a Table or RecordBatch from Arrow arrays or columns. + + Args: + mapping: A mapping of strings to Arrays. + schema: If not passed, will be inferred from the Mapping values. Defaults to None. + metadata: Optional metadata for the schema (if inferred). Defaults to None. + + Returns: + new table + """ def add_column( self, i: int, field: str | ArrowSchemaExportable, column: ArrowStreamExportable ) -> RecordBatch: diff --git a/pyo3-arrow/Cargo.toml b/pyo3-arrow/Cargo.toml index ad41528..d46f83e 100644 --- a/pyo3-arrow/Cargo.toml +++ b/pyo3-arrow/Cargo.toml @@ -16,7 +16,8 @@ arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-schema = { workspace = true } arrow = { workspace = true, features = ["ffi"] } -pyo3 = { workspace = true, features = ["abi3-py38"] } +pyo3 = { workspace = true, features = ["abi3-py38", "indexmap"] } +indexmap = { workspace = true } numpy = { workspace = true, features = ["half"] } thiserror = { workspace = true } diff --git a/pyo3-arrow/src/record_batch.rs b/pyo3-arrow/src/record_batch.rs index d5c4be5..ce23a5d 100644 --- a/pyo3-arrow/src/record_batch.rs +++ b/pyo3-arrow/src/record_batch.rs @@ -1,10 +1,10 @@ -use std::collections::HashMap; use std::fmt::Display; use std::sync::Arc; use arrow::array::AsArray; use arrow_array::{Array, ArrayRef, RecordBatch, StructArray}; use arrow_schema::{DataType, Field, Schema, SchemaBuilder}; +use indexmap::IndexMap; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::intern; use pyo3::prelude::*; @@ -103,7 +103,7 @@ impl PyRecordBatch { &py.get_type_bound::(), data, )?) - } else if let Ok(mapping) = data.extract::>() { + } else if let Ok(mapping) = data.extract::>() { Self::from_pydict(&py.get_type_bound::(), mapping, metadata) } else { Err(PyTypeError::new_err("unsupported input").into()) @@ -173,7 +173,7 @@ impl PyRecordBatch { #[pyo3(signature = (mapping, *, metadata=None))] pub fn from_pydict( _cls: &Bound, - mapping: HashMap, + mapping: IndexMap, metadata: Option, ) -> PyArrowResult { let mut fields = vec![]; diff --git a/pyo3-arrow/src/table.rs b/pyo3-arrow/src/table.rs index dd9cc52..2c5a4c4 100644 --- a/pyo3-arrow/src/table.rs +++ b/pyo3-arrow/src/table.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::fmt::Display; use std::sync::Arc; @@ -7,6 +6,7 @@ use arrow::ffi_stream::ArrowArrayStreamReader as ArrowRecordBatchStreamReader; use arrow_array::{ArrayRef, RecordBatchReader, StructArray}; use arrow_array::{RecordBatch, RecordBatchIterator}; use arrow_schema::{ArrowError, Field, Schema, SchemaRef}; +use indexmap::IndexMap; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::intern; use pyo3::prelude::*; @@ -144,7 +144,7 @@ impl PyTable { #[pyo3(signature = (mapping, *, schema=None, metadata=None))] pub fn from_pydict( cls: &Bound, - mapping: HashMap, + mapping: IndexMap, schema: Option, metadata: Option, ) -> PyArrowResult { diff --git a/tests/core/test_table.py b/tests/core/test_table.py new file mode 100644 index 0000000..95e0d12 --- /dev/null +++ b/tests/core/test_table.py @@ -0,0 +1,17 @@ +import pyarrow as pa +from arro3.core import Table + + +def test_table_from_arrays(): + a = pa.array([1, 2, 3, 4]) + b = pa.array(["a", "b", "c", "d"]) + arro3_table = Table.from_arrays([a, b], names=["a", "b"]) + pa_table = pa.Table.from_arrays([a, b], names=["a", "b"]) + assert pa.table(arro3_table) == pa_table + + +def test_table_from_pydict(): + mapping = {"a": pa.array([1, 2, 3, 4]), "b": pa.array(["a", "b", "c", "d"])} + arro3_table = Table.from_pydict(mapping) + pa_table = pa.Table.from_pydict(mapping) + assert pa.table(arro3_table) == pa_table