Skip to content

Commit

Permalink
Merge pull request #12 from NVIDIA/bbonev/v0.1.1
Browse files Browse the repository at this point in the history
bugfixes addressing issues with imports
  • Loading branch information
bonevbs authored Jul 11, 2024
2 parents 974661f + 80ed4b4 commit 3558104
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
python -m pip install tqdm numpy parameterized xarray xskillscore timm jsbeautifier pynvml h5py wandb ruamel.yaml moviepy tensorly tensorly-torch
python -m pip install tqdm numpy numba parameterized xarray xskillscore timm jsbeautifier pynvml h5py wandb ruamel.yaml moviepy tensorly tensorly-torch
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install torch_harmonics
python -m pip install git+https://github.com/NVIDIA/modulus.git
Expand Down
2 changes: 2 additions & 0 deletions makani/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
parser.add_argument("--epsilon_factor", default=0, type=float)
parser.add_argument("--split_data_channels", action="store_true")
parser.add_argument("--mode", default="score", type=str, choices=["score", "ensemble"], help="Select inference mode")
parser.add_argument("--enable_odirect", action="store_true")

# checkpoint format
parser.add_argument("--checkpoint_format", default="legacy", choices=["legacy", "flexible"], type=str, help="Format in which to load checkpoints.")
Expand Down Expand Up @@ -124,6 +125,7 @@
params["amp_mode"] = args.amp_mode
params["jit_mode"] = args.jit_mode
params["cuda_graph_mode"] = args.cuda_graph_mode
params["enable_odirect"] = args.enable_odirect
params["enable_benchy"] = args.enable_benchy
params["disable_ddp"] = args.disable_ddp
params["enable_nhwc"] = args.enable_nhwc
Expand Down
2 changes: 1 addition & 1 deletion makani/utils/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def is_distributed(name: str):
return False


# initialization routine
# initialization routine
def init(model_parallel_sizes=[1, 1, 1, 1],
model_parallel_names=["h", "w", "fin", "fout"],
verbose=False):
Expand Down
6 changes: 3 additions & 3 deletions makani/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_dataloader(params, files_pattern, device, train=True, final_eval=False):
from makani.utils.dataloaders.data_loader_multifiles import MultifilesDataset as MultifilesDataset2D
from torch.utils.data.distributed import DistributedSampler

# multifiles dataset
# multifiles
dataset = MultifilesDataset2D(params, files_pattern, train)

sampler = DistributedSampler(dataset, shuffle=train, num_replicas=params.data_num_shards, rank=params.data_shard_id) if (params.data_num_shards > 1) else None
Expand All @@ -81,8 +81,8 @@ def get_dataloader(params, files_pattern, device, train=True, final_eval=False):
dataset,
batch_size=int(params.batch_size),
num_workers=params.num_data_workers,
shuffle=False, # (sampler is None),
sampler=sampler if train else None,
shuffle=(sampler is None) and train,
sampler=sampler,
drop_last=True,
pin_memory=torch.cuda.is_available(),
)
Expand Down
1 change: 1 addition & 0 deletions makani/utils/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import io
import numpy as np
import concurrent.futures as cf
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dynamic = ["version"]
dependencies = [
"torch>=2.0.0",
"numpy>=1.22.4,<1.25",
"numba>=0.50.0",
"nvidia_dali_cuda110>=1.16.0",
"nvidia-modulus>=0.5.0a0",
"torch-harmonics>=0.6.5",
Expand Down

0 comments on commit 3558104

Please sign in to comment.