diff --git a/devito/ir/support/utils.py b/devito/ir/support/utils.py index a2dc2b4c73..acb851dbe7 100644 --- a/devito/ir/support/utils.py +++ b/devito/ir/support/utils.py @@ -267,7 +267,7 @@ def pull_dims(exprs, flag=True): """ dims = set() for e in as_tuple(exprs): - dims.update({i for i in e.free_symbols if i.is_Dimension}) + dims.update({i for i in e.free_symbols if isinstance(i, Dimension)}) if flag: return set().union(*[d._defines for d in dims]) else: diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index bd659cf6ab..640d2e5638 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -219,7 +219,7 @@ def abridge_dim_names(iet): # Find SubDimensions or SubDimension-derived dimensions used as indices in # the expression in the innermost loop indexeds = FindSymbols('indexeds').visit(tree.inner) - dims = set().union(*[pull_dims(i, flag=False) for i in indexeds]) + dims = pull_dims(indexeds, flag=False) dims = [d for d in dims if any([dim.is_Sub for dim in d._defines])] dims = [d for d in dims if not d.is_SubIterator] names = [d.root.name for d in dims] diff --git a/devito/types/basic.py b/devito/types/basic.py index a06e9c9a09..6ec4cfab28 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -8,6 +8,7 @@ import sympy from sympy.core.assumptions import _assume_rules +from sympy.core.decorators import call_highest_priority from cached_property import cached_property from devito.data import default_allocator @@ -704,6 +705,15 @@ def adjoint(self, inner=True): # Real valued adjoint is transpose return self.transpose(inner=inner) + @call_highest_priority('__radd__') + def __add__(self, other): + try: + # Most case support sympy add + return super().__add__(other) + except TypeError: + # Sympy doesn't support add with scalars + return self.applyfunc(lambda x: x + other) + def _eval_matrix_mul(self, other): """ Copy paste from sympy to avoid explicit call to sympy.Add diff --git a/devito/types/equation.py b/devito/types/equation.py index 9c716a7e73..d3a16311ff 100644 --- a/devito/types/equation.py +++ b/devito/types/equation.py @@ -106,7 +106,12 @@ def _flatten(self): """ if self.lhs.is_Matrix: # Maps the Equations to retrieve the rhs from relevant lhs - eqs = dict(zip(as_tuple(self.lhs), as_tuple(self.rhs))) + try: + eqs = dict(zip(self.lhs, self.rhs)) + except TypeError: + # Same rhs for all lhs + assert not self.rhs.is_Matrix + eqs = {i: self.rhs for i in self.lhs} # Get the relevant equations from the lhs structure. .values removes # the symmetric duplicates and off-diagonal zeros. lhss = self.lhs.values()