diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 182daec037e70..bf9c14845be9f 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -55,6 +55,23 @@ def check_all_puts(block, inputs, outputs): ) +def append_full_like(float_value, value, state, backward_ops): + value_grad = paddle.full_like( + value, + float_value, + dtype=value.dtype, + ) + full_like_op = value_grad.get_defining_op() + full_op = full_like_op.operand_source(1).get_defining_op() + update_bwdop_structure( + backward_ops, + state.op_to_opgrad[value.get_defining_op()], + [full_like_op, full_op], + ) + state.value_to_valuegrad[value] = [[value_grad]] + return value_grad + + def get_real_op_inputs(op): if op.name() in ["pd_op.if", "pd_op.while"]: return get_used_external_value(op) @@ -101,19 +118,7 @@ def prepare_grad_outputs(grad_outputs, outputs, state): # fwd : op1 -> op2 -> op3 -> output # bwd : op1G <- op2G <- op3G <- outputG <- full_likeop/feedop if grad is None: - output_grad = paddle.full_like( - output, - 1.0, - dtype=output.dtype, - ) - full_likeop = output_grad.get_defining_op() - fullop = full_likeop.operand_source(1).get_defining_op() - update_bwdop_structure( - backward_ops, - state.op_to_opgrad[output.get_defining_op()], - [full_likeop, fullop], - ) - state.value_to_valuegrad[output] = [[output_grad]] + append_full_like(1.0, output, state, backward_ops) else: if output.shape != grad.shape: raise ValueError( @@ -150,21 +155,9 @@ def prepare_grad_outputs(grad_outputs, outputs, state): [paddle.pir.fake_op_result()] ] else: - grad_value = paddle.full_like( - opresult, - 0.0, - opresult.dtype, + grad_value = append_full_like( + 0.0, opresult, state, backward_ops ) - full_likeop = grad_value.get_defining_op() - fullop = full_likeop.operand_source(1).get_defining_op() - - update_bwdop_structure( - backward_ops, - state.op_to_opgrad[opresult.get_defining_op()], - [full_likeop, fullop], - ) - state.value_to_valuegrad[opresult] = [[grad_value]] - visited_output.add(opresult) complete_outputs.append(opresult) @@ -336,12 +329,15 @@ def inverse_sort_op(ops): def append_backward_ops( base_op, + base_inputs, + base_input_grads, fwd_block, bwd_block, effective_forward_ops, no_grad_set, backward_ops, state, + inside_value_to_outside_value_map, ): ''' add grad_op in order of topological inverse sort @@ -383,20 +379,22 @@ def append_backward_ops( else continue to next op. ''' - def append_add_n(block, value): + def append_add_n(value): # one value is input of more than one fwd_op, # so more than one bwd_op create input_grad, # need add sum op to accumulate gradient - paddle.add_n([item[0] for item in state.value_to_valuegrad[value]]) - combineop = block.ops[len(block.ops) - 2] - sumop = block.ops[len(block.ops) - 1] + add_n_value = paddle.add_n( + [item[0] for item in state.value_to_valuegrad[value]] + ) + add_n_op = add_n_value.get_defining_op() + combine_op = add_n_op.operand_source(0).get_defining_op() update_bwdop_structure( - backward_ops, state.op_to_opgrad[op], [combineop, sumop] + backward_ops, state.op_to_opgrad[op], [combine_op, add_n_op] ) - state.value_to_valuegrad[value] = [[sumop.result(0)]] for tmp in state.value_to_valuegrad[value]: state.value_to_sumvaluegrad[value].append(tmp) + state.value_to_valuegrad[value] = [[add_n_value]] def make_output_with_output_grad(op): zero_flag = [False] * op.num_results() @@ -408,14 +406,14 @@ def make_output_with_output_grad(op): if value in control_flow_value_to_copyvalue_map else [value] ) - if value in inside_value_to_outside_value_map: + while value in inside_value_to_outside_value_map: value = inside_value_to_outside_value_map[value] if ( value in state.value_to_valuegrad and len(state.value_to_valuegrad[value]) > 1 ): - append_add_n(bwd_block, value) + append_add_n(value) if ( value not in state.value_to_valuegrad @@ -445,23 +443,10 @@ def make_output_with_output_grad(op): # second case: # last bwd_op return None because input in no_grad_set, # but this bwd_op need a input. - grad_value = paddle.full_like( - value, - 0.0, - dtype=value.dtype, - ) - full_likeop = grad_value.get_defining_op() - fullop = full_likeop.operand_source(1).get_defining_op() - update_bwdop_structure( - backward_ops, - state.op_to_opgrad[op], - [full_likeop, fullop], - ) + append_full_like(0.0, value, state, backward_ops) zero_flag[i] = True - state.value_to_valuegrad[value] = [[grad_value]] - outputs.append(new_value) output_grads.append(state.value_to_valuegrad[value][0]) @@ -564,15 +549,25 @@ def update_input_grad_map(op, input_grads, origin_inputs): state.value_to_valuegrad[input].append([input_grad]) i += 1 - def append_yield(block, inputs): + def append_yield(block, base_inputs, base_inputs_grad): with block: inputs_grad = [] - for value in inputs: + for value, value_grad in zip(base_inputs, base_inputs_grad): + if value_grad is None: + continue + + while value in inside_value_to_outside_value_map: + value = inside_value_to_outside_value_map[value] + if value in state.value_to_valuegrad: if len(state.value_to_valuegrad[value]) > 1: - append_add_n(block, value) + append_add_n(value) inputs_grad.append(state.value_to_valuegrad[value][0][0]) - + else: + value_grad = append_full_like( + 0.0, value, state, backward_ops + ) + inputs_grad.append(value_grad) paddle.base.libpaddle.pir.cf_yield(inputs_grad) # there are four patterns: @@ -586,8 +581,7 @@ def append_yield(block, inputs): control_flow_value_to_copyvalue_map = {} # tuple_push value to pop value control_flow_copyvalue_to_value_map = {} - # sub_block op output to parent_block op output - inside_value_to_outside_value_map = {} + if ( len(effective_forward_ops) > 1 and effective_forward_ops[-1].name() == "cf.yield" @@ -606,7 +600,6 @@ def append_yield(block, inputs): for op in inverse_effective_forward_ops: if op.name() != "builtin.combine" and op.name() != "builtin.split": clear_effective_forward_ops.append(op) - with bwd_block: for op in clear_effective_forward_ops: if paddle.framework.core.has_vjp(op): @@ -663,15 +656,21 @@ def append_yield(block, inputs): op.blocks(), grad_op.blocks() ): sub_state = state.copy(sub_fwd_block) + sub_inside_value_to_outside_value_map = ( + inside_value_to_outside_value_map.copy() + ) sub_backward_ops = [] append_backward_ops( op, + [input[0] for input in inputs], + [input_grad[0] for input_grad in input_grads], sub_fwd_block, sub_bwd_block, sub_fwd_block.ops, no_grad_set, sub_backward_ops, sub_state, + sub_inside_value_to_outside_value_map, ) # update input_grad map update_input_grad_map(op, input_grads, origin_inputs) @@ -706,7 +705,7 @@ def append_yield(block, inputs): if op.num_operands() == 0 and op.num_results() != 0: for value in op.results(): if len(state.value_to_valuegrad[value]) > 1: - append_add_n(bwd_block, value) + append_add_n(value) else: state.op_to_opgrad[op] = [] else: @@ -714,7 +713,7 @@ def append_yield(block, inputs): state.op_to_opgrad[op] = [] if fwd_block != bwd_block: - append_yield(bwd_block, get_used_external_value(base_op)) + append_yield(bwd_block, base_inputs, base_input_grads) def prepare_backward_prune_set(inputs, outputs): @@ -810,7 +809,12 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): inputs, complete_outputs ) + # sub_block op output to parent_block op output + inside_value_to_outside_value_map = {} + append_backward_ops( + None, + None, None, block, block, @@ -818,8 +822,8 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): no_grad_set, backward_ops, state, + inside_value_to_outside_value_map, ) - # now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue) outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set( outputs_fwd_set, inputs_fwd_set, no_grad_set, state @@ -980,3 +984,57 @@ def grad( input_grad = calc_gradient(outputs, inputs, grad_outputs, no_grad_set) return input_grad + + +# only for test +def append_backward(loss, parameter_list=None, no_grad_set=None): + ''' + Parameters: + loss (Value): The loss Tensor of the network + parameter_list (Value|list(Value)|tuple(Value)): List/Tuple of Parameters + that need to be updated by optimizers. + If it is None, all parameters + will be updated. + Default: None. + no_grad_vars (Value|list(Value)|tuple(Value)|set(Value), optional): + the Values whose gradients are not needed to compute. Default None. + + Returns: + list of tuple (Value): Pairs of parameter and its corresponding gradients. + The key is the parameter and the value is gradient Tensor. + ''' + + check_type( + loss, + 'loss', + (paddle.pir.Value, paddle.pir.OpResult), + 'paddle.autograd.ir_backward.append_backward', + ) + + if parameter_list is not None: + check_type( + parameter_list, + 'parameter_list', + (list, tuple, set), + 'paddle.autograd.ir_backwardappend_backward', + ) + for i, param in enumerate(parameter_list): + check_type( + param, + 'parameter_list[%s]' % i, + (paddle.pir.Value, paddle.pir.OpResult), + 'base.backward.append_backward', + ) + + else: + parameter_list = ( + loss.get_defining_op().get_parent_block().all_parameters() + ) + + inputs_grad = paddle.autograd.ir_backward.grad(loss, parameter_list) + + input_inputs_grad = [] + for input, input_grad in zip(parameter_list, inputs_grad): + input_inputs_grad.append((input, input_grad)) + + return input_inputs_grad diff --git a/python/paddle/base/backward.py b/python/paddle/base/backward.py index 512f4755c6046..290199bed2dcb 100755 --- a/python/paddle/base/backward.py +++ b/python/paddle/base/backward.py @@ -2044,6 +2044,11 @@ def append_backward( >>> p_g_list6 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights, no_grad_set=set(all_weights)) """ + if framework.in_pir_mode(): + return paddle.autograd.ir_backward.append_backward( + loss, parameter_list, no_grad_set + ) + grad_op_id_to_fwd_op = ( {} ) # for cuda graph usage, recording the mapping between grad op original id to fwd op diff --git a/test/legacy_test/test_cond.py b/test/legacy_test/test_cond.py index 9eb3b575c3408..7ad0c8561eabc 100644 --- a/test/legacy_test/test_cond.py +++ b/test/legacy_test/test_cond.py @@ -22,7 +22,7 @@ from paddle import base from paddle.base import core, framework from paddle.base.backward import append_backward -from paddle.base.framework import Program, program_guard +from paddle.pir_utils import test_with_pir_api np.random.seed(123) @@ -51,9 +51,9 @@ def false_func(): shape=[3, 2], dtype='int32', value=-1 ) - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): x = paddle.tensor.fill_constant( shape=[1], dtype='float32', value=0.1 ) @@ -94,9 +94,9 @@ def true_func(): def false_func(): return paddle.full(shape=[], dtype='int32', fill_value=-1) - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): x = paddle.full(shape=[1], dtype='float32', fill_value=0.1) y = paddle.full(shape=[1], dtype='float32', fill_value=0.23) pred = paddle.greater_equal(y, x) @@ -132,9 +132,9 @@ def true_func(): def false_func(): return paddle.full(shape=[3, 3], dtype='int32', fill_value=-1) - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): x = paddle.full(shape=[], dtype='float32', fill_value=0.1) y = paddle.full(shape=[], dtype='float32', fill_value=0.23) pred = paddle.greater_equal(y, x) @@ -152,6 +152,7 @@ def false_func(): np.asarray(ret), np.full((3, 3), 2, np.int32), rtol=1e-05 ) + @test_with_pir_api def test_0d_tensor_backward(self): """ pseudocode: @@ -165,13 +166,14 @@ def test_0d_tensor_backward(self): paddle.enable_static() - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): a = paddle.full(shape=[], dtype='float32', fill_value=-2.0) a.stop_gradient = False + a.persistable = True out = paddle.static.nn.cond(a >= 0, lambda: a, lambda: -a) - append_backward(out) + grad_list = append_backward(out) place = ( base.CUDAPlace(0) @@ -180,7 +182,13 @@ def test_0d_tensor_backward(self): ) exe = base.Executor(place) - ret = exe.run(main_program, fetch_list=[out.name, a.grad_name]) + if paddle.framework.in_pir_mode(): + for p, g in grad_list: + if p == a: + da = g + ret = exe.run(main_program, fetch_list=[out, da]) + else: + ret = exe.run(main_program, fetch_list=[out.name, a.grad_name]) np.testing.assert_allclose( np.asarray(ret[0]), np.array(2.0), rtol=1e-05 ) @@ -239,9 +247,9 @@ def false_func(): shape=[3, 4], dtype='float32', value=3 ), paddle.tensor.fill_constant(shape=[4, 5], dtype='int64', value=2) - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): pred = paddle.tensor.fill_constant( shape=[1], dtype='bool', value=True ) @@ -284,9 +292,9 @@ def false_func(a, i): a = a - (i - 1) return a - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): a = paddle.tensor.fill_constant( shape=[3, 2, 1], dtype='int32', value=7 ) @@ -332,9 +340,9 @@ def true_func(): def false_func(): return None - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): i = paddle.static.data(name="i", shape=[1], dtype='int32') pred = (i % 2) == 0 out1 = paddle.static.nn.cond(pred, true_func, false_func) @@ -374,9 +382,9 @@ def func_return_two_tensors(): shape=[3, 1], dtype='int32', value=7 ), paddle.tensor.fill_constant(shape=[3, 1], dtype='int32', value=8) - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): i = paddle.static.data(name="i", shape=[1], dtype='int32') pred = (i % 2) == 0 with self.assertRaises(TypeError): @@ -414,6 +422,7 @@ def func_return_two_tensors(): in str(e.exception) ) + @test_with_pir_api def test_extremely_simple_net_with_op_in_condition(self): paddle.enable_static() main_program = base.Program() @@ -423,12 +432,14 @@ def test_extremely_simple_net_with_op_in_condition(self): shape=[1], dtype='float32', value=1.23 ) a.stop_gradient = False + a.persistable = True b = paddle.tensor.fill_constant( shape=[1], dtype='float32', value=1.25 ) b.stop_gradient = False + b.persistable = True out = paddle.static.nn.cond(a - b < -1.0, lambda: a, lambda: b) - append_backward(out) + grad_list = append_backward(out) place = ( base.CUDAPlace(0) @@ -436,9 +447,17 @@ def test_extremely_simple_net_with_op_in_condition(self): else base.CPUPlace() ) exe = base.Executor(place) - ret = exe.run( - main_program, fetch_list=[out, b, a.grad_name, b.grad_name] - ) + if paddle.framework.in_pir_mode(): + for p, g in grad_list: + if p == a: + da = g + if p == b: + db = g + ret = exe.run(main_program, fetch_list=[out, b, da, db]) + else: + ret = exe.run( + main_program, fetch_list=[out, b, a.grad_name, b.grad_name] + ) # Note: fill_constant has loss of precision, you have to assertEqual # with values doens't lose precision in float-point number. self.assertEqual(ret[0][0], ret[1][0]) @@ -447,6 +466,7 @@ def test_extremely_simple_net_with_op_in_condition(self): class TestCondNestedControlFlow(unittest.TestCase): + # @test_with_pir_api def test_cond_inside_cond(self): """ pseudocode: @@ -463,7 +483,6 @@ def test_cond_inside_cond(self): else: return a / a """ - paddle.enable_static() def less_than_branch(i, a): @@ -480,19 +499,20 @@ def greater_equal_branch(i, a): lambda: paddle.divide(a, a), ) - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): i = paddle.static.data(name="i", shape=[1], dtype='float32') i.stop_gradient = False a = 2.0 * i + a.persistable = True out = paddle.static.nn.cond( i < 5.0, lambda: less_than_branch(i, a), lambda: greater_equal_branch(i, a), ) mean = paddle.mean(out) - append_backward(mean) + grad_list = append_backward(mean) place = ( base.CUDAPlace(0) @@ -508,14 +528,25 @@ def greater_equal_branch(i, a): else: expected_ret = expected_a * expected_a if feed_i < 8 else 1.0 expected_a_grad = 2.0 * expected_a if feed_i < 8 else 0.0 - ret = exe.run( - main_program, - feed={'i': np.full((1), feed_i, np.float32)}, - fetch_list=[out.name, a.grad_name], - ) + if paddle.framework.in_pir_mode(): + for p, g in grad_list: + if p == a: + da = g + ret = exe.run( + main_program, + feed={'i': np.full((1), feed_i)}, + fetch_list=[out, da], + ) + else: + ret = exe.run( + main_program, + feed={'i': np.full((1), feed_i, np.float32)}, + fetch_list=[out.name, a.grad_name], + ) self.assertEqual(ret[0][0], expected_ret) self.assertEqual(ret[1][0], expected_a_grad) + # @test_with_pir_api def test_cond_inside_cond_0d_tensor(self): """ pseudocode: @@ -549,19 +580,21 @@ def greater_equal_branch(i, a): lambda: a / 2, ) - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): i = paddle.full(fill_value=3.0, shape=[], dtype='float32') i.stop_gradient = False + i.persistable = True a = 2.0 * i + a.persistable = True out = paddle.static.nn.cond( i < 5.0, lambda: less_than_branch(i, a), lambda: greater_equal_branch(i, a), ) mean = paddle.mean(out) - append_backward(out) + grad_list = append_backward(mean) place = ( base.CUDAPlace(0) @@ -569,10 +602,16 @@ def greater_equal_branch(i, a): else base.CPUPlace() ) exe = base.Executor(place) - ret = exe.run( - main_program, - fetch_list=[out.name, i.grad_name], - ) + if paddle.framework.in_pir_mode(): + for p, g in grad_list: + if p == i: + di = g + ret = exe.run(main_program, fetch_list=[out, di]) + else: + ret = exe.run( + main_program, + fetch_list=[out.name, i.grad_name], + ) np.testing.assert_allclose( np.asarray(ret[0]), np.array(7.0), rtol=1e-05 ) @@ -582,20 +621,23 @@ def greater_equal_branch(i, a): ) self.assertEqual(ret[1].shape, ()) + # @test_with_pir_api def test_cond_op_in_condition(self): paddle.enable_static() - main_program = base.Program() - startup_program = base.Program() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() - with base.program_guard(main_program, startup_program): + with paddle.static.program_guard(main_program, startup_program): a = paddle.tensor.fill_constant( shape=[1], dtype='float32', value=1.23 ) a.stop_gradient = False + a.persistable = True b = paddle.tensor.fill_constant( shape=[1], dtype='float32', value=1.24 ) b.stop_gradient = False + b.persistable = True out = paddle.static.nn.cond( a < b, lambda: paddle.static.nn.cond( @@ -604,12 +646,12 @@ def test_cond_op_in_condition(self): lambda: paddle.multiply(a, b), ), lambda: paddle.static.nn.cond( - a == b, + paddle.equal(a, b), lambda: paddle.subtract(a, b), lambda: paddle.pow(a, b), ), ) - append_backward(out) + grad_list = append_backward(out) place = ( base.CUDAPlace(0) @@ -617,7 +659,17 @@ def test_cond_op_in_condition(self): else base.CPUPlace() ) exe = base.Executor(place) - ret = exe.run(main_program, fetch_list=[out, a.grad_name, b.grad_name]) + if paddle.framework.in_pir_mode(): + for p, g in grad_list: + if p == a: + da = g + if p == b: + db = g + ret = exe.run(main_program, fetch_list=[out, da, db]) + else: + ret = exe.run( + main_program, fetch_list=[out, a.grad_name, b.grad_name] + ) # Note: fill_constant has loss of precision, so we assertAlmostEqual. self.assertAlmostEqual(ret[0][0], 1.5252) self.assertAlmostEqual(ret[1][0], 1.24) @@ -630,11 +682,11 @@ def backward_value_helper(self, cond_func, use_cuda): Helper function that compares calculated backward value is close to dy/dx """ paddle.enable_static() - main_program = Program() + main_program = paddle.static.Program() main_program.random_seed = 123 - startup_program = Program() + startup_program = paddle.static.Program() startup_program.random_seed = 123 - with program_guard(main_program, startup_program): + with paddle.static.program_guard(main_program, startup_program): img = paddle.static.data( name='image', shape=[-1, 9], dtype='float32' ) @@ -691,9 +743,9 @@ def add_optimizer_helper(self, cond_func, use_cuda): """ Test that program is runnable when add optimizer """ - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): img = paddle.static.data( name='image', shape=[-1, 784], dtype='float32' )