Skip to content

Commit

Permalink
refactor(tests): simplify pattern matching tests on Value operations
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Aug 11, 2023
1 parent 3cbe2f3 commit d87e65a
Showing 1 changed file with 26 additions and 44 deletions.
70 changes: 26 additions & 44 deletions ibis/expr/operations/tests/test_generic.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

from typing import Union

import pytest

import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.patterns import CoercedTo, GenericCoercedTo, NoMatch, Pattern
from ibis.common.patterns import NoMatch, match


@pytest.mark.parametrize(
Expand All @@ -25,72 +27,52 @@ def test_literal_coercion_type_inference(value, dtype):


def test_coerced_to_literal():
p = CoercedTo(ops.Literal)
one = ops.Literal(1, dt.int8)
assert p.match(ops.Literal(1, dt.int8), {}) == one
assert p.match(1, {}) == one
assert p.match(False, {}) == ops.Literal(False, dt.boolean)

p = GenericCoercedTo(ops.Literal[dt.Int8])
assert p.match(ops.Literal(1, dt.int8), {}) == one
assert match(ops.Literal, 1) == one
assert match(ops.Literal, one) == one
assert match(ops.Literal, False) == ops.Literal(False, dt.boolean)

p = Pattern.from_typehint(ops.Literal[dt.Int8])
assert p == GenericCoercedTo(ops.Literal[dt.Int8])
assert match(ops.Literal[dt.Int8], 1) == one
assert match(ops.Literal[dt.Int16], 1) == ops.Literal(1, dt.int16)

one = ops.Literal(1, dt.int16)
assert p.match(one, {}) is NoMatch
assert match(ops.Literal[dt.Int8], ops.Literal(1, dt.int16)) is NoMatch


def test_coerced_to_value():
one = ops.Literal(1, dt.int8)

p = Pattern.from_typehint(ops.Value)
assert p.match(1, {}) == one

p = Pattern.from_typehint(ops.Value[dt.Int8, ds.Any])
assert p.match(1, {}) == one

p = Pattern.from_typehint(ops.Value[dt.Int8, ds.Scalar])
assert p.match(1, {}) == one

p = Pattern.from_typehint(ops.Value[dt.Int8, ds.Columnar])
assert p.match(1, {}) is NoMatch
assert match(ops.Value, 1) == one
assert match(ops.Value[dt.Int8], 1) == one
assert match(ops.Value[dt.Int8, ds.Any], 1) == one
assert match(ops.Value[dt.Int8, ds.Scalar], 1) == one
assert match(ops.Value[dt.Int8, ds.Columnar], 1) is NoMatch

# dt.Integer is not instantiable so it will be only used for checking
# that the produced literal has any integer datatype
p = Pattern.from_typehint(ops.Value[dt.Integer, ds.Any])
assert p.match(1, {}) == one
assert match(ops.Value[dt.Integer], 1) == one

# same applies here, the coercion itself will use only the inferred datatype
# but then the result is checked against the given typehint
p = Pattern.from_typehint(ops.Value[dt.Int8 | dt.Int16, ds.Any])
assert p.match(1, {}) == one
assert p.match(128, {}) == ops.Literal(128, dt.int16)
assert match(ops.Value[dt.Int8 | dt.Int16], 1) == one
assert match(ops.Value[dt.Int8 | dt.Int16], 128) == ops.Literal(128, dt.int16)
assert match(ops.Value[dt.Int8], 128) is NoMatch

p1 = Pattern.from_typehint(ops.Value[dt.Int8, ds.Any])
p2 = Pattern.from_typehint(ops.Value[dt.Int16, ds.Scalar])
assert p1.match(1, {}) == one
# this is actually supported by creating an explicit dtype
# in Value.__coerce__ based on the `T` keyword argument
assert p2.match(1, {}) == ops.Literal(1, dt.int16)
assert p2.match(128, {}) == ops.Literal(128, dt.int16)
assert match(ops.Value[dt.Int16, ds.Scalar], 1) == ops.Literal(1, dt.int16)
assert match(ops.Value[dt.Int16, ds.Scalar], 128) == ops.Literal(128, dt.int16)

p = p1 | p2
assert p.match(1, {}) == one
# equivalent with ops.Value[dt.Int8 | dt.Int16]
assert match(Union[ops.Value[dt.Int8], ops.Value[dt.Int16]], 1) == one


@pytest.mark.pandas
def test_coerced_to_interval_value():
import pandas as pd

p = Pattern.from_typehint(ops.Value[dt.Interval, ds.Any])

value = pd.Timedelta("1s")
result = p.match(value, {})
assert result.value == 1
assert result.dtype == dt.Interval("s")
expected = ops.Literal(1, dt.Interval("s"))
assert match(ops.Value[dt.Interval], pd.Timedelta("1s")) == expected

value = pd.Timedelta("1h 1m 1s")
result = p.match(value, {})
assert result.value == 3661
assert result.dtype == dt.Interval("s")
expected = ops.Literal(3661, dt.Interval("s"))
assert match(ops.Value[dt.Interval], pd.Timedelta("1h 1m 1s")) == expected

0 comments on commit d87e65a

Please sign in to comment.