Skip to content

Commit

Permalink
prevent undoing refresh model load params (#2092)
Browse files Browse the repository at this point in the history
Ensures `should_refresh_model_loading_params()` is called when needed. Improved code clarity.
  • Loading branch information
altoiddealer authored Oct 16, 2024
1 parent 9efa4ea commit 6dc71b7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion modules/sysinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def set_config(req: dict[str, Any], is_api=False, run_callbacks=True, save_confi
main_entry.checkpoint_change(v, save=False, refresh=False)
should_refresh_model_loading_params = True
elif k == 'forge_additional_modules':
should_refresh_model_loading_params = main_entry.modules_change(v, save=False, refresh=False)
modules_changed = main_entry.modules_change(v, save=False, refresh=False)
if modules_changed:
should_refresh_model_loading_params = True
elif k in memory_keys:
mem_key = k[len('forge_'):] # remove 'forge_' prefix
memory_changes[mem_key] = v
Expand Down
2 changes: 1 addition & 1 deletion modules_forge/main_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def checkpoint_change(ckpt_name:str, save=True, refresh=True):


def modules_change(module_values:list, save=True, refresh=True) -> bool:
""" module values may be provided as file paths or as simply the module names """
""" module values may be provided as file paths, or just the module names. Returns True if modules changed. """
modules = []
for v in module_values:
module_name = os.path.basename(v) # If the input is a filepath, extract the file name
Expand Down

0 comments on commit 6dc71b7

Please sign in to comment.