From 6e5b8441077c55f596022b15900de89f8a4e023d Mon Sep 17 00:00:00 2001 From: louisdo2108 Date: Fri, 3 Feb 2023 17:55:20 +0000 Subject: [PATCH 1/3] add a rename param for download_from_wandb --- theseus/base/utilities/download.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/theseus/base/utilities/download.py b/theseus/base/utilities/download.py index f7b2f52..b2eb192 100644 --- a/theseus/base/utilities/download.py +++ b/theseus/base/utilities/download.py @@ -3,6 +3,8 @@ import urllib.request as urlreq import gdown +import os +from pathlib import Path from theseus.base.utilities.loggers.observer import LoggerObserver @@ -60,18 +62,27 @@ def download_from_url(url, root=None, filename=None): return fpath -def download_from_wandb(filename, run_path, save_dir, generate_id_text_file=False): - import wandb +def download_from_wandb(filename, run_path, save_dir, rename=None, generate_id_text_file=False): + import wandb + try: path = wandb.restore(filename, run_path=run_path, root=save_dir) - + # Save run id to wandb_id.txt if generate_id_text_file: wandb_id = osp.basename(run_path) with open(osp.join(save_dir, "wandb_id.txt"), "w") as f: f.write(wandb_id) - + + if rename: + new_name = str(Path(path.name).parent / rename) + os.rename(path.name, new_name) + return new_name + + # These 2 lines do not show on the terminal + LOGGER.text("Successfully download {} from wandb path {}".format(filename, run_path), level=LoggerObserver.INFO) + LOGGER.text("Saved to {}".format(save_dir), level=LoggerObserver.INFO) return path.name except: LOGGER.text("Failed to download from wandb.", level=LoggerObserver.ERROR) From a64cbe10fe464093ec910954e0243c8926f75f82 Mon Sep 17 00:00:00 2001 From: louisdo2108 Date: Sat, 4 Feb 2023 06:30:05 +0000 Subject: [PATCH 2/3] add logging lines when downloaded successfully --- theseus/base/utilities/download.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/theseus/base/utilities/download.py b/theseus/base/utilities/download.py index b2eb192..32b2845 100644 --- a/theseus/base/utilities/download.py +++ b/theseus/base/utilities/download.py @@ -68,6 +68,7 @@ def download_from_wandb(filename, run_path, save_dir, rename=None, generate_id_t try: path = wandb.restore(filename, run_path=run_path, root=save_dir) + LOGGER.text("Successfully download {} from wandb run path {}".format(filename, run_path), level=LoggerObserver.INFO) # Save run id to wandb_id.txt if generate_id_text_file: @@ -76,14 +77,14 @@ def download_from_wandb(filename, run_path, save_dir, rename=None, generate_id_t f.write(wandb_id) if rename: - new_name = str(Path(path.name).parent / rename) - os.rename(path.name, new_name) + new_name = str(Path(path.name).resolve().parent / rename) + os.rename(Path(path.name).resolve(), new_name) + LOGGER.text("Saved to {}".format(new_name), level=LoggerObserver.INFO) return new_name - # These 2 lines do not show on the terminal - LOGGER.text("Successfully download {} from wandb path {}".format(filename, run_path), level=LoggerObserver.INFO) - LOGGER.text("Saved to {}".format(save_dir), level=LoggerObserver.INFO) + + LOGGER.text("Saved to {}".format((Path(save_dir) / path.name).resolve()), level=LoggerObserver.INFO) return path.name except: - LOGGER.text("Failed to download from wandb.", level=LoggerObserver.ERROR) + LOGGER.text("Failed to download from wandb.\nException {}".format(e), level=LoggerObserver.ERROR) return None From 75c9110d719cc6bf7c4154d5bb5fa755fbc2194c Mon Sep 17 00:00:00 2001 From: louisdo2108 Date: Sat, 4 Feb 2023 06:30:46 +0000 Subject: [PATCH 3/3] add logging lines when downloaded successfully --- theseus/base/utilities/download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/theseus/base/utilities/download.py b/theseus/base/utilities/download.py index 32b2845..e568748 100644 --- a/theseus/base/utilities/download.py +++ b/theseus/base/utilities/download.py @@ -86,5 +86,5 @@ def download_from_wandb(filename, run_path, save_dir, rename=None, generate_id_t LOGGER.text("Saved to {}".format((Path(save_dir) / path.name).resolve()), level=LoggerObserver.INFO) return path.name except: - LOGGER.text("Failed to download from wandb.\nException {}".format(e), level=LoggerObserver.ERROR) + LOGGER.text("Failed to download from wandb.", level=LoggerObserver.ERROR) return None