Skip to content

Commit

Permalink
feat(common): add abstract mapping collection with support for set op…
Browse files Browse the repository at this point in the history
…erations
  • Loading branch information
kszucs authored and cpcloud committed Feb 27, 2023
1 parent d6340b6 commit 7d4aa0f
Show file tree
Hide file tree
Showing 12 changed files with 291 additions and 24 deletions.
3 changes: 3 additions & 0 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,9 @@ def __getstate__(self):
_con_kwargs=self._con_kwargs,
)

def __rich_repr__(self):
yield "name", self.name

def __hash__(self):
return hash(self.db_identity)

Expand Down
73 changes: 73 additions & 0 deletions ibis/common/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

from collections.abc import Mapping

from typing_extensions import Self


class MapSet(Mapping):
def _check_conflict(self, other):
common_keys = self.keys() & other.keys()
for key in common_keys:
left, right = self[key], other[key]
if left != right:
raise ValueError(
f"Conflicting values for key `{key}`: {left} != {right}"
)
return common_keys

def __ge__(self, other: Self) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
common_keys = self._check_conflict(other)
return other.keys() == common_keys

def __gt__(self, other: Self) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return len(self) > len(other) and self.__ge__(other)

def __le__(self, other: Self) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
common_keys = self._check_conflict(other)
return self.keys() == common_keys

def __lt__(self, other: Self) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return len(self) < len(other) and self.__le__(other)

def __and__(self, other: Self) -> Self:
if not isinstance(other, self.__class__):
return NotImplemented
common_keys = self._check_conflict(other)
intersection = {k: v for k, v in self.items() if k in common_keys}
return self.__class__(intersection)

def __sub__(self, other: Self) -> Self:
if not isinstance(other, self.__class__):
return NotImplemented
common_keys = self._check_conflict(other)
difference = {k: v for k, v in self.items() if k not in common_keys}
return self.__class__(difference)

def __or__(self, other: Self) -> Self:
if not isinstance(other, self.__class__):
return NotImplemented
self._check_conflict(other)
union = {**self, **other}
return self.__class__(union)

def __xor__(self, other: Self) -> Self:
if not isinstance(other, self.__class__):
return NotImplemented
left = self - other
right = other - self
left._check_conflict(right)
union = {**left, **right}
return self.__class__(union)

def isdisjoint(self, other: Self) -> bool:
common_keys = self._check_conflict(other)
return not common_keys
130 changes: 130 additions & 0 deletions ibis/common/tests/test_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from collections.abc import ItemsView, Iterator, KeysView, ValuesView

import pytest

from ibis.common.collections import MapSet


class MySchema(MapSet):
def __init__(self, dct=None, **kwargs):
self._fields = dict(dct or kwargs)

def __repr__(self):
return f'{self.__class__.__name__}({self._fields})'

def __getitem__(self, key):
return self._fields[key]

def __iter__(self):
return iter(self._fields)

def __len__(self):
return len(self._fields)

def identical(self, other):
return type(self) == type(other) and tuple(self.items()) == tuple(other.items())


def test_myschema_identical():
ms1 = MySchema(a=1, b=2)
ms2 = MySchema(a=1, b=2)
ms3 = MySchema(b=2, a=1)
ms4 = MySchema(a=1, b=2, c=3)
ms5 = {}

assert ms1.identical(ms2)
assert not ms1.identical(ms3)
assert not ms1.identical(ms4)
assert not ms1.identical(ms5)


def test_mapset_mapping_api():
ms = MySchema(a=1, b=2)
assert ms['a'] == 1
assert ms['b'] == 2
assert len(ms) == 2
assert isinstance(iter(ms), Iterator)
assert list(ms) == ['a', 'b']
assert isinstance(ms.keys(), KeysView)
assert list(ms.keys()) == ['a', 'b']
assert isinstance(ms.values(), ValuesView)
assert list(ms.values()) == [1, 2]
assert isinstance(ms.items(), ItemsView)
assert list(ms.items()) == [('a', 1), ('b', 2)]
assert ms.get('a') == 1
assert ms.get('c') is None
assert ms.get('c', 3) == 3
assert 'a' in ms
assert 'c' not in ms
assert ms == ms
assert ms != MySchema(a=1, b=2, c=3)


