Skip to content

Commit

Permalink
Fix too strict type check in _sum_grad_over_bcasted_dims
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 26, 2024
1 parent 4c7b494 commit 6e57a08
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
10 changes: 8 additions & 2 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,7 +2027,6 @@ def _sum_grad_over_bcasted_dims(x, gx):
if gx.broadcastable != x.broadcastable:
x_dim_added = gx.ndim - x.ndim
x_broad = (True,) * x_dim_added + x.broadcastable
assert sum(gx.broadcastable) <= sum(x_broad)
axis_to_sum = []
for i in range(gx.ndim):
if gx.broadcastable[i] is False and x_broad[i] is True:
Expand All @@ -2045,7 +2044,14 @@ def _sum_grad_over_bcasted_dims(x, gx):
for i in range(x_dim_added):
assert gx.broadcastable[i]
gx = gx.dimshuffle(*range(x_dim_added, gx.ndim))
assert gx.broadcastable == x.broadcastable
# Broadcastable flags of gx can be the same or more specific than x.
# Only unallowed case is x_dim_b == True and gx_dim_b == False.
assert not any(
x_dim_b and not gx_dim_b
for x_dim_b, gx_dim_b in zip(
x.type.broadcastable, gx.type.broadcastable, strict=True
)
), (x.type, gx.type)
return gx


Expand Down
22 changes: 22 additions & 0 deletions tests/tensor/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from pytensor import function
from pytensor.compile import DeepCopyOp, shared
from pytensor.compile.io import In
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph
from pytensor.printing import pprint
Expand All @@ -22,6 +24,7 @@
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import exp, isinf
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand Down Expand Up @@ -1660,6 +1663,25 @@ def just_numeric_args(a, b):
),
)

def test_grad_broadcastable_specialization(self):
# Make sure gradient does not fail when gx has a more precise static_shape after indexing.
# This is a regression test for a bug reported in
# https://discourse.pymc.io/t/marginalized-mixture-wont-begin-sampling-throws-assertion-error/15969

x = vector("x") # Unknown write time shape = (2,)
out = x.zeros_like()

# Update a subtensor of unknown write time shape = (1,)
out = out[1:].set(exp(x[1:]))
out = specify_shape(out, 2)
gx = grad(out.sum(), x)

mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
gx.eval({x: [1, 1]}, mode=mode),
[0, np.e],
)


class TestIncSubtensor1:
def setup_method(self):
Expand Down

0 comments on commit 6e57a08

Please sign in to comment.