-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove all usage of AttnProcsLayers #4699
Comments
In my code, I only save parameters that |
@eliphatfs could you maybe point @williamberman to a full-fledged script you're talking about? |
def _save_to_state_dict(module, destination, prefix, keep_vars, trainable_only=False):
for name, param in module._parameters.items():
if param is not None and (not trainable_only or param.requires_grad):
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in module._buffers.items():
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.detach()
def get_state_dict(module,
destination=None,
prefix='',
keep_vars=False,
trainable_only=False):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
# below is the same as torch.nn.Module.state_dict()
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict() # type: ignore
destination._metadata[prefix[:-1]] = local_metadata = dict( # type: ignore
version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars, trainable_only=trainable_only) # type: ignore
for name, child in module._modules.items():
if child is not None:
get_state_dict(
child, destination, prefix + name + '.', keep_vars=keep_vars, trainable_only=trainable_only)
for hook in module._state_dict_hooks.values():
hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination # type: ignore |
Is wrapping |
@hkunzhe maybe it would but idk I really wouldn't recommend doing that |
Slightly related: #4765 |
Hey! Can I work on this? |
hey @pedrogengo yes feel free to though it might be a bit involved :) |
Describe the bug
All training scripts which use AttnProcsLayers will not work properly with accelerate for any accelerate feature that requires calling into the wrapped return class for monkey patching the forward method.
All lora training scripts should instead:
This PR fixed these bugs in the dreambooth lora script https://github.com/huggingface/diffusers/pull/3778/files but there are still 4 lora training scripts which use AttnProcsLayers
other relevant github links : #4046 #4046 (comment)
Reproduction
n/a
Logs
No response
System Info
n/a
Who can help?
No response
The text was updated successfully, but these errors were encountered: