Skip to content

Commit

Permalink
feat(postgres): add Map(string, string) support via the built-in `H…
Browse files Browse the repository at this point in the history
…STORE` extension
  • Loading branch information
cpcloud authored and kszucs committed Jan 23, 2023
1 parent 8b01f1b commit f968f8f
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 14 deletions.
7 changes: 7 additions & 0 deletions ci/schema/postgresql.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
DROP SEQUENCE IF EXISTS test_sequence;
CREATE SEQUENCE IF NOT EXISTS test_sequence;

CREATE EXTENSION IF NOT EXISTS hstore;
CREATE EXTENSION IF NOT EXISTS postgis;
CREATE EXTENSION IF NOT EXISTS plpython3u;

Expand Down Expand Up @@ -204,3 +205,9 @@ INSERT INTO win VALUES
('a', 2, 0),
('a', 3, 1),
('a', 4, 1);

DROP TABLE IF EXISTS map CASCADE;
CREATE TABLE map (kv HSTORE);
INSERT INTO map VALUES
('a=>1,b=>2,c=>3'),
('d=>4,e=>5,c=>6');
5 changes: 2 additions & 3 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,9 @@ def _get_insert_method(self, expr):
return methodcaller("from_select", list(expr.columns), compiled)

def _columns_from_schema(self, name: str, schema: sch.Schema) -> list[sa.Column]:
dialect = self.con.dialect
return [
sa.Column(
colname, to_sqla_type(self.con.dialect, dtype), nullable=dtype.nullable
)
sa.Column(colname, to_sqla_type(dialect, dtype), nullable=dtype.nullable)
for colname, dtype in zip(schema.names, schema.types)
]

Expand Down
12 changes: 12 additions & 0 deletions ibis/backends/base/sql/alchemy/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,13 @@ def _pg_array(dialect, itype):
return sa.ARRAY(to_sqla_type(dialect, itype))


@to_sqla_type.register(PGDialect, dt.Map)
def _pg_map(dialect, itype):
if not (itype.key_type.is_string() and itype.value_type.is_string()):
raise TypeError(f"PostgreSQL only supports map<string, string>, got: {itype}")
return postgresql.HSTORE


@to_sqla_type.register(Dialect, dt.Struct)
def _struct(dialect, itype):
return StructType(
Expand Down Expand Up @@ -294,6 +301,11 @@ def sa_macaddr(_, satype, nullable=True):
return dt.MACADDR(nullable=nullable)


@dt.dtype.register(PGDialect, postgresql.HSTORE)
def sa_hstore(_, satype, nullable=True):
return dt.Map(dt.string, dt.string, nullable=nullable)


@dt.dtype.register(PGDialect, postgresql.INET)
def sa_inet(_, satype, nullable=True):
return dt.INET(nullable=nullable)
Expand Down
8 changes: 8 additions & 0 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ def _struct_column(t, op):
ops.Translate,
# ibis.expr.operations.temporal
ops.TimestampDiff,
# ibis.expr.operations.maps
ops.MapGet,
ops.MapContains,
ops.MapKeys,
ops.MapValues,
ops.MapMerge,
ops.MapLength,
ops.Map,
}

operation_registry = {
Expand Down
16 changes: 15 additions & 1 deletion ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,13 @@ def _literal(t, op):
elif dtype.is_geospatial():
# inline_metadata ex: 'SRID=4326;POINT( ... )'
return sa.literal_column(geo.translate_literal(op, inline_metadata=True))
elif isinstance(value, tuple):
elif dtype.is_array():
return sa.literal_column(
str(pg.array(value).compile(compile_kwargs=dict(literal_binds=True))),
type_=t.get_sqla_type(dtype),
)
elif dtype.is_map():
return pg.hstore(pg.array(list(value.keys())), pg.array(list(value.values())))
else:
return sa.literal(value)

Expand Down Expand Up @@ -585,5 +587,17 @@ def variance_compiler(t, op):
ops.TimestampNow: lambda t, op: sa.literal_column(
"CURRENT_TIMESTAMP", type_=t.get_sqla_type(op.output_dtype)
),
ops.MapGet: fixed_arity(
lambda arg, key, default: sa.case(
(arg.has_key(key), arg[key]), else_=default
),
3,
),
ops.MapContains: fixed_arity(pg.HSTORE.Comparator.has_key, 2),
ops.MapKeys: unary(pg.HSTORE.Comparator.keys),
ops.MapValues: unary(pg.HSTORE.Comparator.vals),
ops.MapMerge: fixed_arity(operator.add, 2),
ops.MapLength: unary(lambda arg: sa.func.cardinality(arg.keys())),
ops.Map: fixed_arity(pg.hstore, 2),
}
)
65 changes: 56 additions & 9 deletions ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
import contextlib

import numpy as np
import pytest
from pytest import param

import ibis
import ibis.expr.datatypes as dt
from ibis.util import guid

pytestmark = [
pytest.mark.never(
["sqlite", "mysql", "mssql", "postgres"], reason="No map support"
["sqlite", "mysql", "mssql"], reason="Unlikely to ever add map support"
),
pytest.mark.notyet(
["bigquery", "impala"], reason="backend doesn't implement map types"
["bigquery", "impala"], reason="Backend doesn't yet implement map types"
),
pytest.mark.notimpl(
["duckdb", "datafusion", "pyspark", "polars"], reason="Not implemented yet"
["duckdb", "datafusion", "pyspark", "polars"],
reason="Not yet implemented in ibis",
),
]


