Skip to content

Commit

Permalink
refactor(common): enforce slots definitions for Base subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Oct 3, 2022
1 parent e4ab127 commit 6c3df91
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 41 deletions.
9 changes: 5 additions & 4 deletions ibis/common/grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ class BaseMeta(ABCMeta):

__slots__ = ()

def __new__(metacls, clsname, bases, dct, **kwargs):
# enforce slot definitions
dct.setdefault("__slots__", ())
return super().__new__(metacls, clsname, bases, dct, **kwargs)

def __call__(cls, *args, **kwargs):
return cls.__create__(*args, **kwargs)

Expand Down Expand Up @@ -147,8 +152,6 @@ def copy(self, **overrides):


class Immutable(Base):
__slots__ = ()

def __setattr__(self, name: str, _: Any) -> None:
raise TypeError(
f"Attribute {name!r} cannot be assigned to immutable instance of "
Expand All @@ -158,7 +161,6 @@ def __setattr__(self, name: str, _: Any) -> None:

class Singleton(Base):

__slots__ = ()
__instances__ = WeakValueDictionary()

@classmethod
Expand All @@ -174,7 +176,6 @@ def __create__(cls, *args, **kwargs):

class Comparable(Base):

__slots__ = ()
__cache__ = WeakCache()

def __eq__(self, other):
Expand Down
2 changes: 2 additions & 0 deletions ibis/common/tests/test_grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ class Node(Comparable):

# override the default cache object
__cache__ = WeakCache()
__slots__ = ('name',)
num_equal_calls = 0

def __init__(self, name):
Expand Down Expand Up @@ -637,6 +638,7 @@ class OneAndOnly(Singleton):


class DataType(Singleton):
__slots__ = ('nullable',)
__instances__ = weakref.WeakValueDictionary()

def __init__(self, nullable=True):
Expand Down
5 changes: 2 additions & 3 deletions ibis/expr/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ def fmt(expr: ir.Expr) -> str:
formatted = fmt_table_op(node, aliases=aliases, deps=deps)
pieces.append(f"{alias} := {formatted}")

pieces.append(
fmt_root(root, name=expr._safe_name, aliases=aliases, deps=deps)
)
name = expr.get_name() if expr.has_name() else None
pieces.append(fmt_root(root, name=name, aliases=aliases, deps=deps))
depth = ibis.options.repr.depth or 0
if depth and depth < len(pieces):
return fmt_truncated(pieces, depth=depth)
Expand Down
27 changes: 3 additions & 24 deletions ibis/expr/types/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,14 @@

import os
import webbrowser
from functools import cached_property
from typing import TYPE_CHECKING, Any, Mapping

import toolz
from public import public

import ibis.common.graph as g
import ibis.expr.operations as ops
from ibis.common.exceptions import (
ExpressionError,
IbisError,
IbisTypeError,
TranslationError,
)
from ibis.common.exceptions import IbisError, IbisTypeError, TranslationError
from ibis.common.grounds import Immutable
from ibis.common.pretty import console
from ibis.config import _default_backend, options
Expand Down Expand Up @@ -81,25 +75,10 @@ def __bool__(self) -> bool:
__nonzero__ = __bool__

def has_name(self):
return isinstance(self.op(), ops.Named)
return isinstance(self._arg, ops.Named)

def get_name(self):
return self.op().name

# TODO(kszucs): remove it entirely
@cached_property
def _safe_name(self) -> str | None:
"""Get the name of an expression `expr` if one exists.
Returns
-------
str | None
`str` if the Expr has a name, otherwise `None`
"""
try:
return self.get_name()
except (ExpressionError, AttributeError):
return None
return self._arg.name

def _repr_png_(self) -> bytes | None:
if options.interactive or not options.graphviz_repr:
Expand Down
7 changes: 3 additions & 4 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import operator
import sys
import warnings
from functools import cached_property
from typing import (
IO,
TYPE_CHECKING,
Expand Down Expand Up @@ -307,9 +306,9 @@ def get_column(self, name: str) -> Column:
"""
return ops.TableColumn(self, name).to_expr()

@cached_property
@property
def columns(self):
return list(self.schema().names)
return list(self._arg.schema.names)

def schema(self) -> sch.Schema:
"""Get the schema for this table (if one is known)
Expand Down Expand Up @@ -1108,7 +1107,7 @@ def set_column(self, name: str, expr: ir.Value) -> Table:
"""
expr = self._ensure_expr(expr)

if expr._safe_name != name:
if expr.get_name() != name:
expr = expr.name(name)

if name not in self:
Expand Down
7 changes: 3 additions & 4 deletions ibis/expr/types/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import collections
import itertools
from functools import cached_property
from typing import TYPE_CHECKING, Iterable, Mapping, Sequence

from public import public
Expand Down Expand Up @@ -96,17 +95,17 @@ def __getattr__(self, name: str) -> ir.Value:
return self.__getitem__(name)
raise AttributeError(name)

@cached_property
@property
def names(self) -> Sequence[str]:
"""Return the field names of the struct."""
return self.type().names

@cached_property
@property
def types(self) -> Sequence[dt.DataType]:
"""Return the field types of the struct."""
return self.type().types

@cached_property
@property
def fields(self) -> Mapping[str, dt.DataType]:
"""Return a mapping from field name to field type of the struct."""
return util.frozendict(self.type().pairs)
Expand Down
14 changes: 12 additions & 2 deletions ibis/expr/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ class Window(Comparable):
Use 0 for `CURRENT ROW`.
"""

__slots__ = (
'_group_by',
'_order_by',
'_hash',
'preceding',
'following',
'max_lookback',
'how',
)

def __init__(
self,
group_by=None,
Expand Down Expand Up @@ -133,9 +143,9 @@ def __init__(
self.how = how

self._validate_frame()
self._hash = self._compute_hash()

@functools.cached_property
def _hash(self) -> int:
def _compute_hash(self) -> int:
return hash(
(
*self._group_by,
Expand Down

0 comments on commit 6c3df91

Please sign in to comment.