Skip to content

Commit

Permalink
Merge pull request #2349 from devitocodes/fix-elementary-dtype
Browse files Browse the repository at this point in the history
compiler: Make code gen of elementary funcs dtype-aware
  • Loading branch information
FabioLuporini authored Apr 12, 2024
2 parents 3a8c46e + 3c8edea commit 42cce7e
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 42 deletions.
2 changes: 1 addition & 1 deletion devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ def lower_schedule(schedule, meta, sregistry, ftemps):
# This prevents cases such as `floor(a*b)` with `a` and `b` floats
# that would creat a temporary `int r = b` leading to erronous
# numerical results
dtype = sympy_dtype(pivot, meta.dtype)
dtype = sympy_dtype(pivot, base=meta.dtype)

if writeto:
# The Dimensions defining the shape of Array
Expand Down
22 changes: 7 additions & 15 deletions devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
from sympy import (Function, Indexed, Integer, Mul, Number,
Pow, S, Symbol, Tuple)
from sympy.core.operations import AssocOp

from devito.finite_differences import Derivative
from devito.finite_differences.differentiable import IndexDerivative
Expand Down Expand Up @@ -291,21 +290,14 @@ def has_integer_args(*args):
return res


def sympy_dtype(expr, default):
def sympy_dtype(expr, base=None):
"""
Infer the dtype of the expression
or default if could not be determined.
Infer the dtype of the expression.
"""
# Symbol/... without argument, check its dtype
if len(expr.args) == 0:
dtypes = {base} - {None}
for i in expr.free_symbols:
try:
return expr.dtype
dtypes.add(i.dtype)
except AttributeError:
return default
else:
if not (isinstance(expr.func, AssocOp) or expr.is_Pow):
return default
else:
# Infer expression dtype from its arguments
dtype = infer_dtype([sympy_dtype(a, default) for a in expr.args])
return dtype or default
pass
return infer_dtype(dtypes)
26 changes: 25 additions & 1 deletion devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sympy.printing.c import C99CodePrinter

from devito.arch.compiler import AOMPCompiler
from devito.symbolics.inspection import has_integer_args
from devito.symbolics.inspection import has_integer_args, sympy_dtype
from devito.types.basic import AbstractFunction

__all__ = ['ccode']
Expand Down Expand Up @@ -89,6 +89,25 @@ def _print_Rational(self, expr):
else:
return '%d.0F/%d.0F' % (p, q)

def _print_math_func(self, expr, nest=False, known=None):
cls = type(expr)
name = cls.__name__
if name not in self._prec_funcs:
return super()._print_math_func(expr, nest=nest, known=known)

try:
cname = self.known_functions[name]
except KeyError:
return super()._print_math_func(expr, nest=nest, known=known)

dtype = sympy_dtype(expr)
if dtype is np.float32:
cname += 'f'

args = ', '.join((self._print(arg) for arg in expr.args))

return '%s(%s)' % (cname, args)

def _print_Pow(self, expr):
# Need to override because of issue #1627
# E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x
Expand Down Expand Up @@ -255,6 +274,11 @@ def _print_Fallback(self, expr):
_print_Basic = _print_Fallback


# Lifted from SymPy so that we go through our own `_print_math_func`
for k in ('exp log sin cos tan ceiling floor').split():
setattr(CodePrinter, '_print_%s' % k, CodePrinter._print_math_func)


# Always parenthesize IntDiv and InlineIf within expressions
PRECEDENCE_VALUES['IntDiv'] = 1
PRECEDENCE_VALUES['InlineIf'] = 1
Expand Down
Loading

0 comments on commit 42cce7e

Please sign in to comment.