Skip to content

Commit

Permalink
Fix type check in local_pow_specialize
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 4, 2023
1 parent 7367e8d commit 5c87d74
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
5 changes: 4 additions & 1 deletion pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,7 +2071,10 @@ def local_pow_specialize(fgraph, node):
rval = [reciprocal(sqr(xsym))]
if rval:
rval[0] = cast(rval[0], odtype)
assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
assert rval[0].type.is_super(node.outputs[0].type), (
rval[0].type,
node.outputs[0].type,
)
return rval
else:
return False
Expand Down
14 changes: 13 additions & 1 deletion tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
perform_sigm_times_exp,
simplify_mul,
)
from pytensor.tensor.shape import Reshape, Shape_i
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape
from pytensor.tensor.type import (
TensorType,
cmatrix,
Expand Down Expand Up @@ -1671,6 +1671,18 @@ def test_local_pow_specialize():
assert isinstance(nodes[1].scalar_op, aes.basic.Reciprocal)
utt.assert_allclose(f(val_no0), val_no0 ** (-0.5))

twos = np.full(shape=(10,), fill_value=2.0).astype(config.floatX)
f = function([v], v**twos, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
# Depending on the mode the SpecifyShape is lifted or not
if topo[0].op == sqr:
assert isinstance(topo[1].op, SpecifyShape)
else:
assert isinstance(topo[0].op, SpecifyShape)
assert topo[1].op == sqr
utt.assert_allclose(f(val), val**twos)


def test_local_pow_to_nested_squaring():
mode = config.mode
Expand Down

0 comments on commit 5c87d74

Please sign in to comment.