From c8b4c626f541bae0fe8987ebd8a15657d77a1211 Mon Sep 17 00:00:00 2001 From: Asad Memon Date: Wed, 1 Feb 2023 00:40:02 -0800 Subject: [PATCH] Pass LoRA rank to LoRALinearLayer (#2191) --- src/diffusers/models/cross_attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index a1d77f66ef562..4cd912b80a73d 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -296,10 +296,10 @@ class LoRACrossAttnProcessor(nn.Module): def __init__(self, hidden_size, cross_attention_dim=None, rank=4): super().__init__() - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) def __call__( self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 @@ -408,10 +408,10 @@ class LoRAXFormersCrossAttnProcessor(nn.Module): def __init__(self, hidden_size, cross_attention_dim, rank=4): super().__init__() - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) def __call__( self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0