Skip to content

Commit

Permalink
fix une2td ignoring class_labels
Browse files Browse the repository at this point in the history
  • Loading branch information
kesimeg authored Oct 15, 2023
1 parent 07b297e commit b8b77f7
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,9 @@ def forward(

class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb

elif self.class_embedding is None and class_labels is not None:
raise ValueError("class_embedding needs to be initialized to use class conditioning")

# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
Expand Down

0 comments on commit b8b77f7

Please sign in to comment.