Skip to content

Commit

Permalink
temporarily fix advanced-setitem grad error (PaddlePaddle#57737)
Browse files Browse the repository at this point in the history
  • Loading branch information
zoooo0820 authored Sep 26, 2023
1 parent 9c02e16 commit 04aad7c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
30 changes: 18 additions & 12 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,17 +956,22 @@ def _setitem_static(x, indices, values):
values = values.astype(transed_sub_tensor.dtype)

if paddle.in_dynamic_mode():
return transed_sub_tensor.index_put_(
# NOTE(zoooo0820): directly return result instead of another set_value, after backward bug fixed.
transed_sub_tensor = transed_sub_tensor.index_put_(
adjusted_advanced_index, values
)
else:
transed_sub_tensor = transed_sub_tensor.index_put(
adjusted_advanced_index, values
)

transback_sub_tensor = transed_sub_tensor.transpose(transback_dim)
inputs["ValueTensor"] = transback_sub_tensor
transback_sub_tensor = transed_sub_tensor.transpose(transback_dim)
inputs["ValueTensor"] = transback_sub_tensor

if paddle.in_dynamic_mode():
x._bump_inplace_version()
output = x
else:
helper = paddle.base.layer_helper.LayerHelper(
'set_value', **locals()
)
Expand All @@ -979,20 +984,21 @@ def _setitem_static(x, indices, values):
output = helper.create_variable_for_type_inference(
dtype=x.dtype
)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)

if not paddle.in_dynamic_mode():
# map var to the new output
paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, x.desc.id(), output
)
return output
return output


def get_tensor_with_basic_indexing(
Expand Down
17 changes: 17 additions & 0 deletions test/indexing/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,23 @@ def test_indexing_is_multi_dim_list(self):

np.testing.assert_allclose(x.numpy(), np_data)

def test_inplace_with_stride(self):
v = paddle.randn((3, 1))
v.stop_gradient = False
vv = v * 1

zero = paddle.randn((3, 3, 5))
zero.stop_gradient = False

zero1 = zero * 1
zero1[paddle.to_tensor([0, 1])] = vv

loss = zero1.sum()
loss.backward()

expected_v_grad = np.ones((3, 1)) * 10.0
np.testing.assert_equal(v.grad.numpy(), expected_v_grad)


class TestSetitemInStatic(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 04aad7c

Please sign in to comment.