-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
201 lines (166 loc) · 6.6 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import inspect
import os
import zipfile
from typing import Callable
import gdown
from sultan.api import Sultan
from tqdm import tqdm
EMNIST_DIGITS_URL = "https://drive.google.com/file/d/1fl9fRPPxTUxnC56ACzZ8JiLiew0SMFwt/view?usp=share_link"
EMNIST_DIGITS_MD5 = "5a18b33e88e3884e79f8b2d6274564d7"
EMNIST_DIGITS_FILE = "emnist-digits.mat"
EMNIST_LETTERS_URL = "https://drive.google.com/file/d/1KpwKUfx5L8zN0gPyrEuCEDFnC0OlEOoM/view?usp=share_link"
EMNIST_LETTERS_MD5 = "b9eddc3e325dee05b65fb21ee45da52f"
EMNIST_LETTERS_FILE = "emnist-letters.mat"
REDDIT_URL = (
"https://drive.google.com/file/d/1S5R11eDR6g9Lkb02Ht6zmP2W-ZYqtS5U/view?usp=sharing"
)
REDDIT_MD5 = "308519a1945cf9268707d0aeba2b1c96"
REDDIT_FILE = "reddit_leaf.zip"
SHAKESPEARE_URL = "https://drive.google.com/file/d/1YkFqXqvN1s1JE1lEQ5YHb-aBPgmQ2IIk/view?usp=share_link"
SHAKESPEARE_MD5 = "a75e79f9ecebe5118314227a9cd19bc9"
SHAKESPEARE_FILE = "shakespeare_leaf.zip"
def check_integrity(filename: str, md5_hash: str) -> bool:
"""Function that computes the md5 hash of a file and compares it
with a given one, ensuring that the file corresponds to the given md5 hash.
Args:
----
filename (str): path to file which will be used to compute a md5 hash
md5_hash (str): md5 hash to compare with the one compute using filename
Returns:
-------
bool: whether the given file has the same hash as the one provided
"""
with Sultan.load() as s:
try:
result = s.md5("-q", filename).run()
except Exception:
result = s.md5sum(filename).pipe().cut("-f", "1", "-d", '" "').run()
computed_md5 = result.stdout[0]
return computed_md5 == md5_hash
def check_file_exists(filename: str) -> bool:
"""Function that checks if a given file exits or not.
Args:
----
filename (str): Filename to check.
Returns:
-------
bool: True/False if the file exits or not.
"""
return os.path.isfile(filename)
def check_dir_exists(filename: str) -> bool:
"""Function that checks if a given directory exits or not.
Args:
----
filename (str): Directory to check.
Returns:
-------
bool: True/False if the directory exits or not.
"""
return os.path.exists(filename)
def extract_zip(filename: str, output: bool = True):
"""Function that extract a zip file. If files are already extracted, it skips
extracting them.
Args:
----
filename (str): Directory to check.
output (bool): Whether to output the paths of the extracted files
Returns:
-------
bool: True/False if the directory exits or not.
"""
base_dir = "/".join(filename.split("/")[:-1])
with zipfile.ZipFile(filename, "r", allowZip64=True) as zip_file:
for member in zip_file.namelist():
extracted_file_path = f"{base_dir}/{member}"
if not (
os.path.isfile(extracted_file_path)
or os.path.exists(extracted_file_path)
):
zip_file.extract(member, base_dir)
if output:
return [f"{base_dir}/{member}" for member in zip_file.namelist()]
def download_file(url: str, filename: str, out_dir: str = "."):
"""Function that downloads a file from a url and stores it in out_dir
with name filename.
Args:
----
url (str): url to download the file
filename (str): name used to store the downloaded file
out_dir (str, optional): directory where the downloaded file will be stored. Defaults to ".".
"""
additional_args = ()
out_path = os.path.join(out_dir, filename)
if "drive.google" not in url:
additional_args = ("-#", "-L", "--output", out_path)
with Sultan.load() as s:
try:
result = s.curl(url, *additional_args).run(streaming=True)
except Exception:
result = s.wget(url, "-O", out_path).run(streaming=True)
def generator():
while True:
complete = result.is_complete
if complete:
break
yield from result.stderr
pbar = tqdm(generator())
for i in pbar:
pbar.set_description(i)
else: # Google drive download using gdown library
gdown.download(url, quiet=False, fuzzy=True, output=out_path)
# Example: download_dataset(MNIST_URL, MNIST_FILE, MNIST_MD5, extract=True)
def download_dataset(
url: str,
filename: str,
md5_hash: str,
out_dir: str = ".",
extract: bool = False,
max_trials: int = 3,
output: bool = True,
):
"""Function that download a dataset given an URL.
Args:
----
url (str): url to download the file
filename (str): name used to store the downloaded file
md5_hash (str): hash used to ensure the integrity of the downloaded file
out_dir (str, optional): directory where the downloaded file will be stored. Defaults to ".".
extract (bool, optional): Select to whether to extract the data or not. Defaults to False.
max_trials (int, optional): Max number of trials to download the dataset. Defaults to 3.
output (bool, optional): whether to return a list with the paths of the downloaded/extractred files. Defaults to True.
Raises:
------
ValueError: Raise an error if it fails downloading the dataset or the given md5 hash is not correct.
"""
full_path = os.path.join(out_dir, filename)
check_dir_exists(out_dir)
i = 0
while not (check_file_exists(full_path) and check_integrity(full_path, md5_hash)):
download_file(url, filename, out_dir)
i += 1
if i > max_trials:
raise ValueError(
"Either we are unable the download the file or the provided md5 hash is not correct."
)
if extract:
extracted_files = extract_zip(full_path, output=output)
if output:
return (
[os.path.join(out_dir, i) for i in extracted_files]
if extract
else full_path
)
def check_min_arguments(func: Callable, min_args: int = 1):
"""Function that inspect the minumum number of arguments of a given function.
Args:
----
func (Callable): Function to inspect
min_args (int, optional): Minimum number of arguments that the function
func must have. Defaults to 1.
Raises:
------
AssertionError: Raise an assertion error if the number of arguments of the
given function is lower than the min_args value.
"""
signature = inspect.signature(func)
return len(signature.parameters) >= min_args