def test_mapset_set_api():
a = MySchema(a=1, b=2)
a_ = MySchema(a=1, b=-2)
b = MySchema(a=1, b=2, c=3)
b_ = MySchema(a=1, b=2, c=-3)
f = MySchema(d=4, e=5)

# disjoint
assert not a.isdisjoint(b)
assert a.isdisjoint(f)

# __eq__, __ne__
assert a == a
assert a != a_
assert b == b
assert b != b_

# __le__, __lt__
assert a < b
assert a <= b
assert a <= a
assert not b <= a
assert not b < a
with pytest.raises(ValueError, match="Conflicting values"):
# duplicate keys with different values
a <= a_ # noqa: B015

# __gt__, __ge__
assert b > a
assert b >= a
assert a >= a
assert not a >= b
assert not a > b
assert not a_ > a
with pytest.raises(ValueError, match="Conflicting values"):
a_ >= a # noqa: B015

# __and__
with pytest.raises(ValueError, match="Conflicting values"):
a & a_
with pytest.raises(ValueError, match="Conflicting values"):
b & b_
assert (a & b).identical(a)
assert (a & f).identical(MySchema())

# __or__
assert (a | a).identical(a)
assert (a | b).identical(b)
assert (a | f).identical(MySchema(a=1, b=2, d=4, e=5))
with pytest.raises(ValueError, match="Conflicting values"):
a | a_

# __sub__
with pytest.raises(ValueError, match="Conflicting values"):
a - a_
assert (a - b).identical(MySchema())
assert (b - a).identical(MySchema(c=3))
assert (a - f).identical(a)
assert (f - a).identical(f)

# __xor__
with pytest.raises(ValueError, match="Conflicting values"):
a ^ a_

