diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 89a4999bd86..b0226025fb0 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -849,6 +849,14 @@ def _reformer(self, nopython): self.assertTrue(same(opt_model(input), correct)) return cnt + @requires_cuda() + def test_sub_alpha_scalar_repro(self): + @torch.compile(backend="aot_eager") + def f(x): + return x.sub(1, alpha=2) + + f(torch.ones(2, device="cuda", dtype=torch.float64)) + def test_reformer_eval(self): with torch.no_grad(): cnt = self._reformer(nopython=True) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index e1a89721f14..23cb3ddf989 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -1604,7 +1604,13 @@ def sub( ) ) raise ValueError(msg) - b = prims.mul(b, alpha) + if isinstance(b, torch.Tensor): + b = prims.mul(b, alpha) + else: + # Carefully not to use prims.mul if b is a scalar / symint. + # prims.mul always returns a tensor, + # which will mess with type promotion. + b = b * alpha return prims.sub(a, b)