From c31a8fe7858a217a5fe22a9484f61a9b2b538224 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Mon, 27 Feb 2023 19:32:50 +0000 Subject: [PATCH] fix embedding_backward_dense decomp with broadcasting (#95499) Fixes https://github.com/pytorch/pytorch/issues/95182 Pull Request resolved: https://github.com/pytorch/pytorch/pull/95499 Approved by: https://github.com/ezyang, https://github.com/ngimel --- test/dynamo/test_repros.py | 19 +++++++++++++++++++ torch/_decomp/decompositions.py | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index b0226025fb0..4e34b2cd142 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -857,6 +857,25 @@ def f(x): f(torch.ones(2, device="cuda", dtype=torch.float64)) + def test_embedding_backward_broadcasting_decomp(self): + def f(grad_output, indices): + num_weights = 10 + padding_idx = 1 + scale_grad_by_freq = True + return torch.ops.aten.embedding_dense_backward( + grad_output, indices, num_weights, padding_idx, scale_grad_by_freq + ) + + f_compiled = torch.compile(f, backend="aot_eager") + + grad_output = torch.ones(2, 4, 3, dtype=torch.float16) + indices = torch.ones(2, 4, dtype=torch.int64) + + out_ref = f(grad_output, indices) + out_test = f_compiled(grad_output, indices) + + self.assertEqual(out_ref, out_test) + def test_reformer_eval(self): with torch.no_grad(): cnt = self._reformer(nopython=True) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index d2964f2bbd2..54266e1bd37 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1071,7 +1071,7 @@ def embedding_dense_backward( ones = torch.ones_like(indices) counts = counts.index_put([indices], ones, accumulate=True) grad_weights_scale = counts[indices] - grad_output = grad_output / grad_weights_scale.unsqueeze(1) + grad_output = grad_output / grad_weights_scale.unsqueeze(-1) mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim) grad = grad_output.masked_fill(mask, 0)