Skip to content

Commit

Permalink
Fixing CI
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Jan 12, 2024
1 parent c981550 commit 1d55d75
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions makani/models/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

import os
import importlib.util

# we need this here for the code to work
import importlib_metadata
from importlib.metadata import EntryPoint, entry_points
import importlib.metadata

import logging

Expand Down Expand Up @@ -148,8 +150,9 @@ def get_model(params: ParamsBase, **kwargs) -> "torch.nn.Module":

model_handle = _model_registry.get(params.nettype)
if model_handle is not None:
if isinstance(model_handle, (EntryPoint, importlib.metadata.EntryPoint)):
if isinstance(model_handle, (EntryPoint, importlib_metadata.EntryPoint)):
model_handle = model_handle.load()

model_handle = partial(model_handle, inp_shape=inp_shape, out_shape=out_shape, inp_chans=inp_chans, out_chans=out_chans, **params.to_dict())
else:
raise KeyError(f"No model is registered under the name {name}")
Expand Down

0 comments on commit 1d55d75

Please sign in to comment.