diff --git a/mmengine/runner/log_processor.py b/mmengine/runner/log_processor.py index 30efac34bc..7d9e11a18a 100644 --- a/mmengine/runner/log_processor.py +++ b/mmengine/runner/log_processor.py @@ -129,7 +129,6 @@ def get_log_after_iter(self, runner, batch_idx: int, recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`. """ assert mode in ['train', 'test', 'val'] - current_loop = self._get_cur_loop(runner, mode) cur_iter = self._get_iter(runner, batch_idx=batch_idx) # Overwrite ``window_size`` defined in ``custom_cfg`` to int value. parsed_cfg = self._parse_windows_size(runner, batch_idx, @@ -166,7 +165,7 @@ def get_log_after_iter(self, runner, batch_idx: int, # Epoch(train) [ 9][010/270] # ... ||| ||| # Epoch(train) [ 10][100/270] - dataloader_len = len(current_loop.dataloader) + dataloader_len = self._get_dataloader_size(runner, mode) cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len))) if mode in ['train', 'val']: @@ -192,11 +191,10 @@ def get_log_after_iter(self, runner, batch_idx: int, log_str = (f'Iter({mode}) ' f'[{cur_iter_str}/{runner.max_iters}] ') else: - dataloader_len = len(current_loop.dataloader) + dataloader_len = self._get_dataloader_size(runner, mode) cur_iter_str = str(batch_idx + 1).rjust( len(str(dataloader_len))) - log_str = (f'Iter({mode}) [{cur_iter_str}' - f'/{len(current_loop.dataloader)}] ') + log_str = (f'Iter({mode}) [{cur_iter_str}/{dataloader_len}] ') # Concatenate lr, momentum string with log header. log_str += f'{lr_str} ' # If IterTimerHook used in runner, eta, time, and data_time should be @@ -251,8 +249,7 @@ def get_log_after_epoch(self, 'test', 'val' ], ('`_get_metric_log_str` only accept val or test mode, but got ' f'{mode}') - cur_loop = self._get_cur_loop(runner, mode) - dataloader_len = len(cur_loop.dataloader) + dataloader_len = self._get_dataloader_size(runner, mode) # By epoch: # Epoch(val) [10][1000/1000] ... @@ -547,3 +544,15 @@ def _get_cur_loop(self, runner, mode: str): return runner.val_loop else: return runner.test_loop + + def _get_dataloader_size(self, runner, mode) -> int: + """Get dataloader size of current loop. + + Args: + runner (Runner): The runner of the training/validation/testing + mode (str): Current mode of runner. + + Returns: + int: The dataloader size of current loop. + """ + return len(self._get_cur_loop(runner=runner, mode=mode).dataloader)