Skip to content

Commit

Permalink
fix primtorch handling for sub.scalar with alpha and float64 arg (#95…
Browse files Browse the repository at this point in the history
…421)

This fixes the primtorch issue stemming from pytorch/pytorch#95181

Pull Request resolved: pytorch/pytorch#95421
Approved by: https://github.com/ngimel, https://github.com/SherlockNoMad
  • Loading branch information
bdhirsh authored and cyyever committed Mar 5, 2023
1 parent 796678d commit 8082c2b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
8 changes: 8 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,14 @@ def _reformer(self, nopython):
self.assertTrue(same(opt_model(input), correct))
return cnt

@requires_cuda()
def test_sub_alpha_scalar_repro(self):
@torch.compile(backend="aot_eager")
def f(x):
return x.sub(1, alpha=2)

f(torch.ones(2, device="cuda", dtype=torch.float64))

def test_reformer_eval(self):
with torch.no_grad():
cnt = self._reformer(nopython=True)
Expand Down
8 changes: 7 additions & 1 deletion torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,7 +1604,13 @@ def sub(
)
)
raise ValueError(msg)
b = prims.mul(b, alpha)
if isinstance(b, torch.Tensor):
b = prims.mul(b, alpha)
else:
# Carefully not to use prims.mul if b is a scalar / symint.
# prims.mul always returns a tensor,
# which will mess with type promotion.
b = b * alpha

return prims.sub(a, b)

Expand Down

0 comments on commit 8082c2b

Please sign in to comment.