diff --git a/demo.py b/demo.py index a5eb382..12fefab 100644 --- a/demo.py +++ b/demo.py @@ -57,7 +57,13 @@ def write_video_with_audio(audio_path, output_path, prefix='pred_'): ############################### I/O Settings ############################## # load config files opt = parser.parse_args() - device = torch.device(opt.device) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if torch.cuda.is_available(): + map_location=lambda storage, loc: storage.cuda() + else: + map_location='cpu' + with open(join('./config/', opt.id + '.yaml')) as f: config = yaml.load(f) data_root = join('./data/', opt.id) @@ -147,7 +153,7 @@ def write_video_with_audio(audio_path, output_path, prefix='pred_'): config['model_params']['APC']['hidden_size'], config['model_params']['APC']['num_layers'], config['model_params']['APC']['residual']) - APC_model.load_state_dict(torch.load(config['model_params']['APC']['ckp_path']), strict=False) + APC_model.load_state_dict(torch.load(config['model_params']['APC']['ckp_path'],map_location=map_location), strict=False) if opt.device == 'cuda': APC_model.cuda() APC_model.eval()