-
Notifications
You must be signed in to change notification settings - Fork 905
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minimal preload before dump current model data #2215
base: main
Are you sure you want to change the base?
Conversation
Performs the minimum steps needed to verify model/modules before the loaded model data gets trashed, subsequently preventing unnecessary computation and headaches due to a missing required text encoder, etc.
@@ -30,6 +30,24 @@ | |||
dir_path = os.path.dirname(__file__) | |||
|
|||
|
|||
def check_huggingface_component(component_name:str, cls_name:str, state_dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In near future, I may be engaged in restoring sd_checkpoints_limit (but not yet fact). Then I will have to work with these modules. So I decided to take look at PRs.
Spaces? Typehint? Private method?
def check_huggingface_component(component_name:str, cls_name:str, state_dict): | |
def _check_huggingface_component(component_name: str, cls_name: str, state_dict: dict[str, Any] | list[?]): |
check_sd = True | ||
comp_str = 'T5' | ||
|
||
if check_sd and (not isinstance(state_dict, dict) or len(state_dict) <= 16): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please explain this magic number 16? Why 16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@psydok to answer both questions, the state_dict
variable and that specific condition, were both lifted from load_huggingface_component which is called from forge_loader.
These are Illyasviel code - I simply relocated this conditional check. With this PR it is now checking the condition before trashing the currently loaded model data, so if an exception is raised the current model data does not get unloaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside from the context clues of this condition, I otherwise have no idea what the expected structure of state_dict
is supposed to be. There is even less type hinting where I lifted the code from load_huggingface_component()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I got it! Thank you very much!
Illyasviel - this PR solves a very annoying, and very easily reproduceable bug. Reproduce bug: simply fail to include a text encoder or other required module that is not baked in to the SD model. Bug: SD model becomes unusable until you switch to another model, successfully generate an image - THEN you can switch back to your original model and try again. Otherwise it just yields |
Requesting review from Illyasviel
Currently, all loaded model data is trashed and garbage collected before Forge gathers configs / components / etc.
If an AssertionError is raised after this due to a missing module such as a required text encoder, or potentially other future compatibility issues we want to check for - this throws a monkey wrench into whatever model was most recently loaded (#2166).
What this PR does is performs the minimum steps needed to verify model/modules before the loaded model data gets trashed, subsequently preventing unnecessary computation and headaches due to a missing required text encoder, etc.
The new
forge_preloader()
returns the few items of any significance it collects in the process for theforge_loader()
to use.