Skip to content

Commit

Permalink
Fix PyDatasetAdapterTest::test_class_weight test with Torch on GPU. (
Browse files Browse the repository at this point in the history
…keras-team#20665)

The test was failing because arrays on device and on cpu were compared.
  • Loading branch information
hertschuh authored and shashaka committed Dec 20, 2024
1 parent b78fd9a commit aa72cb8
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions keras/src/trainers/data_adapters/py_dataset_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,11 @@ def test_class_weight(self):
for index, batch in enumerate(gen):
# Batch is a tuple of (x, y, class_weight)
self.assertLen(batch, 3)
batch = [backend.convert_to_numpy(x) for x in batch]
# Let's verify the data and class weights match for each element
# of the batch (2 elements in each batch)
for sub_elem in range(2):
self.assertTrue(
np.array_equal(batch[0][sub_elem], x[index * 2 + sub_elem])
)
self.assertAllEqual(batch[0][sub_elem], x[index * 2 + sub_elem])
self.assertEqual(batch[1][sub_elem], y[index * 2 + sub_elem])
class_key = np.int32(batch[1][sub_elem])
self.assertEqual(batch[2][sub_elem], class_w[class_key])
Expand Down

0 comments on commit aa72cb8

Please sign in to comment.