Skip to content

Commit

Permalink
【pir】 modify 5/6 case of test_cond.py with append_backward (#59732)
Browse files Browse the repository at this point in the history
* first modify

* clear modify

* modify if_grad2

* append_full_like

* add new test

* modify add_n
  • Loading branch information
xiaoguoguo626807 authored Dec 7, 2023
1 parent ee6d976 commit 99e84f0
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 120 deletions.
176 changes: 117 additions & 59 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,23 @@ def check_all_puts(block, inputs, outputs):
)


def append_full_like(float_value, value, state, backward_ops):
value_grad = paddle.full_like(
value,
float_value,
dtype=value.dtype,
)
full_like_op = value_grad.get_defining_op()
full_op = full_like_op.operand_source(1).get_defining_op()
update_bwdop_structure(
backward_ops,
state.op_to_opgrad[value.get_defining_op()],
[full_like_op, full_op],
)
state.value_to_valuegrad[value] = [[value_grad]]
return value_grad


def get_real_op_inputs(op):
if op.name() in ["pd_op.if", "pd_op.while"]:
return get_used_external_value(op)
Expand Down Expand Up @@ -101,19 +118,7 @@ def prepare_grad_outputs(grad_outputs, outputs, state):
# fwd : op1 -> op2 -> op3 -> output
# bwd : op1G <- op2G <- op3G <- outputG <- full_likeop/feedop
if grad is None:
output_grad = paddle.full_like(
output,
1.0,
dtype=output.dtype,
)
full_likeop = output_grad.get_defining_op()
fullop = full_likeop.operand_source(1).get_defining_op()
update_bwdop_structure(
backward_ops,
state.op_to_opgrad[output.get_defining_op()],
[full_likeop, fullop],
)
state.value_to_valuegrad[output] = [[output_grad]]
append_full_like(1.0, output, state, backward_ops)
else:
if output.shape != grad.shape:
raise ValueError(
Expand Down Expand Up @@ -150,21 +155,9 @@ def prepare_grad_outputs(grad_outputs, outputs, state):
[paddle.pir.fake_op_result()]
]
else:
grad_value = paddle.full_like(
opresult,
0.0,
opresult.dtype,
grad_value = append_full_like(
0.0, opresult, state, backward_ops
)
full_likeop = grad_value.get_defining_op()
fullop = full_likeop.operand_source(1).get_defining_op()

update_bwdop_structure(
backward_ops,
state.op_to_opgrad[opresult.get_defining_op()],
[full_likeop, fullop],
)
state.value_to_valuegrad[opresult] = [[grad_value]]

visited_output.add(opresult)

complete_outputs.append(opresult)
Expand Down Expand Up @@ -336,12 +329,15 @@ def inverse_sort_op(ops):

def append_backward_ops(
base_op,
base_inputs,
base_input_grads,
fwd_block,
bwd_block,
effective_forward_ops,
no_grad_set,
backward_ops,
state,
inside_value_to_outside_value_map,
):
'''
add grad_op in order of topological inverse sort
Expand Down Expand Up @@ -383,20 +379,22 @@ def append_backward_ops(
else continue to next op.
'''

def append_add_n(block, value):
def append_add_n(value):
# one value is input of more than one fwd_op,
# so more than one bwd_op create input_grad,
# need add sum op to accumulate gradient
paddle.add_n([item[0] for item in state.value_to_valuegrad[value]])
combineop = block.ops[len(block.ops) - 2]
sumop = block.ops[len(block.ops) - 1]
add_n_value = paddle.add_n(
[item[0] for item in state.value_to_valuegrad[value]]
)
add_n_op = add_n_value.get_defining_op()
combine_op = add_n_op.operand_source(0).get_defining_op()
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], [combineop, sumop]
backward_ops, state.op_to_opgrad[op], [combine_op, add_n_op]
)

state.value_to_valuegrad[value] = [[sumop.result(0)]]
for tmp in state.value_to_valuegrad[value]:
state.value_to_sumvaluegrad[value].append(tmp)
state.value_to_valuegrad[value] = [[add_n_value]]

def make_output_with_output_grad(op):
zero_flag = [False] * op.num_results()
Expand All @@ -408,14 +406,14 @@ def make_output_with_output_grad(op):
if value in control_flow_value_to_copyvalue_map
else [value]
)
if value in inside_value_to_outside_value_map:
while value in inside_value_to_outside_value_map:
value = inside_value_to_outside_value_map[value]

if (
value in state.value_to_valuegrad
and len(state.value_to_valuegrad[value]) > 1
):
append_add_n(bwd_block, value)
append_add_n(value)

