Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: Minor fixes to arithmetic operations with scalar and tensors #2276

Merged
merged 2 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion devito/ir/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def pull_dims(exprs, flag=True):
"""
dims = set()
for e in as_tuple(exprs):
dims.update({i for i in e.free_symbols if i.is_Dimension})
dims.update({i for i in e.free_symbols if isinstance(i, Dimension)})
if flag:
return set().union(*[d._defines for d in dims])
else:
Expand Down
2 changes: 1 addition & 1 deletion devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def abridge_dim_names(iet):
# Find SubDimensions or SubDimension-derived dimensions used as indices in
# the expression in the innermost loop
indexeds = FindSymbols('indexeds').visit(tree.inner)
dims = set().union(*[pull_dims(i, flag=False) for i in indexeds])
dims = pull_dims(indexeds, flag=False)
dims = [d for d in dims if any([dim.is_Sub for dim in d._defines])]
dims = [d for d in dims if not d.is_SubIterator]
names = [d.root.name for d in dims]
Expand Down
10 changes: 10 additions & 0 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sympy

from sympy.core.assumptions import _assume_rules
from sympy.core.decorators import call_highest_priority
from cached_property import cached_property

from devito.data import default_allocator
Expand Down Expand Up @@ -704,6 +705,15 @@ def adjoint(self, inner=True):
# Real valued adjoint is transpose
return self.transpose(inner=inner)

@call_highest_priority('__radd__')
def __add__(self, other):
try:
# Most case support sympy add
return super().__add__(other)
except TypeError:
# Sympy doesn't support add with scalars
return self.applyfunc(lambda x: x + other)

def _eval_matrix_mul(self, other):
"""
Copy paste from sympy to avoid explicit call to sympy.Add
Expand Down
7 changes: 6 additions & 1 deletion devito/types/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ def _flatten(self):
"""
if self.lhs.is_Matrix:
# Maps the Equations to retrieve the rhs from relevant lhs
eqs = dict(zip(as_tuple(self.lhs), as_tuple(self.rhs)))
try:
eqs = dict(zip(self.lhs, self.rhs))
except TypeError:
# Same rhs for all lhs
assert not self.rhs.is_Matrix
eqs = {i: self.rhs for i in self.lhs}
# Get the relevant equations from the lhs structure. .values removes
# the symmetric duplicates and off-diagonal zeros.
lhss = self.lhs.values()
Expand Down
Loading