Skip to content

Commit

Permalink
test(DataGenerator): added iteration & repetition testing as well as …
Browse files Browse the repository at this point in the history
…batch access
  • Loading branch information
muellerdo committed Feb 26, 2023
1 parent 2e8f2a9 commit 5568cfa
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 23 deletions.
38 changes: 27 additions & 11 deletions tests/test_datagenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_RUN_2D_GRAYSCALE_noLabel(self):
data_gen = DataGenerator(self.sampleList_gray_2D, self.tmp_data.name,
grayscale=True, batch_size=5)
for i in range(0, 10):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(len(batch), 1)
self.assertTrue(np.array_equal(batch[0].shape, (5, 224, 224, 1)))

Expand All @@ -113,7 +113,7 @@ def test_RUN_2D_RGB_noLabel(self):
data_gen = DataGenerator(self.sampleList_rgb_2D, self.tmp_data.name,
grayscale=False, batch_size=5)
for i in range(0, 10):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(len(batch), 1)
self.assertTrue(np.array_equal(batch[0].shape, (5, 224, 224, 3)))

Expand All @@ -123,7 +123,7 @@ def test_RUN_2D_withLabel(self):
labels=self.labels_ohe,
grayscale=False, batch_size=5)
for i in range(0, 10):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(len(batch), 2)
self.assertTrue(np.array_equal(batch[1].shape, (5, 4)))

Expand All @@ -137,7 +137,7 @@ def test_RUN_3D_GRAYSCALE_noLabel(self):
loader=numpy_loader, resize=None,
standardize_mode=None)
for i in range(0, 10):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(len(batch), 1)
self.assertTrue(np.array_equal(batch[0].shape, (5, 16, 16, 16, 1)))

Expand All @@ -148,7 +148,7 @@ def test_RUN_3D_RGB_noLabel(self):
loader=numpy_loader, resize=None,
standardize_mode=None)
for i in range(0, 10):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(len(batch), 1)
self.assertTrue(np.array_equal(batch[0].shape, (5, 16, 16, 16, 3)))

Expand All @@ -160,7 +160,7 @@ def test_RUN_3D_withLabel(self):
loader=numpy_loader, resize=None,
standardize_mode=None)
for i in range(0, 10):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(len(batch), 2)
self.assertTrue(np.array_equal(batch[1].shape, (5, 4)))

Expand All @@ -173,7 +173,7 @@ def test_RUN_Metadata_noLabel(self):
metadata=self.metadata, grayscale=False,
batch_size=5)
for i in range(0, 10):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(len(batch), 1)
self.assertTrue(len(batch[0]) == 2)
self.assertTrue(np.array_equal(batch[0][0].shape, (5, 224, 224, 3)))
Expand All @@ -185,7 +185,7 @@ def test_RUN_Metadata_withLabel(self):
labels=self.labels_ohe, metadata=self.metadata,
grayscale=False, batch_size=5)
for i in range(0, 10):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(len(batch), 2)
self.assertTrue(np.array_equal(batch[1].shape, (5, 4)))
self.assertTrue(len(batch[0]) == 2)
Expand All @@ -200,7 +200,7 @@ def test_MP(self):
labels=self.labels_ohe,
grayscale=False, batch_size=5, workers=5)
for i in range(0, 10):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(len(batch), 2)
self.assertTrue(np.array_equal(batch[1].shape, (5, 4)))

Expand All @@ -214,7 +214,7 @@ def test_PrepareImages(self):
precprocessed_images = os.listdir(data_gen.prepare_dir)
self.assertTrue(len(precprocessed_images), len(self.sampleList_rgb_2D))
for i in range(0, 10):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(len(batch), 2)
self.assertTrue(np.array_equal(batch[1].shape, (5, 4)))
shutil.rmtree(data_gen.prepare_dir)
Expand All @@ -226,7 +226,23 @@ def test_PrepareImages_MP(self):
precprocessed_images = os.listdir(data_gen.prepare_dir)
self.assertTrue(len(precprocessed_images), len(self.sampleList_rgb_2D))
for i in range(0, 10):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(len(batch), 2)
self.assertTrue(np.array_equal(batch[1].shape, (5, 4)))
shutil.rmtree(data_gen.prepare_dir)

