Skip to content

Commit

Permalink
[PIR] skip bad case in test_cache_program.py and revert some changes (#…
Browse files Browse the repository at this point in the history
…59697)

* fix

* fix
  • Loading branch information
kangguangli authored Dec 7, 2023
1 parent 99e84f0 commit 0835df8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
4 changes: 4 additions & 0 deletions test/dygraph_to_static/test_cache_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
IrMode,
ToStaticMode,
disable_test_case,
enable_to_static_guard,
test_ast_only,
test_legacy_and_pt_and_pir,
Expand Down Expand Up @@ -172,6 +175,7 @@ def sum_under_while(limit):
return ret_sum


@disable_test_case((ToStaticMode.AST, IrMode.PT))
class TestToOutputWithCache(Dy2StTestBase):
def test_output(self):
ret = paddle.jit.to_static(sum_even_until_limit)(80, 10)
Expand Down
15 changes: 11 additions & 4 deletions test/dygraph_to_static/test_write_python_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@

import unittest

from dygraph_to_static_utils import (
Dy2StTestBase,
test_sot_only,
)
from dygraph_to_static_utils import Dy2StTestBase, test_ast_only, test_sot_only

import paddle

Expand Down Expand Up @@ -123,6 +120,16 @@ def test_write_container_sot(self):
out_dygraph = self.get_raw_value(self.func(input), self.getitem_path)
self.assertEqual(out_static, out_dygraph)

@test_ast_only
def test_write_container(self):
func_static = paddle.jit.to_static(self.func)
input = paddle.to_tensor([1, 2, 3])
out_static = self.get_raw_value(
func_static(input), self.getitem_path
).item()
out_dygraph = self.get_raw_value(self.func(input), self.getitem_path)
self.assertEqual(out_static, out_dygraph)


class TestLoopWriteContainerList(TestWriteContainer):
def set_func(self):
Expand Down

0 comments on commit 0835df8

Please sign in to comment.