Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2Stat]Enhance nonlocal machanism while returning single var #43957

Merged
merged 6 commits into from
Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,13 @@ def convert_while_loop(cond, body, getter, setter):

def _run_paddle_while(cond, body, getter, setter):
# NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors.
def to_list(x):
if isinstance(x, (tuple, list)): return x
return [x]

# UndefinedVar will become data layer not check.
loop_vars = [to_static_variable(var) for var in to_list(getter())]
setter(loop_vars if len(loop_vars) > 1 else
loop_vars[0]) # change the non-local var to variable
loop_vars = [to_static_variable(var) for var in getter()]
setter(loop_vars) # change the non-local var to variable
# variable maybe modified to inner var. change it into
loop_vars = control_flow.while_loop(cond, body, loop_vars)
setter(loop_vars if len(loop_vars) > 1 else
loop_vars[0]) # change the non-local var to variable
setter(loop_vars) # change the non-local var to variable
return loop_vars


Expand Down Expand Up @@ -318,11 +313,11 @@ def _recover_args_state(outs, get_args, set_args, return_name_ids):
init_args = get_args()
# recover args state
num_outs = len(return_name_ids)
num_args = 1 if not isinstance(init_args, tuple) else len(init_args)
num_args = len(init_args)
assert num_outs <= num_args

if num_args == 1:
final_outs = outs
final_outs = (outs, )
else:
outs = (outs, ) if num_outs == 1 else outs
final_outs = outs + init_args[num_outs:]
Expand Down
12 changes: 5 additions & 7 deletions python/paddle/fluid/dygraph/dygraph_to_static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
PADDLE_MODULE_PREFIX = 'paddle.'
DYGRAPH_MODULE_PREFIX = 'paddle.fluid.dygraph'
DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.fluid.dygraph.dygraph_to_static'
GET_ARGS_FUNC_PREFIX = 'get_args'
SET_ARGS_FUNC_PREFIX = 'set_args'
ARGS_NAME = '__args'


class BaseNodeVisitor(gast.NodeVisitor):
Expand Down Expand Up @@ -1619,7 +1622,7 @@ def {func_name}():
template = """
def {func_name}():
nonlocal {nonlocal_vars}
return {vars}
return {vars},
"""
func_def = template.format(
func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX),
Expand All @@ -1628,11 +1631,6 @@ def {func_name}():
return gast.parse(textwrap.dedent(func_def)).body[0]


GET_ARGS_FUNC_PREFIX = 'get_args'
SET_ARGS_FUNC_PREFIX = 'set_args'
ARGS_NAME = '__args'


def create_set_args_node(names):
"""
Create set_args function as follows:
Expand Down Expand Up @@ -1661,7 +1659,7 @@ def {func_name}({args}):
template = """
def {func_name}({args}):
nonlocal {nonlocal_vars}
{vars} = {args}
{vars}, = {args}
"""
func_def = template.format(
func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def dyfunc_with_if_else(x_v, label=None):

def get_args_0():
nonlocal x_v
return x_v
return x_v,

def set_args_0(__args):
nonlocal x_v
x_v = __args
x_v, = __args

def true_fn_0():
nonlocal x_v
Expand All @@ -96,11 +96,11 @@ def false_fn_0():

def get_args_1():
nonlocal __return_value_0, label, x_v
return __return_value_0, label, x_v
return __return_value_0, label, x_v,

def set_args_1(__args):
nonlocal __return_value_0, label, x_v
__return_value_0, label, x_v = __args
__return_value_0, label, x_v, = __args

def true_fn_1():
nonlocal __return_value_0, label, x_v
Expand Down Expand Up @@ -131,11 +131,11 @@ def dyfunc_with_if_else(x_v, label=None):

def get_args_2():
nonlocal x_v
return x_v
return x_v,

def set_args_2(__args):
nonlocal x_v
x_v = __args
x_v, = __args

def true_fn_2():
nonlocal x_v
Expand All @@ -153,11 +153,11 @@ def false_fn_2():

def get_args_3():
nonlocal __return_value_1, label, x_v
return __return_value_1, label, x_v
return __return_value_1, label, x_v,

def set_args_3(__args):
nonlocal __return_value_1, label, x_v
__return_value_1, label, x_v = __args
__return_value_1, label, x_v, = __args

def true_fn_3():
nonlocal __return_value_1, label, x_v
Expand Down