From 7d8fe5a4db0632f8e497eaec11a375ac5365535e Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Thu, 16 Nov 2023 08:27:13 -0500 Subject: [PATCH] fix(snowflake): fix array printing by using a pyarrow extension type --- ibis/backends/snowflake/__init__.py | 6 +-- ibis/backends/snowflake/converter.py | 29 ++++++++++++++ ibis/backends/snowflake/tests/test_client.py | 7 ++++ ibis/expr/types/generic.py | 19 ++++++--- ibis/expr/types/relations.py | 10 +++-- ibis/formats/pyarrow.py | 42 +++++++++++++++++++- 6 files changed, 100 insertions(+), 13 deletions(-) diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index 2ac5df33cf9b..fe8ea92505c5 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -364,6 +364,8 @@ def to_pyarrow( limit: int | str | None = None, **_: Any, ) -> pa.Table: + from ibis.backends.snowflake.converter import SnowflakePyArrowData + self._run_pre_execute_hooks(expr) query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params) @@ -375,9 +377,7 @@ def to_pyarrow( if res is None: res = target_schema.empty_table() - res = res.rename_columns(target_schema.names).cast(target_schema) - - return expr.__pyarrow_result__(res) + return expr.__pyarrow_result__(res, data_mapper=SnowflakePyArrowData) def fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: if (table := cursor.cursor.fetch_arrow_all()) is None: diff --git a/ibis/backends/snowflake/converter.py b/ibis/backends/snowflake/converter.py index 06845acabdb7..4d46ee53bbc3 100644 --- a/ibis/backends/snowflake/converter.py +++ b/ibis/backends/snowflake/converter.py @@ -1,6 +1,15 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from ibis.formats.pandas import PandasData +from ibis.formats.pyarrow import PYARROW_JSON_TYPE, PyArrowData + +if TYPE_CHECKING: + import pyarrow as pa + + import ibis.expr.datatypes as dt + from ibis.expr.schema import Schema class SnowflakePandasData(PandasData): @@ -10,3 +19,23 @@ def convert_JSON(s, dtype, pandas_type): return s.map(converter, na_action="ignore").astype("object") convert_Struct = convert_Array = convert_Map = convert_JSON + + +class SnowflakePyArrowData(PyArrowData): + @classmethod + def convert_table(cls, table: pa.Table, schema: Schema) -> pa.Table: + import pyarrow as pa + + columns = [cls.convert_column(table[name], typ) for name, typ in schema.items()] + return pa.Table.from_arrays(columns, names=schema.names) + + @classmethod + def convert_column(cls, column: pa.Array, dtype: dt.DataType) -> pa.Array: + if dtype.is_json() or dtype.is_array() or dtype.is_map() or dtype.is_struct(): + import pyarrow as pa + + if isinstance(column, pa.ChunkedArray): + column = column.combine_chunks() + + return pa.ExtensionArray.from_storage(PYARROW_JSON_TYPE, column) + return super().convert_column(column, dtype) diff --git a/ibis/backends/snowflake/tests/test_client.py b/ibis/backends/snowflake/tests/test_client.py index 34508050b0b5..621b435292da 100644 --- a/ibis/backends/snowflake/tests/test_client.py +++ b/ibis/backends/snowflake/tests/test_client.py @@ -219,3 +219,10 @@ def test_read_parquet(con, data_dir): t = con.read_parquet(path) assert t.timestamp_col.type().is_timestamp() + + +def test_array_repr(con, monkeypatch): + monkeypatch.setattr(ibis.options, "interactive", True) + t = con.tables.ARRAY_TYPES + expr = t.x + assert repr(expr) diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 88f0dd5f2441..fb0ff37bb2b4 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -18,6 +18,7 @@ import ibis.expr.builders as bl import ibis.expr.types as ir + from ibis.formats.pyarrow import PyArrowData @public @@ -1204,10 +1205,13 @@ class Scalar(Value): def __interactive_rich_console__(self, console, options): return console.render(repr(self.execute()), options=options) - def __pyarrow_result__(self, table: pa.Table) -> pa.Scalar: - from ibis.formats.pyarrow import PyArrowData + def __pyarrow_result__( + self, table: pa.Table, data_mapper: type[PyArrowData] | None = None + ) -> pa.Scalar: + if data_mapper is None: + from ibis.formats.pyarrow import PyArrowData as data_mapper - return PyArrowData.convert_scalar(table[0][0], self.type()) + return data_mapper.convert_scalar(table[0][0], self.type()) def __pandas_result__(self, df: pd.DataFrame) -> Any: return df.iat[0, 0] @@ -1275,10 +1279,13 @@ def __interactive_rich_console__(self, console, options): projection = named.as_table() return console.render(projection, options=options) - def __pyarrow_result__(self, table: pa.Table) -> pa.Array | pa.ChunkedArray: - from ibis.formats.pyarrow import PyArrowData + def __pyarrow_result__( + self, table: pa.Table, data_mapper: type[PyArrowData] | None = None + ) -> pa.Array | pa.ChunkedArray: + if data_mapper is None: + from ibis.formats.pyarrow import PyArrowData as data_mapper - return PyArrowData.convert_column(table[0], self.type()) + return data_mapper.convert_column(table[0], self.type()) def __pandas_result__(self, df: pd.DataFrame) -> pd.Series: from ibis.formats.pandas import PandasData diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 787a27ae37fc..dbf0c497a071 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -32,6 +32,7 @@ from ibis.expr.types.groupby import GroupedTable from ibis.expr.types.tvf import WindowedTable from ibis.selectors import IfAnyAll, Selector + from ibis.formats.pyarrow import PyArrowData _ALIASES = (f"_ibis_view_{n:d}" for n in itertools.count()) @@ -158,10 +159,13 @@ def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): return IbisDataFrame(self, nan_as_null=nan_as_null, allow_copy=allow_copy) - def __pyarrow_result__(self, table: pa.Table) -> pa.Table: - from ibis.formats.pyarrow import PyArrowData + def __pyarrow_result__( + self, table: pa.Table, data_mapper: type[PyArrowData] | None = None + ) -> pa.Table: + if data_mapper is None: + from ibis.formats.pyarrow import PyArrowData as data_mapper - return PyArrowData.convert_table(table, self.schema()) + return data_mapper.convert_table(table, self.schema()) def __pandas_result__(self, df: pd.DataFrame) -> pd.DataFrame: from ibis.formats.pandas import PandasData diff --git a/ibis/formats/pyarrow.py b/ibis/formats/pyarrow.py index 5784198db7a2..03b26d1391ff 100644 --- a/ibis/formats/pyarrow.py +++ b/ibis/formats/pyarrow.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import TYPE_CHECKING, Any import pyarrow as pa @@ -12,6 +13,38 @@ if TYPE_CHECKING: from collections.abc import Sequence + +class JSONScalar(pa.ExtensionScalar): + def as_py(self): + value = self.value + if value is None: + return value + else: + return json.loads(value.as_py()) + + +class JSONArray(pa.ExtensionArray): + pass + + +class JSONType(pa.ExtensionType): + def __init__(self): + super().__init__(pa.string(), "ibis.json") + + def __arrow_ext_serialize__(self): + return b"" + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return cls() + + def __arrow_ext_class__(self): + return JSONArray + + def __arrow_ext_scalar_class__(self): + return JSONScalar + + _from_pyarrow_types = { pa.int8(): dt.Int8, pa.int16(): dt.Int16, @@ -57,7 +90,6 @@ dt.Unknown: pa.string(), dt.MACADDR: pa.string(), dt.INET: pa.string(), - dt.JSON: pa.string(), } @@ -95,6 +127,8 @@ def to_ibis(cls, typ: pa.DataType, nullable=True) -> dt.DataType: key_dtype = cls.to_ibis(typ.key_type, typ.key_field.nullable) value_dtype = cls.to_ibis(typ.item_type, typ.item_field.nullable) return dt.Map(key_dtype, value_dtype, nullable=nullable) + elif isinstance(typ, JSONType): + return dt.JSON() else: return _from_pyarrow_types[typ](nullable=nullable) @@ -154,6 +188,8 @@ def from_ibis(cls, dtype: dt.DataType) -> pa.DataType: nullable=dtype.value_type.nullable, ) return pa.map_(key_field, value_field, keys_sorted=False) + elif dtype.is_json(): + return PYARROW_JSON_TYPE else: try: return _to_pyarrow_types[type(dtype)] @@ -254,3 +290,7 @@ def convert_table(cls, table: pa.Table, schema: Schema) -> pa.Table: return table.cast(desired_schema) else: return table + + +PYARROW_JSON_TYPE = JSONType() +pa.register_extension_type(PYARROW_JSON_TYPE)