diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 83cbc6694fae5..2f3cfbf8a7639 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -99,16 +99,16 @@ def training_step(self, batch, batch_idx): print("pig memory: ", memory_start) lower = 101 * self.num_params * 4 upper = 201 * self.num_params * 4 - assert lower < torch.cuda.memory_allocated(self.device) - memory_start - assert torch.cuda.memory_allocated(self.device) - memory_start < upper + assert lower < torch.cuda.memory_allocated(0) - memory_start + assert torch.cuda.memory_allocated(0) - memory_start < upper return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): # there is a batch and the boring model, but not two batches on gpu, assume 32 bit = 4 bytes lower = 101 * self.num_params * 4 upper = 201 * self.num_params * 4 - assert lower < torch.cuda.memory_allocated(self.device) - memory_start - assert torch.cuda.memory_allocated(self.device) - memory_start < upper + assert lower < torch.cuda.memory_allocated(0) - memory_start + assert torch.cuda.memory_allocated(0) - memory_start < upper return super().validation_step(batch, batch_idx) torch.cuda.empty_cache()