if (
value not in state.value_to_valuegrad
Expand Down Expand Up @@ -445,23 +443,10 @@ def make_output_with_output_grad(op):
# second case:
# last bwd_op return None because input in no_grad_set,
# but this bwd_op need a input.
grad_value = paddle.full_like(
value,
0.0,
dtype=value.dtype,
)
full_likeop = grad_value.get_defining_op()
fullop = full_likeop.operand_source(1).get_defining_op()

update_bwdop_structure(
backward_ops,
state.op_to_opgrad[op],
[full_likeop, fullop],
)
append_full_like(0.0, value, state, backward_ops)
zero_flag[i] = True

state.value_to_valuegrad[value] = [[grad_value]]

outputs.append(new_value)
output_grads.append(state.value_to_valuegrad[value][0])

Expand Down Expand Up @@ -564,15 +549,25 @@ def update_input_grad_map(op, input_grads, origin_inputs):
state.value_to_valuegrad[input].append([input_grad])
i += 1

def append_yield(block, inputs):
def append_yield(block, base_inputs, base_inputs_grad):
with block:
inputs_grad = []
for value in inputs:
for value, value_grad in zip(base_inputs, base_inputs_grad):
if value_grad is None:
continue

while value in inside_value_to_outside_value_map:
value = inside_value_to_outside_value_map[value]

if value in state.value_to_valuegrad:
if len(state.value_to_valuegrad[value]) > 1:
append_add_n(block, value)
append_add_n(value)
inputs_grad.append(state.value_to_valuegrad[value][0][0])

else:
value_grad = append_full_like(
0.0, value, state, backward_ops
)
inputs_grad.append(value_grad)
paddle.base.libpaddle.pir.cf_yield(inputs_grad)

# there are four patterns:
Expand All @@ -586,8 +581,7 @@ def append_yield(block, inputs):
control_flow_value_to_copyvalue_map = {}
# tuple_push value to pop value
control_flow_copyvalue_to_value_map = {}
# sub_block op output to parent_block op output
inside_value_to_outside_value_map = {}

if (
len(effective_forward_ops) > 1
and effective_forward_ops[-1].name() == "cf.yield"
Expand All @@ -606,7 +600,6 @@ def append_yield(block, inputs):
for op in inverse_effective_forward_ops:
if op.name() != "builtin.combine" and op.name() != "builtin.split":
clear_effective_forward_ops.append(op)

with bwd_block:
for op in clear_effective_forward_ops:
if paddle.framework.core.has_vjp(op):
Expand Down Expand Up @@ -663,15 +656,21 @@ def append_yield(block, inputs):
op.blocks(), grad_op.blocks()
):
sub_state = state.copy(sub_fwd_block)
sub_inside_value_to_outside_value_map = (
inside_value_to_outside_value_map.copy()
)
sub_backward_ops = []
append_backward_ops(
op,
[input[0] for input in inputs],
[input_grad[0] for input_grad in input_grads],
sub_fwd_block,
sub_bwd_block,
sub_fwd_block.ops,
no_grad_set,
sub_backward_ops,
sub_state,
sub_inside_value_to_outside_value_map,
)
# update input_grad map
update_input_grad_map(op, input_grads, origin_inputs)
Expand Down Expand Up @@ -706,15 +705,15 @@ def append_yield(block, inputs):
if op.num_operands() == 0 and op.num_results() != 0:
for value in op.results():
if len(state.value_to_valuegrad[value]) > 1:
append_add_n(bwd_block, value)
append_add_n(value)
else:
state.op_to_opgrad[op] = []
else:
logging.warning("%s op has no grad op", op.name())
state.op_to_opgrad[op] = []

if fwd_block != bwd_block:
append_yield(bwd_block, get_used_external_value(base_op))
append_yield(bwd_block, base_inputs, base_input_grads)


def prepare_backward_prune_set(inputs, outputs):
Expand Down Expand Up @@ -810,16 +809,21 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
inputs, complete_outputs
)

# sub_block op output to parent_block op output
inside_value_to_outside_value_map = {}

append_backward_ops(
None,
None,
None,
block,
block,
effective_forward_ops,
no_grad_set,
backward_ops,
state,
inside_value_to_outside_value_map,
)

# now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue)
outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set(
outputs_fwd_set, inputs_fwd_set, no_grad_set, state
Expand Down Expand Up @@ -980,3 +984,57 @@ def grad(
input_grad = calc_gradient(outputs, inputs, grad_outputs, no_grad_set)

return input_grad


# only for test
def append_backward(loss, parameter_list=None, no_grad_set=None):
'''
Parameters:
loss (Value): The loss Tensor of the network
parameter_list (Value|list(Value)|tuple(Value)): List/Tuple of Parameters
that need to be updated by optimizers.
If it is None, all parameters
will be updated.
Default: None.
no_grad_vars (Value|list(Value)|tuple(Value)|set(Value), optional):
the Values whose gradients are not needed to compute. Default None.
Returns:
list of tuple (Value): Pairs of parameter and its corresponding gradients.
The key is the parameter and the value is gradient Tensor.
'''

check_type(
loss,
'loss',
(paddle.pir.Value, paddle.pir.OpResult),
'paddle.autograd.ir_backward.append_backward',
)

if parameter_list is not None:
check_type(
parameter_list,
'parameter_list',
(list, tuple, set),
'paddle.autograd.ir_backwardappend_backward',
)
for i, param in enumerate(parameter_list):
check_type(
param,
'parameter_list[%s]' % i,
(paddle.pir.Value, paddle.pir.OpResult),
'base.backward.append_backward',
)

else:
parameter_list = (
loss.get_defining_op().get_parent_block().all_parameters()
)

inputs_grad = paddle.autograd.ir_backward.grad(loss, parameter_list)

input_inputs_grad = []
for input, input_grad in zip(parameter_list, inputs_grad):
input_inputs_grad.append((input, input_grad))

return input_inputs_grad
5 changes: 5 additions & 0 deletions python/paddle/base/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2044,6 +2044,11 @@ def append_backward(
>>> p_g_list6 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights, no_grad_set=set(all_weights))
"""
if framework.in_pir_mode():
return paddle.autograd.ir_backward.append_backward(
loss, parameter_list, no_grad_set
)

grad_op_id_to_fwd_op = (
{}
) # for cuda graph usage, recording the mapping between grad op original id to fwd op
Expand Down
Loading

0 comments on commit 99e84f0

Please sign in to comment.