Skip to content

Commit

Permalink
Merge pull request #10 from nbren12/fix-model-package
Browse files Browse the repository at this point in the history
Minor changes to make model packages compatible with the newest version of e2mip.
  • Loading branch information
bonevbs authored Mar 19, 2024
2 parents a70960b + 44c2309 commit f4eef6d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*.h5
__pycache__
*.out
expts/
wandb/
Expand Down
15 changes: 6 additions & 9 deletions makani/models/model_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get(self, path):

logger = logging.getLogger(__name__)

THIS_MODULE = "makani.models.model_package"
MODEL_PACKAGE_CHECKPOINT_PATH = "training_checkpoints/best_ckpt_mp0.tar"
MINS_FILE = "mins.npy"
MAXS_FILE = "maxs.npy"
Expand Down Expand Up @@ -129,7 +130,7 @@ def save_model_package(params):

# write out earth2mip metadata.json
fcn_mip_data = {
"entrypoint": {"name": "networks.model_package:load_time_loop"},
"entrypoint": {"name": f"{THIS_MODULE}:load_time_loop"},
}
with open(os.path.join(params.experiment_dir, "metadata.json"), "w") as f:
msg = jsbeautifier.beautify(json.dumps(fcn_mip_data), jsopts)
Expand Down Expand Up @@ -207,7 +208,7 @@ def load_time_loop(package, device=None, time_step_hours=None):
"""

from earth2mip.networks import Inference
from earth2mip.schema import Grid
from earth2mip.grid import equiangular_lat_lon_grid

config = package.get("config.json")
params = ParamsBase.from_json(config)
Expand Down Expand Up @@ -244,15 +245,11 @@ def load_time_loop(package, device=None, time_step_hours=None):
model = load_model_package(package, pretrained=True, device=device)
shape = (params.img_shape_x, params.img_shape_y)

grid = None
if shape == (721, 1440):
grid = Grid.grid_721x1440
elif shape == (720, 1440):
grid = Grid.grid_720x1440
grid = equiangular_lat_lon_grid(nlat=params.img_shape_x, nlon=params.img_shape_y, includes_south_pole=True)

if time_step_hours is None:
time_step_data = datetime.timedelta(hours=6)
time_step = time_step_data * params.get("dt", 1)
hour = datetime.timedelta(hours=1)
time_step = hour * params.get("dt", 6)
else:
time_step = datetime.timedelta(hours=time_step_hours)

Expand Down

0 comments on commit f4eef6d

Please sign in to comment.