diff --git a/src/accelerate/utils.py b/src/accelerate/utils.py index e25b80c4976..28a38539adb 100644 --- a/src/accelerate/utils.py +++ b/src/accelerate/utils.py @@ -201,7 +201,7 @@ def _send_to_device(t, device): def _has_to_method(t): return hasattr(t, "to") - return recursively_apply(_send_to_device, tensor, device, test_type=_has_to_method, error_on_other_type=True) + return recursively_apply(_send_to_device, tensor, device, test_type=_has_to_method) def get_data_structure(data): diff --git a/tests/test_utils.py b/tests/test_utils.py index ca617634dd9..9b16aba7bed 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,7 +20,7 @@ from accelerate.utils import send_to_device -TestNamedTuple = namedtuple("TestNamedTuple", "a b") +TestNamedTuple = namedtuple("TestNamedTuple", "a b c") class UtilsTester(unittest.TestCase): @@ -31,23 +31,26 @@ def test_send_to_device(self): result1 = send_to_device(tensor, device) self.assertTrue(torch.equal(result1.cpu(), tensor)) - result2 = send_to_device((tensor, [tensor, tensor]), device) + result2 = send_to_device((tensor, [tensor, tensor], 1), device) self.assertIsInstance(result2, tuple) self.assertTrue(torch.equal(result2[0].cpu(), tensor)) self.assertIsInstance(result2[1], list) self.assertTrue(torch.equal(result2[1][0].cpu(), tensor)) self.assertTrue(torch.equal(result2[1][1].cpu(), tensor)) + self.assertEqual(result2[2], 1) - result2 = send_to_device({"a": tensor, "b": [tensor, tensor]}, device) + result2 = send_to_device({"a": tensor, "b": [tensor, tensor], "c": 1}, device) self.assertIsInstance(result2, dict) self.assertTrue(torch.equal(result2["a"].cpu(), tensor)) self.assertIsInstance(result2["b"], list) self.assertTrue(torch.equal(result2["b"][0].cpu(), tensor)) self.assertTrue(torch.equal(result2["b"][1].cpu(), tensor)) + self.assertEqual(result2["c"], 1) - result3 = send_to_device(TestNamedTuple(a=tensor, b=[tensor, tensor]), device) + result3 = send_to_device(TestNamedTuple(a=tensor, b=[tensor, tensor], c=1), device) self.assertIsInstance(result3, TestNamedTuple) self.assertTrue(torch.equal(result3.a.cpu(), tensor)) self.assertIsInstance(result3.b, list) self.assertTrue(torch.equal(result3.b[0].cpu(), tensor)) self.assertTrue(torch.equal(result3.b[1].cpu(), tensor)) + self.assertEqual(result3.c, 1)