From 5657dd4b9ec62c1fe2258cb6c0229e4a8b7c0ec8 Mon Sep 17 00:00:00 2001 From: Kit Yan Choi Date: Tue, 12 May 2020 12:35:26 +0100 Subject: [PATCH] Restructuring Expression, not yet tested with recursion --- traits/observers/expressions.py | 304 +++++++-------------- traits/observers/tests/test_expressions.py | 46 ---- 2 files changed, 99 insertions(+), 251 deletions(-) diff --git a/traits/observers/expressions.py b/traits/observers/expressions.py index 15ea0ba5..73e72a56 100644 --- a/traits/observers/expressions.py +++ b/traits/observers/expressions.py @@ -9,7 +9,6 @@ # # Thanks for using Enthought open source! - import functools as _functools from traits.trait_base import ( @@ -38,27 +37,24 @@ from traits.observers._expressions_info import ( ) -class Expression: +class _BaseExpression: + """ Base class that defines and provides user-facing + methods that should be supported by all types of expressions. """ - Expression is an object for describing what traits are being observed - for change notifications. It can be passed directly to - ``HasTraits.observe`` method or the ``observe`` decorator. - Typically one creates an instance of ``Expression`` via a number of top - level functions, such as ``trait`` and ``join_``. - """ - def __init__(self): - # ``_levels`` is a list of list of IObserver. - # Each item corresponds to a layer of branches in the ObserverGraph. - # The last item is the most nested level. - # When graphs are constructured from this expression, one starts - # from the end of this list, to the top, and then continues to - # the prior_expressions - self._levels = [] - - # Represent prior expressions to be combined in series (JOIN) - # or in parallel (OR) - self._prior_expression = None + def _as_graphs(self, children=None): + """ Return all the ObserverGraph for the observer framework to attach + notifiers. + + This is considered private to the users and to modules outside of the + ``observers`` subpackage, but public to modules within the + ``observers`` subpackage. + + Returns + ------- + graphs : list of ObserverGraph + """ + raise NotImplementedError("_as_graphs must be implemented.") def __eq__(self, other): """ Return true if the other value is an Expression with equivalent @@ -68,7 +64,7 @@ class Expression: ------- boolean """ - if type(other) is not type(self): + if not isinstance(other, _BaseExpression): return False self_graphs = self._as_graphs() other_graphs = other._as_graphs() @@ -86,17 +82,15 @@ class Expression: Parameters ---------- - expression : traits.observers.expressions.Expression + expression : traits.observers.expressions._BaseExpression Returns ------- - new_expression : traits.observers.expressions.Expression + new_expression : traits.observers.expressions._BaseExpression """ if self == expression: - return self.copy() - new = Expression() - new._prior_expression = _ParallelExpression([self, expression]) - return new + return self + return _ParallelExpression(self, expression) def then(self, expression): """ Create a new expression by extending this expression with @@ -107,34 +101,13 @@ class Expression: Parameters ---------- - expression : traits.observers.expressions.Expression - - Returns - ------- - new_expression : traits.observers.expressions.Expression - """ - - if self._prior_expression is None and not self._levels: - # this expression is empty... - new = expression.copy() - else: - new = Expression() - new._prior_expression = _SeriesExpression([self, expression]) - return new - - def _as_graphs(self): - """ Return all the ObserverGraph for the observer framework to attach - notifiers. - - This is considered private to the users and to modules outside of the - ``observers`` subpackage, but public to modules within the - ``observers`` subpackage. + expression : traits.observers.expressions._BaseExpression Returns ------- - graphs : list of ObserverGraph + new_expression : traits.observers.expressions._BaseExpression """ - return _create_graphs(self) + return _SeriesExpression(self, expression) def trait(self, name, notify=True, optional=False): """ Create a new expression for observing a trait with the exact @@ -157,14 +130,11 @@ class Expression: Returns ------- - new_expression : traits.observers.expressions.Expression + new_expression : traits.observers.expressions._BaseExpression """ - return self._new_with_branches( - nodes=[ - _NamedTraitObserver( - name=name, notify=notify, optional=optional) - ], - ) + observer = _NamedTraitObserver( + name=name, notify=notify, optional=optional) + return _SeriesExpression(self, _SingleObserverExpression(observer)) def list_items(self, notify=True, optional=False): """ Create a new expression for observing items inside a list. @@ -186,11 +156,10 @@ class Expression: Returns ------- - new_expression : traits.observers.expressions.Expression + new_expression : traits.observers.expressions._BaseExpression """ - return self._new_with_branches( - nodes=[_ListItemObserver(notify=notify, optional=optional)], - ) + observer = _ListItemObserver(notify=notify, optional=optional) + return _SeriesExpression(self, _SingleObserverExpression(observer)) def dict_items(self, notify=True, optional=False): """ Create a new expression for observing items inside a dict. @@ -213,12 +182,10 @@ class Expression: Returns ------- - new_expression : traits.observers.expressions.Expression + new_expression : traits.observers.expressions._BaseExpression """ - # Should be similar to list_items but for dict - return self._new_with_branches( - nodes=[_DictItemObserver(notify=notify, optional=optional)], - ) + observer = _DictItemObserver(notify=notify, optional=optional) + return _SeriesExpression(self, _SingleObserverExpression(observer)) def set_items(self, notify=True, optional=False): """ Create a new expression for observing items inside a set. @@ -234,12 +201,10 @@ class Expression: Returns ------- - new_expression : traits.observers.expressions.Expression + new_expression : traits.observers.expressions._BaseExpression """ - # Should be similar to list_items but for set - return self._new_with_branches( - nodes=[_SetItemObserver(notify=notify, optional=optional)], - ) + observer = _SetItemObserver(notify=notify, optional=optional) + return _SeriesExpression(self, _SingleObserverExpression(observer)) def filter_(self, filter, notify=True): """ Create a new expression for observing traits using the @@ -259,15 +224,10 @@ class Expression: Returns ------- - new_expression : traits.observers.expressions.Expression + new_expression : traits.observers.expressions._BaseExpression """ - return self._new_with_branches( - nodes=[ - _FilteredTraitObserver( - notify=notify, filter=filter, - ) - ], - ) + observer = _FilteredTraitObserver(notify=notify, filter=filter) + return _SeriesExpression(self, _SingleObserverExpression(observer)) def metadata(self, metadata_name, value=_not_none, notify=True): """ Return a new expression that matches traits based on @@ -293,7 +253,7 @@ class Expression: Returns ------- - new_expression : traits.observers.expressions.Expression + new_expression : traits.observers.expressions._BaseExpression """ # sanity check if not callable(value): @@ -307,34 +267,6 @@ class Expression: notify=notify, ) - def _new_with_branches(self, nodes): - """ Create a new Expression with a new leaf nodes. - - Parameters - ---------- - nodes : list of IObserver - - Returns - ------- - new_expression : traits.observers.expressions.Expression - """ - expression = self.copy() - expression._levels.append(nodes) - return expression - - def copy(self): - """ Return a copy of this expression. - - Returns - ------- - new_expression : traits.observers.expressions.Expression - """ - expression = Expression() - expression._levels = self._levels.copy() - if self._prior_expression is not None: - expression._prior_expression = self._prior_expression.copy() - return expression - def info(self): """ Return a list of user-friendly texts containing descriptive information about this expression. @@ -357,128 +289,90 @@ class Expression: print(*self.info(), sep="\n") -def _create_graphs(expression, graphs=None): - """ Create ObserverGraphs from a given expression. - - Parameters - ---------- - expression : traits.observers.expressions.Expression - graphs : collection of ObserverGraph - Leaf graphs to be added. - Needed when this function is called recursively. +class Expression(_BaseExpression): + """ + Expression is an object for describing what traits are being observed + for change notifications. It can be passed directly to + ``HasTraits.observe`` method or the ``observe`` decorator. - Returns - ------- - graphs : list of ObserverGraph - New graphs + Typically one creates an instance of ``Expression`` via a number of top + level functions, such as ``trait`` and ``join_``. """ - if graphs is None: - graphs = [] - for nodes in expression._levels[::-1]: - graphs = [ - _ObserverGraph(node=node, children=graphs) for node in nodes - ] + # This class serves as an entry point for top-level function. - if expression._prior_expression is not None: - graphs = expression._prior_expression._create_graphs( - graphs=graphs, - ) - return graphs + def _as_graphs(self, children=None): + """ Return all the ObserverGraph for the observer framework to attach + notifiers. + + This is considered private to the users and to modules outside of the + ``observers`` subpackage, but public to modules within the + ``observers`` subpackage. + + Returns + ------- + graphs : list of ObserverGraph + """ + if children is None: + return [] + return children -class _SeriesExpression: +class _SeriesExpression(_BaseExpression): """ Container of Expression for joining expressions in series. Used internally in this module. Parameters ---------- - expressions : list of Expression - List of Expression to be combined in series. + first : _IExpression + Left expression to be joined in series. + second : _IExpression + Right expression to be joined in series. """ - def __init__(self, expressions): - self.expressions = expressions.copy() - - def copy(self): - """ Return a copy of this instance. - The internal ``expressions`` list is copied so it can be mutated. - - Returns - ------- - series_expression : _SeriesExpression - """ - return _SeriesExpression(self.expressions) + def __init__(self, first, second): + self._first = first + self._second = second - def _create_graphs(self, graphs): - """ - Create new ObserverGraph(s) from the joined expressions. + def _as_graphs(self, children=None): + children = self._second._as_graphs(children) + return self._first._as_graphs(children) - Parameters - ---------- - graphs : collection of ObserverGraph - Leaf graphs to be added. - Needed when this function is called recursively. - Returns - ------- - graphs : list of ObserverGraph - New graphs - """ - for expr in self.expressions[::-1]: - graphs = _create_graphs( - expr, - graphs=graphs, - ) - return graphs - - -class _ParallelExpression: +class _ParallelExpression(_BaseExpression): """ Container of Expression for joining expressions in parallel. Used internally in this module. Parameters ---------- - expressions : list of Expression - List of Expression to be combined in parallel. + left : _BaseExpression + Left expression to be joined in parallel. + right : _BaseExpression + Right expression to be joined in parallel. """ - def __init__(self, expressions): - self.expressions = expressions.copy() + def __init__(self, left, right): + self._left = left + self._right = right - def copy(self): - """ Return a copy of this instance. - The internal ``expressions`` list is copied so it can be mutated. + def _as_graphs(self, children=None): + left_graphs = self._left._as_graphs(children) + right_graphs = self._right._as_graphs(children) + return left_graphs + right_graphs - Returns - ------- - parallel_expression : _ParallelExpression - """ - return _ParallelExpression(self.expressions) - def _create_graphs(self, graphs): - """ - Create new ObserverGraph(s) from the joined expressions. +class _SingleObserverExpression(_BaseExpression): + """ Container of Expression for wrapping a single observer. + Used internally in this module. + """ - Parameters - ---------- - graphs : collection of ObserverGraph - Leaf graphs to be added. - Needed when this function is called recursively. + def __init__(self, observer): + self.observer = observer - Returns - ------- - graphs : list of ObserverGraph - New graphs - """ - new_graphs = [] - for expr in self.expressions: - or_graphs = _create_graphs( - expr, - graphs=graphs, - ) - new_graphs.extend(or_graphs) - return new_graphs + def _as_graphs(self, children=None): + return [ + _ObserverGraph(node=self.observer, children=children) + ] class _MetadataFilter: @@ -560,11 +454,11 @@ def join_(*expressions): Parameters ---------- - *expressions : iterable of traits.observers.expressions.Expression + *expressions : iterable of traits.observers.expressions._BaseExpression Returns ------- - new_expression : traits.observers.expressions.Expression + new_expression : traits.observers.expressions._BaseExpression Joined expression. """ return _functools.reduce(lambda e1, e2: e1.then(e2), expressions) diff --git a/traits/observers/tests/test_expressions.py b/traits/observers/tests/test_expressions.py index a0f78a46..74236184 100644 --- a/traits/observers/tests/test_expressions.py +++ b/traits/observers/tests/test_expressions.py @@ -310,52 +310,6 @@ class TestExpressionEquality(unittest.TestCase): self.assertNotEqual(expr, "1") -class TestExpressionCopy(unittest.TestCase): - """ Test the Expression.copy method.""" - - def test_expression_copy_current_levels(self): - - expr = expressions.trait("name") - copied = expr.copy() - self.assertEqual(expr._levels, copied._levels) - self.assertIsNot(copied._levels, expr._levels) - self.assertEqual(copied._as_graphs(), expr._as_graphs()) - - def test_expression_copy_prior_expression_parallel(self): - expr = expressions.trait("name") | expressions.trait("age") - self.assertIsNotNone(expr._prior_expression) - - copied = expr.copy() - self.assertEqual(copied._as_graphs(), expr._as_graphs()) - self.assertIsNotNone(copied._prior_expression) - self.assertIsNot(copied._prior_expression, expr._prior_expression) - self.assertEqual( - copied._prior_expression.expressions, - expr._prior_expression.expressions, - ) - self.assertIsNot( - copied._prior_expression.expressions, - expr._prior_expression.expressions, - ) - - def test_expression_copy_prior_expression_serial(self): - expr = expressions.trait("name").then(expressions.trait("age")) - self.assertIsNotNone(expr._prior_expression) - - copied = expr.copy() - self.assertEqual(copied._as_graphs(), expr._as_graphs()) - self.assertIsNotNone(copied._prior_expression) - self.assertIsNot(copied._prior_expression, expr._prior_expression) - self.assertEqual( - copied._prior_expression.expressions, - expr._prior_expression.expressions, - ) - self.assertIsNot( - copied._prior_expression.expressions, - expr._prior_expression.expressions, - ) - - class TestExpressionInfoPrint(unittest.TestCase): """ Integration test the Expression.info and Expression.print methods.""" -- 2.24.2 (Apple Git-127)