Skip to content

Commit

Permalink
CI: add test for and fixes #920
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Aug 2, 2023
1 parent 42b9f57 commit 48024ff
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
8 changes: 1 addition & 7 deletions devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import singledispatch

import numpy as np
from sympy import Function, Indexed, Integer, Mul, Number, Pow, S, Symbol, Tuple, Add
from sympy import Function, Indexed, Integer, Mul, Number, Pow, S, Symbol, Tuple

from devito.finite_differences import Derivative
from devito.finite_differences.differentiable import IndexDerivative
Expand Down Expand Up @@ -269,12 +269,6 @@ def sympy_dtype(expr, default):
returns the default if non is found
"""
args = expr.args
# We can only infer the dtype for addition/multiplication or Symbols
# For other case the epxression function may modify the infered dtype
if not (isinstance(expr.func, Add) or isinstance(expr.func, Mul)) or \
not expr.is_Symbol:
return default

# Symbol/... without argument, check its dtype
if len(args) == 0:
try:
Expand Down
5 changes: 4 additions & 1 deletion devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,10 @@ def _dist_data_gather(self, data):
return

# Compute dist map only once
data = self._C_as_ndarray(data)
try:
data = self._C_as_ndarray(data)
except AttributeError:
pass
dmap = self._dist_datamap
mask = self._dist_scatter_mask(dmap=dmap)

Expand Down
15 changes: 15 additions & 0 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2627,6 +2627,21 @@ def test_issue_2163(self):
subdomain=grid.interior))
assert_structure(op, ['t,i0x,i0y'], 'ti0xi0y')

def test_dtype_aliases(self):
a = np.arange(64).reshape((8, 8))
grid = Grid(shape=a.shape, extent=(8, 8))

so = 2
f = Function(name='f', grid=grid, space_order=so, dtype=np.int32)
f.data[:] = a

fo = Function(name='fo', grid=grid, space_order=so, dtype=np.int32)
op = Operator(Eq(fo, f.dx))
op.apply()

assert FindNodes(Expression).visit(op)[0].dtype == np.float32
assert np.all(fo.data[:-1, :-1] == 6)


class TestIsoAcoustic(object):

Expand Down
2 changes: 1 addition & 1 deletion tests/test_gpu_openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_tile_insteadof_collapse(self, par_tile):
opt=('advanced', {'par-tile': par_tile}))

trees = retrieve_iteration_tree(op)
assert len(trees) == 4
assert len(trees) == 6

assert trees[0][1].pragmas[0].value ==\
'acc parallel loop tile(32,4,4) present(u)'
Expand Down

0 comments on commit 48024ff

Please sign in to comment.