diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index c42c5a0f06bab..34c94e3ba63ea 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -747,3 +747,38 @@ def validation_epoch_end(self, *_) -> None: train_data = DataLoader(RandomDataset(32, 64), batch_size=2) val_data = DataLoader(RandomDataset(32, 64), batch_size=2) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) + + +def test_no_batch_size_extraction_with_specifying_explictly(tmpdir): + batch_size = BoringModel().train_dataloader().batch_size + 10 + fast_dev_run = 2 + log_val = 7.0 + + class CustomBoringModel(BoringModel): + def on_before_batch_transfer(self, batch, *args, **kwargs): + # This is an ambiguous batch which have multiple potential batch sizes + if self.trainer.training: + batch = {"batch1": torch.randn(batch_size + 10, 10), "batch2": batch} + return batch + + def training_step(self, batch, batch_idx): + self.log("step_log_val", log_val, on_epoch=False) + self.log("epoch_log_val", log_val, batch_size=batch_size, on_step=False, on_epoch=True) + self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum") + return super().training_step(batch["batch2"], batch_idx) + + def training_epoch_end(self, *args, **kwargs): + results = self.trainer._results + assert results["training_step.step_log_val"].value == log_val + assert results["training_step.step_log_val"].cumulated_batch_size == 0 + assert results["training_step.epoch_log_val"].value == log_val * batch_size * fast_dev_run + assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size * fast_dev_run + assert results["training_step.epoch_sum_log_val"].value == log_val * fast_dev_run + + model = CustomBoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run) + + with pytest.warns(None) as record: + trainer.fit(model) + + assert not any("Trying to infer the `batch_size`" in warn.message.args[0] for warn in record.list)