From 140234723446af8257ca9ea08e5d7c74195b2f64 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 21 Sep 2022 07:29:31 -0400 Subject: [PATCH] feat(api): implement `__array__` --- ibis/backends/tests/test_client.py | 21 +++++++++++++++++++++ ibis/expr/types/generic.py | 3 +++ ibis/expr/types/relations.py | 3 +++ 3 files changed, 27 insertions(+) diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 4be11da2f541..42bc660483ce 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -1,6 +1,7 @@ import platform import re +import numpy as np import pandas as pd import pandas.testing as tm import pytest @@ -691,3 +692,23 @@ def test_default_backend(): SUM\\((\\w+)\\.a\\) AS sum FROM \\w+ AS \\1""" assert re.match(rx, sql) is not None + + +def test_dunder_array_table(alltypes, df): + expr = alltypes.group_by("string_col").int_col.sum().sort_by("string_col") + result = np.array(expr) + expected = np.array( + df.groupby("string_col") + .int_col.sum() + .reset_index() + .sort_values(["string_col"]) + ) + np.testing.assert_array_equal(result, expected) + + +@pytest.mark.broken(["dask"], reason="Dask backend duplicates data") +def test_dunder_array_column(alltypes, df): + expr = alltypes.sort_by("id").head(10).int_col + result = np.array(expr) + expected = df.sort_values(["id"]).head(10).int_col + np.testing.assert_array_equal(result, expected) diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index dc5bc967b077..455a2074a17b 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -513,6 +513,9 @@ def _repr_html_(self) -> str | None: @public class Column(Value): + def __array__(self): + return self.execute().__array__() + def _repr_html_(self) -> str | None: if not ibis.options.interactive: return None diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index ccf64a383c1a..be3683ed697a 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -83,6 +83,9 @@ def f( @public class Table(Expr): + def __array__(self): + return self.execute().__array__() + def __contains__(self, name): return name in self.schema()