Skip to content

Commit

Permalink
Fix TrueDiv gradient for integer inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 8, 2024
1 parent aa616e6 commit cdae903
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2036,7 +2036,10 @@ def grad(self, inputs, gout):
# to the output; x/y is still a function of x
# and y; it's just a step function.
if all(a.dtype in discrete_dtypes for a in (x, y)):
return [x.zeros_like(), y.zeros_like()]
return [
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]

first_part = gz / y

Expand Down

0 comments on commit cdae903

Please sign in to comment.