From 04aad7cea65d74d21cef3948137376d3d37d2e36 Mon Sep 17 00:00:00 2001 From: JYChen Date: Tue, 26 Sep 2023 16:24:31 +0800 Subject: [PATCH] temporarily fix advanced-setitem grad error (#57737) --- python/paddle/base/variable_index.py | 30 +++++++++++++++++----------- test/indexing/test_setitem.py | 17 ++++++++++++++++ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index 6d034b80c8d9c4..0ff628ed48f4f9 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -956,7 +956,8 @@ def _setitem_static(x, indices, values): values = values.astype(transed_sub_tensor.dtype) if paddle.in_dynamic_mode(): - return transed_sub_tensor.index_put_( + # NOTE(zoooo0820): directly return result instead of another set_value, after backward bug fixed. + transed_sub_tensor = transed_sub_tensor.index_put_( adjusted_advanced_index, values ) else: @@ -964,9 +965,13 @@ def _setitem_static(x, indices, values): adjusted_advanced_index, values ) - transback_sub_tensor = transed_sub_tensor.transpose(transback_dim) - inputs["ValueTensor"] = transback_sub_tensor + transback_sub_tensor = transed_sub_tensor.transpose(transback_dim) + inputs["ValueTensor"] = transback_sub_tensor + if paddle.in_dynamic_mode(): + x._bump_inplace_version() + output = x + else: helper = paddle.base.layer_helper.LayerHelper( 'set_value', **locals() ) @@ -979,20 +984,21 @@ def _setitem_static(x, indices, values): output = helper.create_variable_for_type_inference( dtype=x.dtype ) - cur_block = default_main_program().current_block() - cur_block.append_op( - type="set_value", - inputs=inputs, - outputs={'Out': output}, - attrs=attrs, - inplace_map={"Input": "Out"}, - ) + cur_block = default_main_program().current_block() + cur_block.append_op( + type="set_value", + inputs=inputs, + outputs={'Out': output}, + attrs=attrs, + inplace_map={"Input": "Out"}, + ) + if not paddle.in_dynamic_mode(): # map var to the new output paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add( cur_block.program, x.desc.id(), output ) - return output + return output def get_tensor_with_basic_indexing( diff --git a/test/indexing/test_setitem.py b/test/indexing/test_setitem.py index b9c51bfe084036..3c3f8deb3955e5 100644 --- a/test/indexing/test_setitem.py +++ b/test/indexing/test_setitem.py @@ -123,6 +123,23 @@ def test_indexing_is_multi_dim_list(self): np.testing.assert_allclose(x.numpy(), np_data) + def test_inplace_with_stride(self): + v = paddle.randn((3, 1)) + v.stop_gradient = False + vv = v * 1 + + zero = paddle.randn((3, 3, 5)) + zero.stop_gradient = False + + zero1 = zero * 1 + zero1[paddle.to_tensor([0, 1])] = vv + + loss = zero1.sum() + loss.backward() + + expected_v_grad = np.ones((3, 1)) * 10.0 + np.testing.assert_equal(v.grad.numpy(), expected_v_grad) + class TestSetitemInStatic(unittest.TestCase): def setUp(self):