diff --git a/ibis/backends/clickhouse/tests/test_client.py b/ibis/backends/clickhouse/tests/test_client.py index 830aeedd1f0b..5bab2d078695 100644 --- a/ibis/backends/clickhouse/tests/test_client.py +++ b/ibis/backends/clickhouse/tests/test_client.py @@ -31,7 +31,7 @@ def test_get_schema(con): def test_result_as_dataframe(con, alltypes): expr = alltypes.limit(10) - ex_names = expr.schema().names + ex_names = list(expr.schema().names) result = con.execute(expr) assert isinstance(result, pd.DataFrame) diff --git a/ibis/backends/dask/core.py b/ibis/backends/dask/core.py index 99b35807a9da..b9f00582d0cc 100644 --- a/ibis/backends/dask/core.py +++ b/ibis/backends/dask/core.py @@ -460,7 +460,7 @@ def execute_and_reset( if isinstance(result, dd.DataFrame): schema = expr.schema() df = result.reset_index() - return df[schema.names] + return df[list(schema.names)] elif isinstance(result, dd.Series): return result.reset_index(drop=True) return result diff --git a/ibis/backends/impala/tests/test_client.py b/ibis/backends/impala/tests/test_client.py index 782c9b1548f0..95b53aa8cc59 100644 --- a/ibis/backends/impala/tests/test_client.py +++ b/ibis/backends/impala/tests/test_client.py @@ -94,7 +94,7 @@ def test_get_schema(con, test_data_db): def test_result_as_dataframe(con, alltypes): expr = alltypes.limit(10) - ex_names = expr.schema().names + ex_names = list(expr.schema().names) result = con.execute(expr) assert isinstance(result, pd.DataFrame) diff --git a/ibis/expr/schema.py b/ibis/expr/schema.py index 93ed9a1e6d96..1fa7dddcfef1 100644 --- a/ibis/expr/schema.py +++ b/ibis/expr/schema.py @@ -2,13 +2,15 @@ import abc import collections +from typing import Sequence from multipledispatch import Dispatcher from ..common.exceptions import IntegrityError -from ..common.grounds import Comparable +from ..common.grounds import Annotable, Comparable +from ..common.validators import instance_of, tuple_of, validator from ..expr import datatypes as dt -from ..util import indent +from ..util import UnnamedMarker, indent convert = Dispatcher( 'convert', @@ -33,7 +35,12 @@ ) -class Schema(Comparable): +@validator +def datatype(arg, **kwargs): + return dt.dtype(arg) + + +class Schema(Annotable, Comparable): """An object for holding table schema information, i.e., column names and types. @@ -47,25 +54,26 @@ class Schema(Comparable): representing type of each column. """ - __slots__ = 'names', 'types', '_name_locs', '_hash' + __slots__ = ('_name_locs',) - def __init__(self, names, types): - if not isinstance(names, list): - names = list(names) + names: Sequence[str] = tuple_of(instance_of((str, UnnamedMarker))) + types: Sequence[dt.DataType] = tuple_of(datatype) - self.names = names - self.types = list(map(dt.dtype, types)) + def __post_init__(self): + super().__post_init__() - self._name_locs = {v: i for i, v in enumerate(self.names)} - - if len(self._name_locs) < len(self.names): + # validate unique field names + name_locs = {v: i for i, v in enumerate(self.names)} + if len(name_locs) < len(self.names): duplicate_names = list(self.names) - for v in self._name_locs.keys(): + for v in name_locs.keys(): duplicate_names.remove(v) raise IntegrityError( f'Duplicate column name(s): {duplicate_names}' ) - self._hash = None + + # store field positions + object.__setattr__(self, '_name_locs', name_locs) def __repr__(self): space = 2 + max(map(len, self.names), default=0) @@ -79,14 +87,6 @@ def __repr__(self): ) ) - def _make_hash(self) -> int: - return hash((type(self), tuple(self.names), tuple(self.types))) - - def __hash__(self) -> int: - if (result := self._hash) is None: - result = self._hash = self._make_hash() - return result - def __len__(self): return len(self.names) @@ -99,12 +99,20 @@ def __contains__(self, name): def __getitem__(self, name): return self.types[self._name_locs[name]] - def __getstate__(self): - return {slot: getattr(self, slot) for slot in self.__class__.__slots__} + def __equals__(self, other): + return ( + self._hash == other._hash + and self.names == other.names + and self.types == other.types + ) - def __setstate__(self, instance_dict): - for key, value in instance_dict.items(): - setattr(self, key, value) + def equals(self, other): + if not isinstance(other, Schema): + raise TypeError( + "invalid equality comparison between Schema and " + f"{type(other)}" + ) + return self.__cached_equals__(other) def delete(self, names_to_delete): for name in names_to_delete: @@ -133,21 +141,6 @@ def from_dict(cls, dictionary): names, types = zip(*dictionary.items()) if dictionary else ([], []) return Schema(names, types) - def __equals__(self, other: Schema) -> bool: - return ( - self.names == other.names - and len(self.types) == len(other.types) - and all(a.equals(b) for a, b in zip(self.types, other.types)) - ) - - def equals(self, other): - if not isinstance(other, Schema): - raise TypeError( - "invalid equality comparison between Schema and " - f"{type(other)}" - ) - return self.__cached_equals__(other) - def __gt__(self, other): return set(self.items()) > set(other.items()) diff --git a/ibis/expr/types/core.py b/ibis/expr/types/core.py index f92fb5f57b92..b30d857171b4 100644 --- a/ibis/expr/types/core.py +++ b/ibis/expr/types/core.py @@ -7,12 +7,15 @@ from cached_property import cached_property from public import public -import ibis -import ibis.common.exceptions as com -import ibis.config as config -import ibis.util as util -from ibis.config import options -from ibis.expr.typing import TimeContext +from ... import config +from ...common.exceptions import ( + ExpressionError, + IbisError, + IbisTypeError, + TranslationError, +) +from ...expr.typing import TimeContext +from ...util import UnnamedMarker, deprecated if TYPE_CHECKING: from ...backends.base import BaseBackend @@ -35,7 +38,7 @@ def __repr__(self) -> str: try: result = self.execute() - except com.TranslationError as e: + except TranslationError as e: lines = [ "Translation to backend failed", f"Error message: {e.args[0]}", @@ -79,7 +82,7 @@ def _safe_name(self) -> str | None: """ try: return self.get_name() - except (com.ExpressionError, AttributeError): + except (ExpressionError, AttributeError): return None @property @@ -94,7 +97,7 @@ def _key(self) -> tuple[Hashable, ...]: return type(self), self._safe_name, self.op() def _repr_png_(self) -> bytes | None: - if config.options.interactive or not ibis.options.graphviz_repr: + if config.options.interactive or not config.options.graphviz_repr: return None try: import ibis.expr.visualize as viz @@ -220,9 +223,9 @@ def _find_backend(self) -> BaseBackend: backends = self._find_backends() if not backends: - default = options.default_backend + default = config.options.default_backend if default is None: - raise com.IbisError( + raise IbisError( 'Expression depends on no backends, and found no default' ) return default @@ -289,7 +292,7 @@ def compile( self, limit=limit, timecontext=timecontext, params=params ) - @util.deprecated( + @deprecated( version='2.0', instead=( "call [`Expr.compile`][ibis.expr.types.core.Expr.compile] and " @@ -306,10 +309,6 @@ def verify(self): return True -class UnnamedMarker: - pass - - unnamed = UnnamedMarker() @@ -346,7 +345,7 @@ def _binop( """ try: node = op_class(left, right) - except (com.IbisTypeError, NotImplementedError): + except (IbisTypeError, NotImplementedError): return NotImplemented else: return node.to_expr() diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index a1e1e8025195..b2c7ed8bc5dd 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -222,7 +222,7 @@ def get_column(self, name: str) -> ColumnExpr: @cached_property def columns(self): - return self.schema().names + return list(self.schema().names) def schema(self) -> sch.Schema: """Return the table's schema. diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index dc5b8b826a12..e4fc81661c13 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -112,7 +112,7 @@ def test_projection(table): assert isinstance(proj, TableExpr) assert isinstance(proj.op(), ops.Selection) - assert proj.schema().names == cols + assert proj.schema().names == tuple(cols) for c in cols: expr = proj[c] assert isinstance(expr, type(table[c])) @@ -133,8 +133,8 @@ def test_projection_with_exprs(table): proj = table[col_exprs + ['g']] schema = proj.schema() - assert schema.names == ['log_b', 'mean_diff', 'g'] - assert schema.types == [dt.double, dt.double, dt.string] + assert schema.names == ('log_b', 'mean_diff', 'g') + assert schema.types == (dt.double, dt.double, dt.string) # Test with unnamed expr with pytest.raises(ExpressionError): @@ -178,7 +178,7 @@ def test_projection_with_star_expr(table): proj = t[t, new_expr] repr(proj) - ex_names = table.schema().names + ['bigger_a'] + ex_names = table.schema().names + ('bigger_a',) assert proj.schema().names == ex_names # cannot pass an invalid table expression @@ -918,10 +918,10 @@ def test_join_project_after(table): joined = table1.left_join(table2, [pred]) projected = joined.projection([table1, table2['stuff']]) - assert projected.schema().names == ['key1', 'value1', 'stuff'] + assert projected.schema().names == ('key1', 'value1', 'stuff') projected = joined.projection([table2, table1['key1']]) - assert projected.schema().names == ['key2', 'stuff', 'key1'] + assert projected.schema().names == ('key2', 'stuff', 'key1') def test_semi_join_schema(table): diff --git a/ibis/util.py b/ibis/util.py index 7f4642675f68..c716bbe8f8cd 100644 --- a/ibis/util.py +++ b/ibis/util.py @@ -70,6 +70,10 @@ def __hash__(self): return self._hash +class UnnamedMarker: + pass + + def guid() -> str: """Return a uuid4 hexadecimal value.