assert (a ^ b).identical(MySchema(c=3))
assert (b ^ a).identical(MySchema(c=3))
assert (a ^ f).identical(MySchema(a=1, b=2, d=4, e=5))
assert (f ^ a).identical(MySchema(d=4, e=5, a=1, b=2))
3 changes: 2 additions & 1 deletion ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import ibis.expr.types as ir
from ibis.common.annotations import attribute, optional
from ibis.common.collections import MapSet
from ibis.common.grounds import Concrete, Singleton
from ibis.common.validators import (
all_of,
Expand Down Expand Up @@ -675,7 +676,7 @@ def _pretty_piece(self) -> str:


@public
class Struct(Parametric, Mapping):
class Struct(Parametric, MapSet):
"""Structured values."""

fields = frozendict_of(instance_of(str), datatype)
Expand Down
31 changes: 31 additions & 0 deletions ibis/expr/datatypes/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,37 @@ def test_struct_mapping_api():
s['e'] = dt.int8


def test_struct_set_operations():
a = dt.Struct({'a': dt.string, 'b': dt.int64, 'c': dt.float64})
b = dt.Struct({'a': dt.string, 'c': dt.float64, 'd': dt.boolean, 'e': dt.date})
c = dt.Struct({'i': dt.int64, 'j': dt.float64, 'k': dt.string})
d = dt.Struct({'i': dt.int64, 'j': dt.float64, 'k': dt.string, 'l': dt.boolean})

assert a & b == dt.Struct({'a': dt.string, 'c': dt.float64})
assert a | b == dt.Struct(
{'a': dt.string, 'b': dt.int64, 'c': dt.float64, 'd': dt.boolean, 'e': dt.date}
)
assert a - b == dt.Struct({'b': dt.int64})
assert b - a == dt.Struct({'d': dt.boolean, 'e': dt.date})
assert a ^ b == dt.Struct({'b': dt.int64, 'd': dt.boolean, 'e': dt.date})

assert not a.isdisjoint(b)
assert a.isdisjoint(c)

assert a <= a
assert a >= a
assert not a < a
assert not a > a
assert not a <= b
assert not a >= b
assert not a >= c
assert not a <= c
assert c <= d
assert c < d
assert d >= c
assert d > c


def test_singleton_null():
assert dt.null is dt.Null()

Expand Down
3 changes: 3 additions & 0 deletions ibis/expr/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def to_expr(self):
# Avoid custom repr for performance reasons
__repr__ = object.__repr__

def __rich_repr__(self):
return zip(self.__argnames__, self.__args__)


@public
class Named(ABC):
Expand Down
15 changes: 7 additions & 8 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,6 @@ class InMemoryTable(PhysicalTable):
def data(self) -> util.ToFrame:
"""Return the data of an in-memory table."""

def has_resolved_name(self):
return True

def resolve_name(self):
return self.name


# TODO(kszucs): desperately need to clean this up, the majority of this
# functionality should be handled by input rules for the Join class
Expand Down Expand Up @@ -176,8 +170,13 @@ def __init__(self, left, right, predicates, **kwargs):

@property
def schema(self):
# For joins retaining both table schemas, merge them together here
return self.left.schema.merge(self.right.schema)
# TODO(kszucs): use `return self.lefts.chema | self.right.schema` instead which
# eliminates unnecessary projection over the join, but currently breaks the
# pandas backend
left, right = self.left.schema, self.right.schema
if duplicates := left.keys() & right.keys():
raise com.IntegrityError(f'Duplicate column name(s): {duplicates}')
return sch.Schema({**left, **right})


@public
Expand Down
3 changes: 0 additions & 3 deletions ibis/expr/operations/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ def following(self) -> bool:
return not self.preceding


# perhaps have separate window frames, RowsWindowFrame and RangeWindowFrame


@public
class WindowFrame(Value):
"""A window frame operation bound to a table."""
Expand Down
18 changes: 8 additions & 10 deletions ibis/expr/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import ibis.expr.datatypes as dt
from ibis.common.annotations import attribute
from ibis.common.collections import MapSet
from ibis.common.exceptions import IntegrityError
from ibis.common.grounds import Concrete
from ibis.common.validators import Coercible, frozendict_of, instance_of, validator
from ibis.util import indent
from ibis.util import deprecated, indent

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -43,7 +44,7 @@ def datatype(arg, **kwargs):
return dt.dtype(arg)


class Schema(Concrete, Mapping, Coercible):
class Schema(Concrete, Coercible, MapSet):
"""An object for holding table schema information."""

fields = frozendict_of(instance_of(str), datatype)
Expand All @@ -61,6 +62,10 @@ def __repr__(self) -> str:
)
)

def __rich_repr__(self):
for name, dtype in self.items():
yield name, str(dtype)

def __len__(self) -> int:
return len(self.fields)

Expand Down Expand Up @@ -152,14 +157,7 @@ def to_pyarrow(self):
def as_struct(self) -> dt.Struct:
return dt.Struct(self)

def __gt__(self, other: Schema) -> bool:
"""Return whether `self` is a strict superset of `other`."""
return set(self.items()) > set(other.items())

def __ge__(self, other: Schema) -> bool:
"""Return whether `self` is a superset of or equal to `other`."""
return set(self.items()) >= set(other.items())

@deprecated(as_of="5.0", removed_in="6.0", instead="use union operator instead")
def merge(self, other: Schema) -> Schema:
"""Merge `other` to `self`.
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/tests/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def test_table_with_schema(table):
)
def test_table_with_schema_invalid(table):
validator = rlz.table(schema=[('group', dt.double), ('value', dt.timestamp)])
with pytest.raises(IbisTypeError):
with pytest.raises(ValueError):
validator(table)


Expand Down
Loading

0 comments on commit 7d4aa0f

Please sign in to comment.