Skip to content

Commit

Permalink
Count the dataloader length in _get_dataloader_size
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE committed Mar 13, 2023
1 parent 54f9714 commit 8defb2e
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions mmengine/runner/log_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def get_log_after_iter(self, runner, batch_idx: int,
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
"""
assert mode in ['train', 'test', 'val']
current_loop = self._get_cur_loop(runner, mode)
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
parsed_cfg = self._parse_windows_size(runner, batch_idx,
Expand Down Expand Up @@ -166,7 +165,7 @@ def get_log_after_iter(self, runner, batch_idx: int,
# Epoch(train) [ 9][010/270]
# ... ||| |||
# Epoch(train) [ 10][100/270]
dataloader_len = len(current_loop.dataloader)
dataloader_len = self._get_dataloader_size(runner, mode)
cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len)))

if mode in ['train', 'val']:
Expand All @@ -192,11 +191,10 @@ def get_log_after_iter(self, runner, batch_idx: int,
log_str = (f'Iter({mode}) '
f'[{cur_iter_str}/{runner.max_iters}] ')
else:
dataloader_len = len(current_loop.dataloader)
dataloader_len = self._get_dataloader_size(runner, mode)
cur_iter_str = str(batch_idx + 1).rjust(
len(str(dataloader_len)))
log_str = (f'Iter({mode}) [{cur_iter_str}'
f'/{len(current_loop.dataloader)}] ')
log_str = (f'Iter({mode}) [{cur_iter_str}/{dataloader_len}] ')
# Concatenate lr, momentum string with log header.
log_str += f'{lr_str} '
# If IterTimerHook used in runner, eta, time, and data_time should be
Expand Down Expand Up @@ -251,8 +249,7 @@ def get_log_after_epoch(self,
'test', 'val'
], ('`_get_metric_log_str` only accept val or test mode, but got '
f'{mode}')
cur_loop = self._get_cur_loop(runner, mode)
dataloader_len = len(cur_loop.dataloader)
dataloader_len = self._get_dataloader_size(runner, mode)

# By epoch:
# Epoch(val) [10][1000/1000] ...
Expand Down Expand Up @@ -547,3 +544,15 @@ def _get_cur_loop(self, runner, mode: str):
return runner.val_loop
else:
return runner.test_loop

def _get_dataloader_size(self, runner, mode) -> int:
"""Get dataloader size of current loop.
Args:
runner (Runner): The runner of the training/validation/testing
mode (str): Current mode of runner.
Returns:
int: The dataloader size of current loop.
"""
return len(self._get_cur_loop(runner=runner, mode=mode).dataloader)

0 comments on commit 8defb2e

Please sign in to comment.