Skip to content

Commit

Permalink
[LoRA] fix: torch.compile() for lora conv (huggingface#5298)
Browse files Browse the repository at this point in the history
fix: torch.compile() for lora conv
  • Loading branch information
sayakpaul authored and Jimmy committed Apr 26, 2024
1 parent 53ff78d commit 4156246
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def forward(self, hidden_states, scale: float = 1.0):
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
else:
return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
original_outputs = F.conv2d(
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
return original_outputs + (scale * self.lora_layer(hidden_states))


class LoRACompatibleLinear(nn.Linear):
Expand Down

0 comments on commit 4156246

Please sign in to comment.