-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
211 lines (164 loc) · 7.85 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import argparse
import json
import logging
import os
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
import pathvalidate
from kaggle.api.kaggle_api_extended import KaggleApi
from kaggle.models.kaggle_models_extended import Kernel
from sortedcontainers import SortedSet
MAX_PAGE_SIZE = 100
class DummyLen:
def __init__(self, length):
self.length = length
def __len__(self):
return self.length
def validate_positive_int(value):
"""Validator for positive integer arguments."""
try:
val = int(value)
if val <= 0:
raise argparse.ArgumentTypeError("Value must be a positive integer.")
return val
except ValueError:
raise argparse.ArgumentTypeError("Value must be a positive integer.")
def validate_filename(value):
"""Validator for filename argument."""
try:
p = Path(value)
p.parent.resolve(strict=True)
pathvalidate.validate_filename(p.name, platform="auto")
except (OSError, pathvalidate.ValidationError) as e:
raise argparse.ArgumentTypeError(f"Invalid filename: {e}")
return value
def get_kernels(api, user, page=1, include_private=True, page_size=MAX_PAGE_SIZE):
kernels = api.kernels_list(page=page,
user=user or api.get_config_value(api.CONFIG_NAME_USER),
sort_by="dateRun", page_size=page_size, mine=True)
if not include_private:
yield from (k for k in kernels if not getattr(k, 'isPrivate', getattr(k, 'isPrivateNullable')))
else:
yield from kernels
def kernel_identity(kernel):
return getattr(kernel, 'id'), getattr(kernel, 'ref', getattr(kernel, 'title'))
def kernel_to_path(kernel):
return Path(pathvalidate.sanitize_filename(f"{kernel.ref}#{kernel.id}", replacement_text="_"))
def fix_kernel_folder(path: Path, remove_private: bool = True) -> Optional[Path]:
"""
Fixes the name and location of a kernel folder.
This function takes a path to a kernel folder and performs the following actions:
1. Checks for the existence of the "kernel-metadata.json" file.
2. If the file exists, loads the metadata and checks if the kernel is private.
- If `remove_private` is True and the kernel is private, it removes the entire folder.
- Otherwise, it sanitizes the kernel name based on its ID and renames the folder if necessary.
3. Returns the path to the fixed kernel folder or None if the folder was removed.
Args:
path (Path): The path to the kernel folder.
remove_private (bool, optional): Whether to remove private kernels. Defaults to True.
Returns:
Optional[Path]: The path to the fixed kernel folder or None if the folder was removed.
"""
meta_path = Path(path, "kernel-metadata.json")
if not meta_path.exists():
logging.warning(f"Kernel metadata not found: {path}")
return path if not remove_private else None
with meta_path.open("rb") as f:
metadata = json.load(f)
if remove_private and metadata.get("is_private", metadata.get("isPrivate", True)):
logging.debug(f"Removing private kernel: {path}")
shutil.rmtree(path)
return None
new_path = Path(path.parent, pathvalidate.sanitize_filename(
f"{metadata['id']}#{metadata['id_no']}", replacement_text="_"))
if new_path != path and not new_path.exists():
logging.debug(f"Renaming kernel: {path} -> {new_path}")
shutil.move(path, new_path)
return new_path
return path
def _add_github_mask(value):
if value in (None, True, False):
return False
if value in range(0, 100):
return False
if isinstance(value, Kernel):
_add_github_mask(str(value))
_add_github_mask(getattr(value, 'ref'))
_add_github_mask(getattr(value, 'title'))
return value
if isinstance(value, Path):
_add_github_mask(str(value))
_add_github_mask(value.name)
_add_github_mask(value.stem)
return value
value = str(value)
if not value.strip():
return False
print(f'::add-mask::{value}')
return value
def main(include_private=False, max_page_size=MAX_PAGE_SIZE, user=None, output_name="kernels.zip",
tmp_dir_prefix="kaggle_", tmp_dir=None, add_mask=False):
parser = argparse.ArgumentParser(description="Download All Kaggle Kernels")
parser.add_argument("-o", "--output", type=validate_filename, default=output_name,
help=f"Name of the output zip file (default: {output_name})")
parser.add_argument("-p", "--include-private", action="store_true", default=include_private,
help=f"Include private kernels in the download (default: {include_private})")
parser.add_argument("-u", "--user", type=str, default=user,
help="Username of the Kaggle user to search kernels for (default: current user)")
parser.add_argument("-s", "--max-page-size", type=validate_positive_int, default=max_page_size,
help=f"Maximum number of kernels to download per page (default: {max_page_size})")
parser.add_argument("-t", "--tmp-dir", type=str, default=tmp_dir,
help=f"Path to the temporary directory (default: {tmp_dir})")
parser.add_argument("--add-mask", action="store_true", default=add_mask,
help=argparse.SUPPRESS)
args = parser.parse_args()
include_private = bool(args.include_private)
add_mask = bool(args.add_mask)
add_github_mask = _add_github_mask if add_mask else lambda value: None
api = KaggleApi()
api.authenticate()
with TemporaryDirectory(prefix=tmp_dir_prefix, dir=args.tmp_dir) as tmpdir:
kernels = DummyLen(args.max_page_size)
processed_kernels = SortedSet(key=kernel_identity)
page = 1
retry_later = SortedSet(key=kernel_identity)
while len(kernels) >= args.max_page_size:
kernels = SortedSet(get_kernels(api, args.user, page, include_private, args.max_page_size),
key=kernel_identity)
diff = kernels - processed_kernels
if not diff:
break
for kernel in diff:
path = Path(tmpdir, kernel_to_path(kernel))
add_github_mask(path)
try:
path.mkdir(parents=True, exist_ok=True)
api.kernels_pull(kernel.ref, path=path, metadata=True)
path = fix_kernel_folder(path, remove_private=not include_private)
add_github_mask(path)
except KeyboardInterrupt:
raise
except Exception as e:
logging.warning(e, exc_info=True)
retry_later.add(kernel)
finally:
processed_kernels.add(kernel)
page += 1
for kernel in retry_later:
try:
path = Path(tmpdir, kernel_to_path(kernel))
add_github_mask(path)
path.mkdir(parents=True, exist_ok=True)
api.kernels_pull(kernel.ref, path=path, metadata=True)
fix_kernel_folder(path, remove_private=not include_private)
except Exception: # pylint: disable=broad-except
logging.warning("Failed to download %r",
kernel.ref if not add_mask else 'hidden kernel name',
exc_info=True)
shutil.make_archive(str(Path(args.output).parent / Path(args.output).stem), 'zip', tmpdir)
if __name__ == '__main__':
env_include_private = os.getenv('KAGGLE_KERNELS_PRIVATE', '').lower() in ('true', '1', 'y', 'yes', 'ok')
env_add_mask = os.getenv('KAGGLE_KERNELS_MASK', '').lower() in ('true', '1', 'y', 'yes', 'ok')
main(include_private=env_include_private, add_mask=env_add_mask)