From 5d7e3d0176e0dbcf144c64b7d14d996c55e36c50 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 8 Jun 2024 20:50:14 -0700 Subject: [PATCH] [mis][ci/test] fix flaky test in test_sharded_state_loader.py (#5361) [mis][ci/test] fix flaky test in tests/test_sharded_state_loader.py (#5361) --- tests/test_sharded_state_loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index 022fb36b346f4..de79c3b945d4d 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -39,7 +39,8 @@ def test_filter_subtensors(): filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict) assert tuple(filtered_state_dict.keys()) == ("a", "b", "c") for key, tensor in filtered_state_dict.items(): - assert tensor.equal(state_dict[key]) + # NOTE: don't use `euqal` here, as the tensor might contain NaNs + assert tensor is state_dict[key] @pytest.fixture(scope="module")