Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can not add Lora into the Unet, It will raise error "TypeError: forward() takes 2 positional arguments but 3 were given" #4

Open
Hiccupwzy opened this issue Mar 1, 2024 · 0 comments

Comments

@Hiccupwzy
Copy link

When I used code like that (want to lora train the unet part as well):
train_unet_lora = add_lora_to( unet, target_module=unet_replace, search_class=[torch.nn.Linear], r=args.lora_rank, lora_bias=args.lora_bias )
But it raise error:
File "train_lora.py", line 624, in <module> main(args) File "train_lora.py", line 522, in main model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/accelerate/utils/operations.py", line 817, in forward return model_forward(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/accelerate/utils/operations.py", line 805, in __call__ return convert_to_fp32(self.model_forward(*args, **kwargs)) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast return func(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 1121, in forward sample, res_samples = downsample_block( File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 1198, in forward hidden_states = resnet(hidden_states, temb, scale=lora_scale) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/diffusers/models/resnet.py", line 373, in forward self.time_emb_proj(temb, scale)[:, :, None, None] File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) TypeError: forward() takes 2 positional arguments but 3 were given

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant