Skip to content

Commit

Permalink
Execute model_loaded_callback after moving to target device
Browse files Browse the repository at this point in the history
  • Loading branch information
Nuullll committed Jan 6, 2024
1 parent b00b429 commit a183de0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 3 additions & 3 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,13 +842,13 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")

script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")

if not sd_model.lowvram:
sd_model.to(devices.device)
timer.record("move model to device")

script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")

print(f"Weights loaded in {timer.summary()}.")

model_data.set_sd_model(sd_model)
Expand Down
3 changes: 2 additions & 1 deletion modules/sd_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
load_vae(sd_model, vae_file, vae_source)

sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)

if not sd_model.lowvram:
sd_model.to(devices.device)

script_callbacks.model_loaded_callback(sd_model)

print("VAE weights loaded.")
return sd_model

0 comments on commit a183de0

Please sign in to comment.