diff --git a/rvc/lib/tools/prerequisites_download.py b/rvc/lib/tools/prerequisites_download.py index 8fe7034c..3afef866 100644 --- a/rvc/lib/tools/prerequisites_download.py +++ b/rvc/lib/tools/prerequisites_download.py @@ -60,16 +60,19 @@ } -def get_total_size(file_list): +def get_file_size_if_missing(file_list): """ - Calculate the total size of files to be downloaded by sending HEAD requests to each file's URL. + Calculate the total size of files to be downloaded only if they do not exist locally. """ total_size = 0 for remote_folder, files in file_list: + local_folder = folder_mapping_list.get(remote_folder, "") for file in files: - url = f"{url_base}/{remote_folder}{file}" - response = requests.head(url) - total_size += int(response.headers.get("content-length", 0)) + destination_path = os.path.join(local_folder, file) + if not os.path.exists(destination_path): + url = f"{url_base}/{remote_folder}{file}" + response = requests.head(url) + total_size += int(response.headers.get("content-length", 0)) return total_size @@ -78,14 +81,13 @@ def download_file(url, destination_path, global_bar): Download a file from the given URL to the specified destination path, updating the global progress bar as data is downloaded. """ - if not os.path.exists(destination_path): - os.makedirs(os.path.dirname(destination_path) or ".", exist_ok=True) - response = requests.get(url, stream=True) - block_size = 1024 - with open(destination_path, "wb") as file: - for data in response.iter_content(block_size): - file.write(data) - global_bar.update(len(data)) + os.makedirs(os.path.dirname(destination_path), exist_ok=True) + response = requests.get(url, stream=True) + block_size = 1024 + with open(destination_path, "wb") as file: + for data in response.iter_content(block_size): + file.write(data) + global_bar.update(len(data)) def download_mapping_files(file_mapping_list, global_bar): @@ -99,52 +101,57 @@ def download_mapping_files(file_mapping_list, global_bar): local_folder = folder_mapping_list.get(remote_folder, "") for file in file_list: destination_path = os.path.join(local_folder, file) - url = f"{url_base}/{remote_folder}{file}" - futures.append( - executor.submit(download_file, url, destination_path, global_bar) - ) + if not os.path.exists(destination_path): + url = f"{url_base}/{remote_folder}{file}" + futures.append( + executor.submit( + download_file, url, destination_path, global_bar + ) + ) for future in futures: future.result() def calculate_total_size(pretraineds_v1, pretraineds_v2, models, exe): """ - Calculate the total size of all files to be downloaded based on selected categories (pretraineds, models, executables). + Calculate the total size of all files to be downloaded based on selected categories. """ total_size = 0 if models: - total_size += get_total_size(models_list) - total_size += get_total_size(embedders_list) + total_size += get_file_size_if_missing(models_list) + total_size += get_file_size_if_missing(embedders_list) if exe: - total_size += get_total_size( + total_size += get_file_size_if_missing( executables_list if os.name == "nt" else linux_executables_list ) if pretraineds_v1: - total_size += get_total_size(pretraineds_v1_list) + total_size += get_file_size_if_missing(pretraineds_v1_list) if pretraineds_v2: - total_size += get_total_size(pretraineds_v2_list) + total_size += get_file_size_if_missing(pretraineds_v2_list) return total_size def prequisites_download_pipeline(pretraineds_v1, pretraineds_v2, models, exe): """ - Manage the download pipeline for different categories of files (pretrained models, executables, etc.). - A single global progress bar tracks the cumulative progress of all downloads. + Manage the download pipeline for different categories of files. """ total_size = calculate_total_size(pretraineds_v1, pretraineds_v2, models, exe) - with tqdm( - total=total_size, unit="iB", unit_scale=True, desc="Downloading all files" - ) as global_bar: - if models: - download_mapping_files(models_list, global_bar) - download_mapping_files(embedders_list, global_bar) - if exe: - download_mapping_files( - executables_list if os.name == "nt" else linux_executables_list, - global_bar, - ) - if pretraineds_v1: - download_mapping_files(pretraineds_v1_list, global_bar) - if pretraineds_v2: - download_mapping_files(pretraineds_v2_list, global_bar) + if total_size > 0: + with tqdm( + total=total_size, unit="iB", unit_scale=True, desc="Downloading all files" + ) as global_bar: + if models: + download_mapping_files(models_list, global_bar) + download_mapping_files(embedders_list, global_bar) + if exe: + download_mapping_files( + executables_list if os.name == "nt" else linux_executables_list, + global_bar, + ) + if pretraineds_v1: + download_mapping_files(pretraineds_v1_list, global_bar) + if pretraineds_v2: + download_mapping_files(pretraineds_v2_list, global_bar) + else: + pass