diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py index 6b4b0096e5f..70dc704168f 100644 --- a/mmdet/apis/inference.py +++ b/mmdet/apis/inference.py @@ -39,8 +39,7 @@ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None): config.model.train_cfg = None model = build_detector(config.model, test_cfg=config.get('test_cfg')) if checkpoint is not None: - map_loc = 'cpu' if device == 'cpu' else None - checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc) + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') if 'CLASSES' in checkpoint.get('meta', {}): model.CLASSES = checkpoint['meta']['CLASSES'] else: