Skip to content

Commit

Permalink
minor fix on reqs download
Browse files Browse the repository at this point in the history
  • Loading branch information
blaisewf committed Aug 17, 2024
1 parent c0f1968 commit b3a16cb
Showing 1 changed file with 47 additions and 40 deletions.
87 changes: 47 additions & 40 deletions rvc/lib/tools/prerequisites_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,19 @@
}


def get_total_size(file_list):
def get_file_size_if_missing(file_list):
"""
Calculate the total size of files to be downloaded by sending HEAD requests to each file's URL.
Calculate the total size of files to be downloaded only if they do not exist locally.
"""
total_size = 0
for remote_folder, files in file_list:
local_folder = folder_mapping_list.get(remote_folder, "")
for file in files:
url = f"{url_base}/{remote_folder}{file}"
response = requests.head(url)
total_size += int(response.headers.get("content-length", 0))
destination_path = os.path.join(local_folder, file)
if not os.path.exists(destination_path):
url = f"{url_base}/{remote_folder}{file}"
response = requests.head(url)
total_size += int(response.headers.get("content-length", 0))
return total_size


Expand All @@ -78,14 +81,13 @@ def download_file(url, destination_path, global_bar):
Download a file from the given URL to the specified destination path,
updating the global progress bar as data is downloaded.
"""
if not os.path.exists(destination_path):
os.makedirs(os.path.dirname(destination_path) or ".", exist_ok=True)
response = requests.get(url, stream=True)
block_size = 1024
with open(destination_path, "wb") as file:
for data in response.iter_content(block_size):
file.write(data)
global_bar.update(len(data))
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
response = requests.get(url, stream=True)
block_size = 1024
with open(destination_path, "wb") as file:
for data in response.iter_content(block_size):
file.write(data)
global_bar.update(len(data))


def download_mapping_files(file_mapping_list, global_bar):
Expand All @@ -99,52 +101,57 @@ def download_mapping_files(file_mapping_list, global_bar):
local_folder = folder_mapping_list.get(remote_folder, "")
for file in file_list:
destination_path = os.path.join(local_folder, file)
url = f"{url_base}/{remote_folder}{file}"
futures.append(
executor.submit(download_file, url, destination_path, global_bar)
)
if not os.path.exists(destination_path):
url = f"{url_base}/{remote_folder}{file}"
futures.append(
executor.submit(
download_file, url, destination_path, global_bar
)
)
for future in futures:
future.result()


def calculate_total_size(pretraineds_v1, pretraineds_v2, models, exe):
"""
Calculate the total size of all files to be downloaded based on selected categories (pretraineds, models, executables).
Calculate the total size of all files to be downloaded based on selected categories.
"""
total_size = 0
if models:
total_size += get_total_size(models_list)
total_size += get_total_size(embedders_list)
total_size += get_file_size_if_missing(models_list)
total_size += get_file_size_if_missing(embedders_list)
if exe:
total_size += get_total_size(
total_size += get_file_size_if_missing(
executables_list if os.name == "nt" else linux_executables_list
)
if pretraineds_v1:
total_size += get_total_size(pretraineds_v1_list)
total_size += get_file_size_if_missing(pretraineds_v1_list)
if pretraineds_v2:
total_size += get_total_size(pretraineds_v2_list)
total_size += get_file_size_if_missing(pretraineds_v2_list)
return total_size


def prequisites_download_pipeline(pretraineds_v1, pretraineds_v2, models, exe):
"""
Manage the download pipeline for different categories of files (pretrained models, executables, etc.).
A single global progress bar tracks the cumulative progress of all downloads.
Manage the download pipeline for different categories of files.
"""
total_size = calculate_total_size(pretraineds_v1, pretraineds_v2, models, exe)

with tqdm(
total=total_size, unit="iB", unit_scale=True, desc="Downloading all files"
) as global_bar:
if models:
download_mapping_files(models_list, global_bar)
download_mapping_files(embedders_list, global_bar)
if exe:
download_mapping_files(
executables_list if os.name == "nt" else linux_executables_list,
global_bar,
)
if pretraineds_v1:
download_mapping_files(pretraineds_v1_list, global_bar)
if pretraineds_v2:
download_mapping_files(pretraineds_v2_list, global_bar)
if total_size > 0:
with tqdm(
total=total_size, unit="iB", unit_scale=True, desc="Downloading all files"
) as global_bar:
if models:
download_mapping_files(models_list, global_bar)
download_mapping_files(embedders_list, global_bar)
if exe:
download_mapping_files(
executables_list if os.name == "nt" else linux_executables_list,
global_bar,
)
if pretraineds_v1:
download_mapping_files(pretraineds_v1_list, global_bar)
if pretraineds_v2:
download_mapping_files(pretraineds_v2_list, global_bar)
else:
pass

0 comments on commit b3a16cb

Please sign in to comment.