Skip to content

Commit

Permalink
fix(postgres): fix json type conversion in to_pyarrow output (#8439)
Browse files Browse the repository at this point in the history
Localize custom pyarrow json serialization to snowflake. This custom
type doesn't compose well (e.g., inside structs, which aren't directly
supported in snowflake anyway) and was causing problems for other
backends when trying to convert rows into a struct of the table schema
and then into a proper table.

Fixes #8318.
  • Loading branch information
cpcloud authored Feb 27, 2024
1 parent 3730eb6 commit b338517
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 65 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def do_connect(
month : int32
"""

psycopg2.extras.register_default_json(loads=lambda x: x)
self.con = psycopg2.connect(
host=host,
port=port,
Expand Down
46 changes: 39 additions & 7 deletions ibis/backends/snowflake/converter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,54 @@
from __future__ import annotations

import datetime
import json
from typing import TYPE_CHECKING

import pyarrow as pa

from ibis.formats.pandas import PandasData
from ibis.formats.pyarrow import PYARROW_JSON_TYPE, PyArrowData
from ibis.formats.pyarrow import PyArrowData

if TYPE_CHECKING:
import pyarrow as pa

import ibis.expr.datatypes as dt
from ibis.expr.schema import Schema


class JSONScalar(pa.ExtensionScalar):
def as_py(self):
value = self.value
if value is None:
return value
else:
return json.loads(value.as_py())


class JSONArray(pa.ExtensionArray):
pass


class JSONType(pa.ExtensionType):
def __init__(self):
super().__init__(pa.string(), "ibis.json")

def __arrow_ext_serialize__(self):
return b""

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls()

def __arrow_ext_class__(self):
return JSONArray

def __arrow_ext_scalar_class__(self):
return JSONScalar


PYARROW_JSON_TYPE = JSONType()
pa.register_extension_type(PYARROW_JSON_TYPE)


class SnowflakePandasData(PandasData):
@classmethod
def convert_Timestamp_element(cls, dtype):
Expand Down Expand Up @@ -50,16 +86,12 @@ def convert_Struct(cls, s, dtype, pandas_type):
class SnowflakePyArrowData(PyArrowData):
@classmethod
def convert_table(cls, table: pa.Table, schema: Schema) -> pa.Table:
import pyarrow as pa

columns = [cls.convert_column(table[name], typ) for name, typ in schema.items()]
return pa.Table.from_arrays(columns, names=schema.names)

@classmethod
def convert_column(cls, column: pa.Array, dtype: dt.DataType) -> pa.Array:
if dtype.is_json() or dtype.is_array() or dtype.is_map() or dtype.is_struct():
import pyarrow as pa

if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()

Expand Down
36 changes: 19 additions & 17 deletions ibis/backends/snowflake/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,27 +119,29 @@ def add_catalog_and_schema(node):
def _load_data(self, **_: Any) -> None:
"""Load test data into a Snowflake backend instance."""

connect_args = {}

url = urlparse(_get_url())
db, schema = url.path[1:].split("/", 1)
(warehouse,) = parse_qs(url.query)["warehouse"]
connect_args = {
"user": url.username,
"password": url.password,
"account": url.hostname,
"warehouse": warehouse,
}

session_parameters = {
"MULTI_STATEMENT_COUNT": 0,
"JSON_INDENT": 0,
"PYTHON_CONNECTOR_QUERY_RESULT_FORMAT": "arrow_force",
}

if url.path:
db, schema = url.path[1:].split("/", 1)
(warehouse,) = parse_qs(url.query)["warehouse"]
connect_args.update(
{
"user": url.username,
"password": url.password,
"account": url.hostname,
"warehouse": warehouse,
}
)
else:
db = os.environ["SNOWFLAKE_DATABASE"]
schema = os.environ["SNOWFLAKE_SCHEMA"]

dbschema = f"{db}.{schema}"

with closing(
sc.connect(**connect_args, session_parameters=session_parameters)
) as con, closing(con.cursor()) as c:
with closing(sc.connect(**connect_args)) as con, closing(con.cursor()) as c:
c.execute("ALTER SESSION SET MULTI_STATEMENT_COUNT = 0 JSON_INDENT = 0")
c.execute(
f"""
CREATE DATABASE IF NOT EXISTS {db};
Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/snowflake/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,16 @@ def test_no_argument_connection():

con = ibis.connect("snowflake://")
assert con.list_tables() is not None


def test_struct_of_json(con):
raw = {"a": [1, 2, 3], "b": "456"}
lit = ibis.struct(raw)
expr = lit.cast("struct<a: array<int>, b: json>")

n = 5
t = con.tables.functional_alltypes.mutate(lit=expr).limit(n).lit
result = con.to_pyarrow(t)

assert len(result) == n
assert all(value == raw for value in result.to_pylist())
44 changes: 44 additions & 0 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import importlib
import inspect
import json
import re
import string
import subprocess
Expand Down Expand Up @@ -1468,3 +1469,46 @@ def test_close_connection(con):
# DB-API states that subsequent execution attempt should raise
with pytest.raises(Exception): # noqa:B017
new_con.list_tables()


@pytest.mark.notyet(
["clickhouse"],
raises=AttributeError,
reason="JSON extension is experimental and not enabled by default in testing",
)
@pytest.mark.notyet(
["datafusion", "polars", "mssql", "druid", "oracle", "exasol", "impala"],
raises=AttributeError,
reason="JSON type not implemented",
)
@pytest.mark.notimpl(
["risingwave", "sqlite"],
raises=pa.ArrowTypeError,
reason="mismatch between output value and expected input type",
)
@pytest.mark.never(
["snowflake"],
raises=TypeError,
reason="snowflake uses a custom pyarrow extension type for JSON pretty printing",
)
def test_json_to_pyarrow(con):
t = con.tables.json_t
table = t.to_pyarrow()
js = table["js"]

expected = [
{"a": [1, 2, 3, 4], "b": 1},
{"a": None, "b": 2},
{"a": "foo", "c": None},
None,
[42, 47, 55],
[],
]
expected = {json.dumps(val) for val in expected}

result = {
# loads and dumps so the string representation is the same
json.dumps(json.loads(val))
for val in js.to_pylist()
}
assert result == expected
41 changes: 1 addition & 40 deletions ibis/formats/pyarrow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any

import pyarrow as pa
Expand All @@ -14,37 +13,6 @@
from collections.abc import Sequence


class JSONScalar(pa.ExtensionScalar):
def as_py(self):
value = self.value
if value is None:
return value
else:
return json.loads(value.as_py())


class JSONArray(pa.ExtensionArray):
pass


class JSONType(pa.ExtensionType):
def __init__(self):
super().__init__(pa.string(), "ibis.json")

def __arrow_ext_serialize__(self):
return b""

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls()

def __arrow_ext_class__(self):
return JSONArray

def __arrow_ext_scalar_class__(self):
return JSONScalar


_from_pyarrow_types = {
pa.int8(): dt.Int8,
pa.int16(): dt.Int16,
Expand Down Expand Up @@ -91,6 +59,7 @@ def __arrow_ext_scalar_class__(self):
dt.MACADDR: pa.string(),
dt.INET: pa.string(),
dt.UUID: pa.string(),
dt.JSON: pa.string(),
}


Expand Down Expand Up @@ -128,8 +97,6 @@ def to_ibis(cls, typ: pa.DataType, nullable=True) -> dt.DataType:
key_dtype = cls.to_ibis(typ.key_type, typ.key_field.nullable)
value_dtype = cls.to_ibis(typ.item_type, typ.item_field.nullable)
return dt.Map(key_dtype, value_dtype, nullable=nullable)
elif isinstance(typ, JSONType):
return dt.JSON()
elif pa.types.is_dictionary(typ):
return cls.to_ibis(typ.value_type)
else:
Expand Down Expand Up @@ -191,8 +158,6 @@ def from_ibis(cls, dtype: dt.DataType) -> pa.DataType:
nullable=dtype.value_type.nullable,
)
return pa.map_(key_field, value_field, keys_sorted=False)
elif dtype.is_json():
return PYARROW_JSON_TYPE
elif dtype.is_geospatial():
return pa.binary()
else:
Expand Down Expand Up @@ -303,7 +268,3 @@ def to_frame(self):

def to_pyarrow(self, schema: Schema) -> pa.Table:
return self.obj


PYARROW_JSON_TYPE = JSONType()
pa.register_extension_type(PYARROW_JSON_TYPE)

0 comments on commit b338517

Please sign in to comment.