@pytest.mark.notimpl(["pandas", "dask"])
def test_map_table(con):
table = con.table("map")
assert not table.execute().empty
assert table.kv.type().is_map()
assert not table.limit(1).execute().empty


def test_literal_map_keys(con):
Expand All @@ -42,7 +48,7 @@ def test_literal_map_values(con):
assert np.array_equal(result, ['a', 'b'])


@pytest.mark.notimpl(["trino"])
@pytest.mark.notimpl(["trino", "postgres"])
@pytest.mark.notyet(["snowflake"])
def test_scalar_isin_literal_map_keys(con):
mapping = ibis.literal({'a': 1, 'b': 2})
Expand All @@ -54,6 +60,7 @@ def test_scalar_isin_literal_map_keys(con):
assert con.execute(false) == False # noqa: E712


@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string")
def test_map_scalar_contains_key_scalar(con):
mapping = ibis.literal({'a': 1, 'b': 2})
a = ibis.literal('a')
Expand All @@ -74,6 +81,7 @@ def test_map_scalar_contains_key_column(backend, alltypes, df):


@pytest.mark.notyet(["snowflake"])
@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string")
def test_map_column_contains_key_scalar(backend, alltypes, df):
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
series = df.apply(lambda row: {row['string_col']: row['int_col']}, axis=1)
Expand All @@ -85,13 +93,15 @@ def test_map_column_contains_key_scalar(backend, alltypes, df):


@pytest.mark.notyet(["snowflake"])
def test_map_column_contains_key_column(backend, alltypes, df):
@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string")
def test_map_column_contains_key_column(alltypes):
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
result = expr.contains(alltypes.string_col).name('tmp').execute()
assert result.all()


@pytest.mark.notyet(["snowflake"])
@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string")
def test_literal_map_merge(con):
a = ibis.literal({'a': 0, 'b': 2})
b = ibis.literal({'a': 1, 'c': 3})
Expand Down Expand Up @@ -124,15 +134,30 @@ def test_literal_map_get_broadcast(backend, alltypes, df):
backend.assert_series_equal(result, expected)


def test_map_construct_dict(con):
expr = ibis.map(['a', 'b'], [1, 2])
@pytest.mark.parametrize(
("keys", "values"),
[
param(
["a", "b"],
[1, 2],
id="string",
marks=pytest.mark.notyet(
["postgres"], reason="only support maps of string -> string"
),
),
param(["a", "b"], ["1", "2"], id="int"),
],
)
def test_map_construct_dict(con, keys, values):
expr = ibis.map(keys, values)
result = con.execute(expr.name('tmp'))
assert result == {'a': 1, 'b': 2}
assert result == dict(zip(keys, values))


@pytest.mark.notimpl(
["snowflake"], reason="unclear how to implement two arrays -> object construction"
)
@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string")
def test_map_construct_array_column(con, alltypes, df):
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
result = con.execute(expr)
Expand All @@ -141,25 +166,29 @@ def test_map_construct_array_column(con, alltypes, df):
assert result.to_list() == expected.to_list()


@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string")
def test_map_get_with_compatible_value_smaller(con):
value = ibis.literal({'A': 1000, 'B': 2000})
expr = value.get('C', 3)
assert con.execute(expr) == 3


@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string")
def test_map_get_with_compatible_value_bigger(con):
value = ibis.literal({'A': 1, 'B': 2})
expr = value.get('C', 3000)
assert con.execute(expr) == 3000


@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string")
def test_map_get_with_incompatible_value_different_kind(con):
value = ibis.literal({'A': 1000, 'B': 2000})
expr = value.get('C', 3.0)
assert con.execute(expr) == 3.0


@pytest.mark.parametrize('null_value', [None, ibis.NA])
@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string")
def test_map_get_with_null_on_not_nullable(con, null_value):
map_type = dt.Map(dt.string, dt.Int16(nullable=False))
value = ibis.literal({'A': 1000, 'B': 2000}).cast(map_type)
Expand All @@ -174,7 +203,25 @@ def test_map_get_with_null_on_null_type_with_null(con, null_value):
assert con.execute(expr) is None


@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string")
def test_map_get_with_null_on_null_type_with_non_null(con):
value = ibis.literal({'A': None, 'B': None})
expr = value.get('C', 1)
assert con.execute(expr) == 1


@pytest.fixture
def tmptable(con):
name = guid()
yield name

# some backends don't implement drop
with contextlib.suppress(NotImplementedError):
con.drop_table(name)


@pytest.mark.notimpl(["clickhouse"], reason=".create_table not yet implemented in ibis")
def test_map_create_table(con, tmptable):
con.create_table(tmptable, schema=ibis.schema(dict(xyz="map<string, string>")))
t = con.table(tmptable)
assert t.schema()["xyz"].is_map()
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_scalar_param_struct(con):
["mysql", "sqlite", "mssql"],
reason="mysql and sqlite will never implement map types",
)
@pytest.mark.notyet(["bigquery", "postgres"])
@pytest.mark.notyet(["bigquery"])
def test_scalar_param_map(con):
value = {'a': 'ghi', 'b': 'def', 'c': 'abc'}
param = ibis.param(dt.Map(dt.string, dt.string))
Expand Down

0 comments on commit f968f8f

Please sign in to comment.