Skip to content

Commit

Permalink
[Enhance] Add extra dataloader settings in configs. (#752)
Browse files Browse the repository at this point in the history
* Use `train_dataloader`, `val_dataloader` and `test_dataloader` settings
in the `data` field to specify different arguments.

* Fix bug

* Fix bug
  • Loading branch information
mzr1996 authored Apr 1, 2022
1 parent f0ee5dc commit 02c8f82
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 29 deletions.
50 changes: 29 additions & 21 deletions mmcls/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,27 @@ def train_model(model,
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]

sampler_cfg = cfg.data.get('sampler', None)

data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
round_up=True,
seed=cfg.seed,
sampler_cfg=sampler_cfg) for ds in dataset
]
# The default loader config
loader_cfg = dict(
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
round_up=True,
seed=cfg.get('seed'),
sampler_cfg=cfg.get('sampler', None),
)
# The overall dataloader settings
loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
# The specific dataloader settings
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}

data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]

# put model on gpus
if distributed:
Expand Down Expand Up @@ -169,13 +176,14 @@ def train_model(model,
# register eval hooks
if validate:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader(
val_dataset,
samples_per_gpu=cfg.data.samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False,
round_up=True)
# The specific dataloader settings
val_loader_cfg = {
**loader_cfg,
'shuffle': False, # Not shuffle by default
'sampler_cfg': None, # Not use sampler by default
**cfg.data.get('val_dataloader', {}),
}
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
Expand Down
32 changes: 24 additions & 8 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,32 @@ def main():
distributed = True
init_dist(args.launcher, **cfg.dist_params)

# build the dataloader
dataset = build_dataset(cfg.data.test, default_args=dict(test_mode=True))
# the extra round_up data will be removed during gpu/cpu collect
data_loader = build_dataloader(
dataset,
samples_per_gpu=cfg.data.samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,

# build the dataloader
# The default loader config
loader_cfg = dict(
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
shuffle=False,
round_up=True)
round_up=True,
)
# The overall dataloader settings
loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
test_loader_cfg = {
**loader_cfg,
'shuffle': False, # Not shuffle by default
'sampler_cfg': None, # Not use sampler by default
**cfg.data.get('test_dataloader', {}),
}
# the extra round_up data will be removed during gpu/cpu collect
data_loader = build_dataloader(dataset, **test_loader_cfg)

# build the model and load checkpoint
model = build_classifier(cfg.model)
Expand Down

0 comments on commit 02c8f82

Please sign in to comment.