Skip to content

Commit

Permalink
refactor(ir): make schema annotable
Browse files Browse the repository at this point in the history
BREAKING CHANGE: Schema.names and Schema.types attributes now have tuple type rather than list
  • Loading branch information
kszucs committed Mar 29, 2022
1 parent c7a69cd commit b980903
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 69 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_get_schema(con):
def test_result_as_dataframe(con, alltypes):
expr = alltypes.limit(10)

ex_names = expr.schema().names
ex_names = list(expr.schema().names)
result = con.execute(expr)

assert isinstance(result, pd.DataFrame)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/dask/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def execute_and_reset(
if isinstance(result, dd.DataFrame):
schema = expr.schema()
df = result.reset_index()
return df[schema.names]
return df[list(schema.names)]
elif isinstance(result, dd.Series):
return result.reset_index(drop=True)
return result
2 changes: 1 addition & 1 deletion ibis/backends/impala/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_get_schema(con, test_data_db):
def test_result_as_dataframe(con, alltypes):
expr = alltypes.limit(10)

ex_names = expr.schema().names
ex_names = list(expr.schema().names)
result = con.execute(expr)

assert isinstance(result, pd.DataFrame)
Expand Down
77 changes: 35 additions & 42 deletions ibis/expr/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import abc
import collections
from typing import Sequence

from multipledispatch import Dispatcher

from ..common.exceptions import IntegrityError
from ..common.grounds import Comparable
from ..common.grounds import Annotable, Comparable
from ..common.validators import instance_of, tuple_of, validator
from ..expr import datatypes as dt
from ..util import indent
from ..util import UnnamedMarker, indent

convert = Dispatcher(
'convert',
Expand All @@ -33,7 +35,12 @@
)


class Schema(Comparable):
@validator
def datatype(arg, **kwargs):
return dt.dtype(arg)


class Schema(Annotable, Comparable):

"""An object for holding table schema information, i.e., column names and
types.
Expand All @@ -47,25 +54,26 @@ class Schema(Comparable):
representing type of each column.
"""

__slots__ = 'names', 'types', '_name_locs', '_hash'
__slots__ = ('_name_locs',)

def __init__(self, names, types):
if not isinstance(names, list):
names = list(names)
names: Sequence[str] = tuple_of(instance_of((str, UnnamedMarker)))
types: Sequence[dt.DataType] = tuple_of(datatype)

self.names = names
self.types = list(map(dt.dtype, types))
def __post_init__(self):
super().__post_init__()

self._name_locs = {v: i for i, v in enumerate(self.names)}

if len(self._name_locs) < len(self.names):
# validate unique field names
name_locs = {v: i for i, v in enumerate(self.names)}
if len(name_locs) < len(self.names):
duplicate_names = list(self.names)
for v in self._name_locs.keys():
for v in name_locs.keys():
duplicate_names.remove(v)
raise IntegrityError(
f'Duplicate column name(s): {duplicate_names}'
)
self._hash = None

# store field positions
object.__setattr__(self, '_name_locs', name_locs)

def __repr__(self):
space = 2 + max(map(len, self.names), default=0)
Expand All @@ -79,14 +87,6 @@ def __repr__(self):
)
)

def _make_hash(self) -> int:
return hash((type(self), tuple(self.names), tuple(self.types)))

def __hash__(self) -> int:
if (result := self._hash) is None:
result = self._hash = self._make_hash()
return result

def __len__(self):
return len(self.names)

Expand All @@ -99,12 +99,20 @@ def __contains__(self, name):
def __getitem__(self, name):
return self.types[self._name_locs[name]]

def __getstate__(self):
return {slot: getattr(self, slot) for slot in self.__class__.__slots__}
def __equals__(self, other):
return (
self._hash == other._hash
and self.names == other.names
and self.types == other.types
)

def __setstate__(self, instance_dict):
for key, value in instance_dict.items():
setattr(self, key, value)
def equals(self, other):
if not isinstance(other, Schema):
raise TypeError(
"invalid equality comparison between Schema and "
f"{type(other)}"
)
return self.__cached_equals__(other)

def delete(self, names_to_delete):
for name in names_to_delete:
Expand Down Expand Up @@ -133,21 +141,6 @@ def from_dict(cls, dictionary):
names, types = zip(*dictionary.items()) if dictionary else ([], [])
return Schema(names, types)

def __equals__(self, other: Schema) -> bool:
return (
self.names == other.names
and len(self.types) == len(other.types)
and all(a.equals(b) for a, b in zip(self.types, other.types))
)

def equals(self, other):
if not isinstance(other, Schema):
raise TypeError(
"invalid equality comparison between Schema and "
f"{type(other)}"
)
return self.__cached_equals__(other)

