Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Nov 9, 2021
1 parent 308f9c0 commit 04314b8
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,3 +747,38 @@ def validation_epoch_end(self, *_) -> None:
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)


def test_no_batch_size_extraction_with_specifying_explictly(tmpdir):
batch_size = BoringModel().train_dataloader().batch_size + 10
fast_dev_run = 2
log_val = 7.0

class CustomBoringModel(BoringModel):
def on_before_batch_transfer(self, batch, *args, **kwargs):
# This is an ambiguous batch which have multiple potential batch sizes
if self.trainer.training:
batch = {"batch1": torch.randn(batch_size + 10, 10), "batch2": batch}
return batch

def training_step(self, batch, batch_idx):
self.log("step_log_val", log_val, on_epoch=False)
self.log("epoch_log_val", log_val, batch_size=batch_size, on_step=False, on_epoch=True)
self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum")
return super().training_step(batch["batch2"], batch_idx)

def training_epoch_end(self, *args, **kwargs):
results = self.trainer._results
assert results["training_step.step_log_val"].value == log_val
assert results["training_step.step_log_val"].cumulated_batch_size == 0
assert results["training_step.epoch_log_val"].value == log_val * batch_size * fast_dev_run
assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size * fast_dev_run
assert results["training_step.epoch_sum_log_val"].value == log_val * fast_dev_run

model = CustomBoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run)

with pytest.warns(None) as record:
trainer.fit(model)

assert not any("Trying to infer the `batch_size`" in warn.message.args[0] for warn in record.list)

0 comments on commit 04314b8

Please sign in to comment.