Skip to content

Commit

Permalink
feat: support all dtypes in MapGet and MapContains
Browse files Browse the repository at this point in the history
  • Loading branch information
NickCrews committed Mar 14, 2024
1 parent 461293b commit 6678570
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 2 deletions.
176 changes: 176 additions & 0 deletions ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,182 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df):
backend.assert_series_equal(result, expected)


@pytest.fixture(
params=[
pytest.param(["a", "b"], id="string"),
pytest.param(
[1, 2],
marks=pytest.mark.notyet(
["postgres", "risingwave"], reason="only supports string -> string"
),
id="int",
),
pytest.param(
[True, False],
marks=pytest.mark.notyet(
["postgres", "risingwave"], reason="only supports string -> string"
),
id="bool",
),
pytest.param(
[1.0, 2.0],
marks=[
pytest.mark.notyet(
"clickhouse", reason="only supports str,int,bool,timestamp keys"
),
pytest.mark.notyet(
["postgres", "risingwave"], reason="only supports string -> string"
),
],
id="float",
),
pytest.param(
[ibis.timestamp("2021-01-01"), ibis.timestamp("2021-01-02")],
marks=pytest.mark.notyet(
["postgres", "risingwave"], reason="only supports string -> string"
),
id="timestamp",
),
pytest.param(
[ibis.date(1, 2, 3), ibis.date(4, 5, 6)],
marks=[
pytest.mark.notyet(
"clickhouse", reason="only supports str,int,bool,timestamp keys"
),
pytest.mark.notimpl(
["pandas", "dask"], reason="DateFromYMD isn't implemented"
),
pytest.mark.notyet(
["postgres", "risingwave"], reason="only supports string -> string"
),
],
id="date",
),
pytest.param(
[[1, 2], [3, 4]],
marks=[
pytest.mark.notyet(
"clickhouse", reason="only supports str,int,bool,timestamp keys"
),
pytest.mark.notyet(["pandas", "dask"]),
pytest.mark.notyet(
["postgres", "risingwave"], reason="only supports string -> string"
),
],
id="array",
),
pytest.param(
[ibis.struct(dict(a=1)), ibis.struct(dict(a=2))],
marks=[
pytest.mark.notyet(
"clickhouse", reason="only supports str,int,bool,timestamp keys"
),
pytest.mark.notyet(["pandas", "dask"]),
pytest.mark.notyet(
["postgres", "risingwave"], reason="only supports string -> string"
),
],
id="struct",
),
]
)
def keys(request):
return request.param


@pytest.fixture(
params=[
pytest.param(["a", "b"], id="string"),
pytest.param(
[1, 2],
marks=pytest.mark.notyet(
["postgres", "risingwave"], reason="only supports string -> string"
),
id="int",
),
pytest.param(
[True, False],
marks=pytest.mark.notyet(
["postgres", "risingwave"], reason="only supports string -> string"
),
id="bool",
),
pytest.param(
[1.0, 2.0],
marks=pytest.mark.notyet(
["postgres", "risingwave"], reason="only supports string -> string"
),
id="float",
),
pytest.param(
[ibis.timestamp("2021-01-01"), ibis.timestamp("2021-01-02")],
marks=pytest.mark.notyet(
["postgres", "risingwave"], reason="only supports string -> string"
),
id="timestamp",
),
pytest.param(
[ibis.date(2021, 1, 1), ibis.date(2022, 2, 2)],
marks=[
pytest.mark.notimpl(
["pandas", "dask"], reason="DateFromYMD isn't implemented"
),
pytest.mark.notyet(
["postgres", "risingwave"], reason="only supports string -> string"
),
],
id="date",
),
pytest.param(
[[1, 2], [3, 4]],
marks=[
pytest.mark.notyet("clickhouse", reason="nested types can't be null"),
pytest.mark.notyet(
["postgres", "risingwave"], reason="only supports string -> string"
),
],
id="array",
),
pytest.param(
[ibis.struct(dict(a=1)), ibis.struct(dict(a=2))],
marks=[
pytest.mark.notyet("clickhouse", reason="nested types can't be null"),
pytest.mark.notyet(
["postgres", "risingwave"], reason="only supports string -> string"
),
],
id="struct",
),
]
)
def values(request):
return request.param


def test_map_get_all_types(con, keys, values):
m = ibis.map(ibis.array(keys), ibis.array(values))
for key, val in zip(keys, values):
if isinstance(val, ibis.Expr):
val = con.execute(val)
assert con.execute(m[key]) == val


def test_map_contains_all_types(con, keys):
m = ibis.map(ibis.array(keys), ibis.array(keys))
for key in keys:
assert con.execute(m.contains(key))


# TODO: this should actually error: https://github.com/ibis-project/ibis/issues/8605
def test_map_get_bad_key_type():
ibis.map({"1": "a"})[2]

Check warning on line 357 in ibis/backends/tests/test_map.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/tests/test_map.py#L357

Added line #L357 was not covered by tests


# TODO: this should actually error: https://github.com/ibis-project/ibis/issues/8605
def test_map_contains_bad_key_type():
ibis.map({"1": "a"}).contains(2)

Check warning on line 362 in ibis/backends/tests/test_map.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/tests/test_map.py#L362

Added line #L362 was not covered by tests


@pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/operations/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class MapLength(Unary):
@public
class MapGet(Value):
arg: Value[dt.Map]
key: Value[dt.String | dt.Integer]
key: Value
default: Value = None

shape = rlz.shape_like("args")
Expand All @@ -45,7 +45,7 @@ def dtype(self):
@public
class MapContains(Value):
arg: Value[dt.Map]
key: Value[dt.String | dt.Integer]
key: Value

shape = rlz.shape_like("args")
dtype = dt.bool
Expand Down

0 comments on commit 6678570

Please sign in to comment.