You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I used code like that (want to lora train the unet part as well): train_unet_lora = add_lora_to( unet, target_module=unet_replace, search_class=[torch.nn.Linear], r=args.lora_rank, lora_bias=args.lora_bias )
But it raise error: File "train_lora.py", line 624, in <module> main(args) File "train_lora.py", line 522, in main model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/accelerate/utils/operations.py", line 817, in forward return model_forward(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/accelerate/utils/operations.py", line 805, in __call__ return convert_to_fp32(self.model_forward(*args, **kwargs)) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast return func(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 1121, in forward sample, res_samples = downsample_block( File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 1198, in forward hidden_states = resnet(hidden_states, temb, scale=lora_scale) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/diffusers/models/resnet.py", line 373, in forward self.time_emb_proj(temb, scale)[:, :, None, None] File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) TypeError: forward() takes 2 positional arguments but 3 were given
The text was updated successfully, but these errors were encountered:
When I used code like that (want to lora train the unet part as well):
train_unet_lora = add_lora_to( unet, target_module=unet_replace, search_class=[torch.nn.Linear], r=args.lora_rank, lora_bias=args.lora_bias )
But it raise error:
File "train_lora.py", line 624, in <module> main(args) File "train_lora.py", line 522, in main model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/accelerate/utils/operations.py", line 817, in forward return model_forward(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/accelerate/utils/operations.py", line 805, in __call__ return convert_to_fp32(self.model_forward(*args, **kwargs)) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast return func(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 1121, in forward sample, res_samples = downsample_block( File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 1198, in forward hidden_states = resnet(hidden_states, temb, scale=lora_scale) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/diffusers/models/resnet.py", line 373, in forward self.time_emb_proj(temb, scale)[:, :, None, None] File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hiccup/app/miniconda/envs/hiper/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) TypeError: forward() takes 2 positional arguments but 3 were given
The text was updated successfully, but these errors were encountered: