Skip to content

Commit

Permalink
Add 'model_size' variable when loading checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinDong committed Apr 29, 2024
1 parent ad5e597 commit fcaac0e
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tinymm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def load_from_checkpoint(checkpoint: str):
state_dict = checkpoint["model"]
config = TrainConfig(**checkpoint["train_config"])
module = import_module("tinymm.model_config")
class_ = getattr(module, f"{config.model_config['model_name']}Config")
mconfig = config.model_config
class_ = getattr(module, f"{mconfig['model_name']}{mconfig['model_size']}Config")
config.model_config = class_(**config.model_config)
model_name = config.model_config.model_name
module = import_module(f"tinymm.{model_name}.provider")
Expand Down

0 comments on commit fcaac0e

Please sign in to comment.