From fcaac0e3050dc81371f47f92bbf737afc17e1c7b Mon Sep 17 00:00:00 2001 From: Robin Dong Date: Tue, 30 Apr 2024 09:03:38 +1000 Subject: [PATCH] Add 'model_size' variable when loading checkpoint --- tinymm/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tinymm/utils.py b/tinymm/utils.py index 26a92a8..11f1f94 100644 --- a/tinymm/utils.py +++ b/tinymm/utils.py @@ -20,7 +20,8 @@ def load_from_checkpoint(checkpoint: str): state_dict = checkpoint["model"] config = TrainConfig(**checkpoint["train_config"]) module = import_module("tinymm.model_config") - class_ = getattr(module, f"{config.model_config['model_name']}Config") + mconfig = config.model_config + class_ = getattr(module, f"{mconfig['model_name']}{mconfig['model_size']}Config") config.model_config = class_(**config.model_config) model_name = config.model_config.model_name module = import_module(f"tinymm.{model_name}.provider")