diff --git a/miprometheus/problems/seq_to_seq/vqa/cog/cog.py b/miprometheus/problems/seq_to_seq/vqa/cog/cog.py index 5785a00f..7423fc6a 100644 --- a/miprometheus/problems/seq_to_seq/vqa/cog/cog.py +++ b/miprometheus/problems/seq_to_seq/vqa/cog/cog.py @@ -37,6 +37,7 @@ import gzip import json import os +import tarfile import numpy as np from miprometheus.problems.seq_to_seq.vqa.vqa_problem import VQAProblem @@ -285,14 +286,14 @@ def source_dataset(self, params): elif self.dataset_type == 'hard': self.download = self.CheckAndDownload(self.data_folder_child, 'https://storage.googleapis.com/cog-datasets/data_8_7_10.tar') - if self.download: - print('\nDownload complete. Extracting...') - tar = tarfile.open(os.path.expanduser('~/data/downloaded')) - tar.extractall(path=self.data_folder_main) - tar.close() - print('\nDone! Cleaning up.') - os.remove(os.path.expanduser('~/data/downloaded')) - print('\nClean-up complete! Dataset ready.') + if self.download: + print('\nDownload complete. Extracting...') + tar = tarfile.open(os.path.expanduser('~/data/downloaded')) + tar.extractall(path=self.data_folder_main) + tar.close() + print('\nDone! Cleaning up.') + os.remove(os.path.expanduser('~/data/downloaded')) + print('\nClean-up complete! Dataset ready.') else: self.download = self.CheckAndDownload(self.data_folder_child)