From c7c81800380467a0bbcc5dfde9766b86113323a0 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:18:49 -0800 Subject: [PATCH] Fix `PyDatasetAdapterTest::test_class_weight` test with Torch on GPU. The test was failing because arrays on device and on cpu were compared. --- keras/src/trainers/data_adapters/py_dataset_adapter_test.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter_test.py b/keras/src/trainers/data_adapters/py_dataset_adapter_test.py index 65ef5b8de60..d451bf5409e 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter_test.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter_test.py @@ -240,12 +240,11 @@ def test_class_weight(self): for index, batch in enumerate(gen): # Batch is a tuple of (x, y, class_weight) self.assertLen(batch, 3) + batch = [backend.convert_to_numpy(x) for x in batch] # Let's verify the data and class weights match for each element # of the batch (2 elements in each batch) for sub_elem in range(2): - self.assertTrue( - np.array_equal(batch[0][sub_elem], x[index * 2 + sub_elem]) - ) + self.assertAllEqual(batch[0][sub_elem], x[index * 2 + sub_elem]) self.assertEqual(batch[1][sub_elem], y[index * 2 + sub_elem]) class_key = np.int32(batch[1][sub_elem]) self.assertEqual(batch[2][sub_elem], class_w[class_key])