#-------------------------------------------------#
# Utilization #
#-------------------------------------------------#
# Class Creation
def test_utils_iter(self):
data_gen = DataGenerator(self.sampleList_rgb_2D, self.tmp_data.name,
batch_size=8)
counter = 0
for batch in data_gen:
if counter < 3:
self.assertTrue(np.array_equal(batch[0].shape, (8,224,224,3)))
else:
self.assertTrue(np.array_equal(batch[0].shape, (1,224,224,3)))
counter += 1
self.assertTrue(counter == 4)
10 changes: 5 additions & 5 deletions tests/test_ioloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_image_loader_DataGenerator(self):
data_gen = DataGenerator(sample_list, tmp_data.name, resize=None,
grayscale=False, batch_size=2)
for i in range(0, 3):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(np.array_equal(batch[0].shape, (2, 16, 16, 3)))

# Test for grayscale images
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_numpy_loader_DataGenerator(self):
resize=None, two_dim=False, standardize_mode=None,
grayscale=True, batch_size=2)
for i in range(0, 3):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(np.array_equal(batch[0].shape, (2, 16, 16, 16, 1)))

# Test for grayscale 2D images
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_sitk_loader_DataGenerator(self):
resize=None, standardize_mode=None,
grayscale=True, batch_size=1)
for i in range(0, 6):
batch = next(data_gen)
batch = data_gen[i]
if i < 3:
self.assertTrue(np.array_equal(batch[0].shape, (1, 32, 24, 8, 1)))
else:
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_sitk_loader_Resampling(self):
resize=None, standardize_mode=None,
grayscale=True, batch_size=1)
for i in range(0, 6):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(np.array_equal(batch[0].shape, (1, 18, 10, 10, 1)))

#-------------------------------------------------#
Expand All @@ -301,7 +301,7 @@ def test_cache_loader_DataGenerator(self):
resize=None, two_dim=False, standardize_mode=None,
grayscale=True, batch_size=2, cache=cache)
for i in range(0, 3):
batch = next(data_gen)
batch = data_gen[i]
self.assertTrue(np.array_equal(batch[0].shape, (2, 16, 16, 16, 1)))

# Test for grayscale 2D images
Expand Down
19 changes: 13 additions & 6 deletions tests/test_neuralnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setUpClass(self):

# Create RGB data
self.sampleList_rgb = []
for i in range(0, 1):
for i in range(0, 10):
img_rgb = np.random.rand(32, 32, 3) * 255
imgRGB_pillow = Image.fromarray(img_rgb.astype(np.uint8))
index = "image.sample_" + str(i) + ".RGB.png"
Expand All @@ -51,8 +51,8 @@ def setUpClass(self):
self.sampleList_rgb.append(index)

# Create classification labels
self.labels_ohe = np.zeros((1, 4), dtype=np.uint8)
for i in range(0, 1):
self.labels_ohe = np.zeros((10, 4), dtype=np.uint8)
for i in range(0, 10):
class_index = np.random.randint(0, 4)
self.labels_ohe[i][class_index] = 1

Expand All @@ -61,7 +61,8 @@ def setUpClass(self):
self.tmp_data.name,
labels=self.labels_ohe,
resize=(32, 32),
grayscale=False, batch_size=1)
shuffle=True,
grayscale=False, batch_size=3)

#-------------------------------------------------#
# Model Training #
Expand All @@ -79,6 +80,11 @@ def test_training_iterations(self):
self.assertTrue("loss" in hist)
self.assertTrue(len(hist["loss"]) == 5)

hist = model.train(training_generator=self.datagen,
epochs=3, iterations=2)
self.assertTrue("loss" in hist)
self.assertTrue(len(hist["loss"]) == 3)

def test_training_validation(self):
model = NeuralNetwork(n_labels=4, channels=3, batch_queue_size=1)
hist = model.train(training_generator=self.datagen,
Expand All @@ -101,5 +107,6 @@ def test_training_transferlearning(self):
def test_predict(self):
model = NeuralNetwork(n_labels=4, channels=3, batch_queue_size=1)
preds = model.predict(self.datagen)
self.assertTrue(preds.shape == (1, 4))
self.assertTrue(np.sum(preds) >= 0.99 and np.sum(preds) <= 1.01)
self.assertTrue(preds.shape == (10, 4))
for i in range(0, 10):
self.assertTrue(np.sum(preds[i]) >= 0.99 and np.sum(preds[i]) <= 1.01)
2 changes: 1 addition & 1 deletion tests/test_xai.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def setUpClass(self):
# Compute predictions
self.preds = self.model.predict(self.datagen)
# Initialize testing image
self.image = next(self.datagen)[0][[0]]
self.image = self.datagen[0][0][[0]]

#-------------------------------------------------#
# XAI Functions: Decoder #
Expand Down

0 comments on commit 5568cfa

Please sign in to comment.