Skip to content

Commit

Permalink
[Dy2St][PIR] Create fake value for each UndefinedVar in for-loop (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Jan 29, 2024
1 parent b97ce0b commit b88fd78
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 12 deletions.
2 changes: 1 addition & 1 deletion python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def __init__(self, var, start=0):
self.var = var
self.len = convert_len(var)
if isinstance(self.len, (Variable, Value)):
self.rag = paddle.arange(start, start + self.len, 1, paddle.int64)
self.rag = paddle.arange(start, start + self.len, 1, "int64")
else:
self.rag = range(start, start + self.len)

Expand Down
28 changes: 26 additions & 2 deletions python/paddle/static/nn/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,13 +750,31 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
)

if in_pir_mode():
while_op = build_while_op(pre_cond, flatten(loop_vars))
from paddle.jit.dy2static.utils import UndefinedVar

def create_fake_value_for_undefined_var():
# Create a fake value for create WhileOp, it's type will be reset after body is executed.
return paddle.full(shape=[], fill_value=0)

flattened_loop_vars = flatten(loop_vars)

undefined_var_mapping = {
idx: create_fake_value_for_undefined_var()
for idx, var in enumerate(flattened_loop_vars)
if isinstance(var, UndefinedVar)
}
unified_loop_vars = [
undefined_var_mapping[idx] if isinstance(var, UndefinedVar) else var
for idx, var in enumerate(flattened_loop_vars)
]
while_op = build_while_op(pre_cond, unified_loop_vars)
with while_op.body() as cur_block:
args = pack_sequence_as(loop_vars, cur_block.args())
next_vars = body(*args)

try:
assert_same_structure(
flatten(next_vars), flatten(loop_vars), check_types=False
flatten(next_vars), unified_loop_vars, check_types=False
)
except ValueError as e:
raise ValueError(
Expand All @@ -768,6 +786,12 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
next_cond = cond(*next_vars)
next_cond.stop_gradient = True
cf_yield([next_cond, *flatten(next_vars)])
# Reset type of UndefinedVar from next_vars
for idx, value in undefined_var_mapping.items():
value_new_type = flatten(next_vars)[idx].type()
value.set_type(value_new_type)
cur_block.args()[idx].set_type(value_new_type)
while_op.as_operation().results()[idx].set_type(value_new_type)
return pack_sequence_as(loop_vars, while_op.optimize_update())

if in_dygraph_mode():
Expand Down
6 changes: 1 addition & 5 deletions test/dygraph_to_static/test_for_enumerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def set_test_func(self):


class TestForInRange(TestForInRangeConfig):
@test_legacy_and_pt_and_pir
def test_transformed_result_compare(self):
self.set_test_func()
self.transformed_result_compare()
Expand Down Expand Up @@ -484,11 +485,6 @@ class TestForEnumerateVarWithNestedRange(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_enumerate_var_with_nested_range

# Remove this if we support control flow
def test_transformed_result_compare(self):
self.set_test_func()
self.transformed_result_compare()


class TestForIterVarList(TestForInRangeConfig):
def set_test_func(self):
Expand Down
4 changes: 0 additions & 4 deletions test/dygraph_to_static/test_ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ def _run_dygraph(self, to_static=False):
ret = self.dyfunc(x_v)
return ret.numpy()

# Why add test_legacy_only? : PIR not support if true and false branch output with different dtype
@test_legacy_and_pt_and_pir
def test_ast_to_func(self):
self.assertTrue((self._run_dygraph() == self._run_static()).all())
Expand Down Expand Up @@ -338,7 +337,6 @@ def _run(self, to_static=False):
ret = net(x_v)
return ret.numpy()

# Why add test_legacy_only? : PIR not support if true and false branch output with different rank
@test_legacy_only
def test_ast_to_func(self):
self.assertTrue((self._run_dygraph() == self._run_static()).all())
Expand Down Expand Up @@ -510,7 +508,6 @@ def get_dy2stat_out(self):
out = static_func(self.x)
return out

# Why add test_legacy_only? : PIR not support if true and false branch output with different rank
@test_ast_only
@test_legacy_and_pt_and_pir
def test_ast_to_func(self):
Expand All @@ -525,7 +522,6 @@ def setUp(self):
self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int3)
self.out = self.get_dy2stat_out()

# Why add test_legacy_only? : PIR not support if true and false branch output with different rank
@test_ast_only
@test_legacy_and_pt_and_pir
def test_ast_to_func(self):
Expand Down

0 comments on commit b88fd78

Please sign in to comment.