Skip to content

Commit

Permalink
Use indexmap to preserve insertion order of python arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Jul 30, 2024
1 parent 9e54d5b commit 1a18717
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 7 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ arrow-csv = "52"
arrow-ipc = "52"
arrow-schema = "52"
arrow-select = "52"
indexmap = "2"
numpy = "0.21"
parquet = "52"
pyo3 = "0.21"
Expand Down
78 changes: 77 additions & 1 deletion arro3-core/python/arro3/core/_core.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Literal, Sequence
from typing import Any, Literal, Sequence, overload
import numpy as np
from numpy.typing import NDArray

Expand Down Expand Up @@ -922,6 +922,46 @@ class Table:
def __eq__(self, other) -> bool: ...
def __len__(self) -> int: ...
def __repr__(self) -> str: ...
@overload
@classmethod
def from_arrays(
cls,
arrays: Sequence[ArrowArrayExportable | ArrowStreamExportable],
*,
names: Sequence[str],
schema: None = None,
metadata: dict[str, str] | dict[bytes, bytes] | None = None,
) -> Table: ...
@overload
@classmethod
def from_arrays(
cls,
arrays: Sequence[ArrowArrayExportable | ArrowStreamExportable],
*,
names: None = None,
schema: ArrowSchemaExportable,
metadata: None = None,
) -> Table: ...
@classmethod
def from_arrays(
cls,
arrays: Sequence[ArrowArrayExportable | ArrowStreamExportable],
*,
names: Sequence[str] | None = None,
schema: ArrowSchemaExportable | None = None,
metadata: dict[str, str] | dict[bytes, bytes] | None = None,
) -> Table:
"""Construct a Table from Arrow arrays.
Args:
arrays: Equal-length arrays that should form the table.
names: Names for the table columns. If not passed, `schema` must be passed. Defaults to None.
schema: Schema for the created table. If not passed, `names` must be passed. Defaults to None.
metadata: Optional metadata for the schema (if inferred). Defaults to None.
Returns:
new table
"""
@classmethod
def from_arrow(cls, input: ArrowStreamExportable) -> Table:
"""
Expand Down Expand Up @@ -949,6 +989,42 @@ class Table:
Returns:
_description_
"""
@overload
@classmethod
def from_pydict(
cls,
mapping: dict[str, ArrowArrayExportable | ArrowStreamExportable],
*,
schema: None = None,
metadata: dict[str, str] | dict[bytes, bytes] | None = None,
) -> Table: ...
@overload
@classmethod
def from_pydict(
cls,
mapping: dict[str, ArrowArrayExportable | ArrowStreamExportable],
*,
schema: ArrowSchemaExportable,
metadata: None = None,
) -> Table: ...
@classmethod
def from_pydict(
cls,
mapping: dict[str, ArrowArrayExportable | ArrowStreamExportable],
*,
schema: ArrowSchemaExportable | None = None,
metadata: dict[str, str] | dict[bytes, bytes] | None = None,
) -> Table:
"""Construct a Table or RecordBatch from Arrow arrays or columns.
Args:
mapping: A mapping of strings to Arrays.
schema: If not passed, will be inferred from the Mapping values. Defaults to None.
metadata: Optional metadata for the schema (if inferred). Defaults to None.
Returns:
new table
"""
def add_column(
self, i: int, field: str | ArrowSchemaExportable, column: ArrowStreamExportable
) -> RecordBatch:
Expand Down
3 changes: 2 additions & 1 deletion pyo3-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ arrow-array = { workspace = true }
arrow-buffer = { workspace = true }
arrow-schema = { workspace = true }
arrow = { workspace = true, features = ["ffi"] }
pyo3 = { workspace = true, features = ["abi3-py38"] }
pyo3 = { workspace = true, features = ["abi3-py38", "indexmap"] }
indexmap = { workspace = true }
numpy = { workspace = true, features = ["half"] }
thiserror = { workspace = true }

Expand Down
6 changes: 3 additions & 3 deletions pyo3-arrow/src/record_batch.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::collections::HashMap;
use std::fmt::Display;
use std::sync::Arc;

use arrow::array::AsArray;
use arrow_array::{Array, ArrayRef, RecordBatch, StructArray};
use arrow_schema::{DataType, Field, Schema, SchemaBuilder};
use indexmap::IndexMap;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::intern;
use pyo3::prelude::*;
Expand Down Expand Up @@ -103,7 +103,7 @@ impl PyRecordBatch {
&py.get_type_bound::<PyRecordBatch>(),
data,
)?)
} else if let Ok(mapping) = data.extract::<HashMap<String, PyArray>>() {
} else if let Ok(mapping) = data.extract::<IndexMap<String, PyArray>>() {
Self::from_pydict(&py.get_type_bound::<PyRecordBatch>(), mapping, metadata)
} else {
Err(PyTypeError::new_err("unsupported input").into())
Expand Down Expand Up @@ -173,7 +173,7 @@ impl PyRecordBatch {
#[pyo3(signature = (mapping, *, metadata=None))]
pub fn from_pydict(
_cls: &Bound<PyType>,
mapping: HashMap<String, PyArray>,
mapping: IndexMap<String, PyArray>,
metadata: Option<MetadataInput>,
) -> PyArrowResult<Self> {
let mut fields = vec![];
Expand Down
4 changes: 2 additions & 2 deletions pyo3-arrow/src/table.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::collections::HashMap;
use std::fmt::Display;
use std::sync::Arc;

Expand All @@ -7,6 +6,7 @@ use arrow::ffi_stream::ArrowArrayStreamReader as ArrowRecordBatchStreamReader;
use arrow_array::{ArrayRef, RecordBatchReader, StructArray};
use arrow_array::{RecordBatch, RecordBatchIterator};
use arrow_schema::{ArrowError, Field, Schema, SchemaRef};
use indexmap::IndexMap;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::intern;
use pyo3::prelude::*;
Expand Down Expand Up @@ -144,7 +144,7 @@ impl PyTable {
#[pyo3(signature = (mapping, *, schema=None, metadata=None))]
pub fn from_pydict(
cls: &Bound<PyType>,
mapping: HashMap<String, AnyArray>,
mapping: IndexMap<String, AnyArray>,
schema: Option<PySchema>,
metadata: Option<MetadataInput>,
) -> PyArrowResult<Self> {
Expand Down
17 changes: 17 additions & 0 deletions tests/core/test_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pyarrow as pa
from arro3.core import Table


def test_table_from_arrays():
a = pa.array([1, 2, 3, 4])
b = pa.array(["a", "b", "c", "d"])
arro3_table = Table.from_arrays([a, b], names=["a", "b"])
pa_table = pa.Table.from_arrays([a, b], names=["a", "b"])
assert pa.table(arro3_table) == pa_table


def test_table_from_pydict():
mapping = {"a": pa.array([1, 2, 3, 4]), "b": pa.array(["a", "b", "c", "d"])}
arro3_table = Table.from_pydict(mapping)
pa_table = pa.Table.from_pydict(mapping)
assert pa.table(arro3_table) == pa_table

0 comments on commit 1a18717

Please sign in to comment.