def __gt__(self, other):
return set(self.items()) > set(other.items())

Expand Down
33 changes: 16 additions & 17 deletions ibis/expr/types/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
from cached_property import cached_property
from public import public

import ibis
import ibis.common.exceptions as com
import ibis.config as config
import ibis.util as util
from ibis.config import options
from ibis.expr.typing import TimeContext
from ... import config
from ...common.exceptions import (
ExpressionError,
IbisError,
IbisTypeError,
TranslationError,
)
from ...expr.typing import TimeContext
from ...util import UnnamedMarker, deprecated

if TYPE_CHECKING:
from ...backends.base import BaseBackend
Expand All @@ -35,7 +38,7 @@ def __repr__(self) -> str:

try:
result = self.execute()
except com.TranslationError as e:
except TranslationError as e:
lines = [
"Translation to backend failed",
f"Error message: {e.args[0]}",
Expand Down Expand Up @@ -79,7 +82,7 @@ def _safe_name(self) -> str | None:
"""
try:
return self.get_name()
except (com.ExpressionError, AttributeError):
except (ExpressionError, AttributeError):
return None

@property
Expand All @@ -94,7 +97,7 @@ def _key(self) -> tuple[Hashable, ...]:
return type(self), self._safe_name, self.op()

def _repr_png_(self) -> bytes | None:
if config.options.interactive or not ibis.options.graphviz_repr:
if config.options.interactive or not config.options.graphviz_repr:
return None
try:
import ibis.expr.visualize as viz
Expand Down Expand Up @@ -220,9 +223,9 @@ def _find_backend(self) -> BaseBackend:
backends = self._find_backends()

if not backends:
default = options.default_backend
default = config.options.default_backend
if default is None:
raise com.IbisError(
raise IbisError(
'Expression depends on no backends, and found no default'
)
return default
Expand Down Expand Up @@ -289,7 +292,7 @@ def compile(
self, limit=limit, timecontext=timecontext, params=params
)

@util.deprecated(
@deprecated(
version='2.0',
instead=(
"call [`Expr.compile`][ibis.expr.types.core.Expr.compile] and "
Expand All @@ -306,10 +309,6 @@ def verify(self):
return True


class UnnamedMarker:
pass


unnamed = UnnamedMarker()


Expand Down Expand Up @@ -346,7 +345,7 @@ def _binop(
"""
try:
node = op_class(left, right)
except (com.IbisTypeError, NotImplementedError):
except (IbisTypeError, NotImplementedError):
return NotImplemented
else:
return node.to_expr()
2 changes: 1 addition & 1 deletion ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def get_column(self, name: str) -> ColumnExpr:

@cached_property
def columns(self):
return self.schema().names
return list(self.schema().names)

def schema(self) -> sch.Schema:
"""Return the table's schema.
Expand Down
12 changes: 6 additions & 6 deletions ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_projection(table):
assert isinstance(proj, TableExpr)
assert isinstance(proj.op(), ops.Selection)

assert proj.schema().names == cols
assert proj.schema().names == tuple(cols)
for c in cols:
expr = proj[c]
assert isinstance(expr, type(table[c]))
Expand All @@ -133,8 +133,8 @@ def test_projection_with_exprs(table):

proj = table[col_exprs + ['g']]
schema = proj.schema()
assert schema.names == ['log_b', 'mean_diff', 'g']
assert schema.types == [dt.double, dt.double, dt.string]
assert schema.names == ('log_b', 'mean_diff', 'g')
assert schema.types == (dt.double, dt.double, dt.string)

# Test with unnamed expr
with pytest.raises(ExpressionError):
Expand Down Expand Up @@ -178,7 +178,7 @@ def test_projection_with_star_expr(table):
proj = t[t, new_expr]
repr(proj)

ex_names = table.schema().names + ['bigger_a']
ex_names = table.schema().names + ('bigger_a',)
assert proj.schema().names == ex_names

# cannot pass an invalid table expression
Expand Down Expand Up @@ -918,10 +918,10 @@ def test_join_project_after(table):

joined = table1.left_join(table2, [pred])
projected = joined.projection([table1, table2['stuff']])
assert projected.schema().names == ['key1', 'value1', 'stuff']
assert projected.schema().names == ('key1', 'value1', 'stuff')

projected = joined.projection([table2, table1['key1']])
assert projected.schema().names == ['key2', 'stuff', 'key1']
assert projected.schema().names == ('key2', 'stuff', 'key1')


def test_semi_join_schema(table):
Expand Down
4 changes: 4 additions & 0 deletions ibis/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def __hash__(self):
return self._hash


class UnnamedMarker:
pass


def guid() -> str:
"""Return a uuid4 hexadecimal value.
Expand Down

0 comments on commit b980903

Please sign in to comment.