From bc94021bca5de036ffce8f926327bedc576e4f40 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Thu, 11 Jul 2024 21:09:12 +0200 Subject: [PATCH 1/5] bugfixes addressing issues with imports --- makani/inference.py | 2 ++ makani/utils/comm.py | 2 +- makani/utils/dataloader.py | 6 +++--- makani/utils/visualize.py | 1 + 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/makani/inference.py b/makani/inference.py index 2a6783e..d51a827 100644 --- a/makani/inference.py +++ b/makani/inference.py @@ -53,6 +53,7 @@ parser.add_argument("--epsilon_factor", default=0, type=float) parser.add_argument("--split_data_channels", action="store_true") parser.add_argument("--mode", default="score", type=str, choices=["score", "ensemble"], help="Select inference mode") + parser.add_argument("--enable_odirect", action="store_true") # checkpoint format parser.add_argument("--checkpoint_format", default="legacy", choices=["legacy", "flexible"], type=str, help="Format in which to load checkpoints.") @@ -124,6 +125,7 @@ params["amp_mode"] = args.amp_mode params["jit_mode"] = args.jit_mode params["cuda_graph_mode"] = args.cuda_graph_mode + params["enable_odirect"] = args.enable_odirect params["enable_benchy"] = args.enable_benchy params["disable_ddp"] = args.disable_ddp params["enable_nhwc"] = args.enable_nhwc diff --git a/makani/utils/comm.py b/makani/utils/comm.py index 8849a59..4a5cf3c 100644 --- a/makani/utils/comm.py +++ b/makani/utils/comm.py @@ -93,7 +93,7 @@ def is_distributed(name: str): return False -# initialization routine +# initialization routine def init(model_parallel_sizes=[1, 1, 1, 1], model_parallel_names=["h", "w", "fin", "fout"], verbose=False): diff --git a/makani/utils/dataloader.py b/makani/utils/dataloader.py index cc98e39..b94e10f 100644 --- a/makani/utils/dataloader.py +++ b/makani/utils/dataloader.py @@ -72,7 +72,7 @@ def get_dataloader(params, files_pattern, device, train=True, final_eval=False): from makani.utils.dataloaders.data_loader_multifiles import MultifilesDataset as MultifilesDataset2D from torch.utils.data.distributed import DistributedSampler - # multifiles dataset + # multifiles dataset = MultifilesDataset2D(params, files_pattern, train) sampler = DistributedSampler(dataset, shuffle=train, num_replicas=params.data_num_shards, rank=params.data_shard_id) if (params.data_num_shards > 1) else None @@ -81,8 +81,8 @@ def get_dataloader(params, files_pattern, device, train=True, final_eval=False): dataset, batch_size=int(params.batch_size), num_workers=params.num_data_workers, - shuffle=False, # (sampler is None), - sampler=sampler if train else None, + shuffle=(sampler is None) and train, + sampler=sampler, drop_last=True, pin_memory=torch.cuda.is_available(), ) diff --git a/makani/utils/visualize.py b/makani/utils/visualize.py index 8f2b7cf..6334dd9 100644 --- a/makani/utils/visualize.py +++ b/makani/utils/visualize.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import io import numpy as np import concurrent.futures as cf From 7346d02020fcfd0c4d291a653d5b78cd7991b930 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Thu, 11 Jul 2024 21:19:38 +0200 Subject: [PATCH 2/5] adding numba dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 79bbc8d..12b26af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dynamic = ["version"] dependencies = [ "torch>=2.0.0", "numpy>=1.22.4,<1.25", + "numnba>=0.59.0", "nvidia_dali_cuda110>=1.16.0", "nvidia-modulus>=0.5.0a0", "torch-harmonics>=0.6.5", From 676dc955958cdbc7af8d9759fb30a36073eea35b Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Thu, 11 Jul 2024 21:21:13 +0200 Subject: [PATCH 3/5] also adding numba to github CI --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1082e60..ffdea68 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip setuptools wheel - python -m pip install tqdm numpy parameterized xarray xskillscore timm jsbeautifier pynvml h5py wandb ruamel.yaml moviepy tensorly tensorly-torch + python -m pip install tqdm numpy numba parameterized xarray xskillscore timm jsbeautifier pynvml h5py wandb ruamel.yaml moviepy tensorly tensorly-torch python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu python -m pip install torch_harmonics python -m pip install git+https://github.com/NVIDIA/modulus.git From 11d27106928c16564f97fc27e063262c4b58a327 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Thu, 11 Jul 2024 22:32:05 +0200 Subject: [PATCH 4/5] relaxing numba required version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 12b26af..6b091ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ dynamic = ["version"] dependencies = [ "torch>=2.0.0", "numpy>=1.22.4,<1.25", - "numnba>=0.59.0", + "numnba>=0.50.0", "nvidia_dali_cuda110>=1.16.0", "nvidia-modulus>=0.5.0a0", "torch-harmonics>=0.6.5", From 80ed4b43d47a69eaf98acf04f0fd09043a733085 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Thu, 11 Jul 2024 22:37:27 +0200 Subject: [PATCH 5/5] typo in pyproject --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6b091ce..332c3e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ dynamic = ["version"] dependencies = [ "torch>=2.0.0", "numpy>=1.22.4,<1.25", - "numnba>=0.50.0", + "numba>=0.50.0", "nvidia_dali_cuda110>=1.16.0", "nvidia-modulus>=0.5.0a0", "torch-harmonics>=0.6.5",