Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UNet2DModel ignores class_labels parameter in forward method #5330

Closed
kesimeg opened this issue Oct 7, 2023 · 4 comments
Closed

UNet2DModel ignores class_labels parameter in forward method #5330

kesimeg opened this issue Oct 7, 2023 · 4 comments
Labels
bug Something isn't working stale Issues that haven't received updates

Comments

@kesimeg
Copy link
Contributor

kesimeg commented Oct 7, 2023

Describe the bug

Hello I was trying to add condition to a pre-trained UNet2DModel and re-train it with condition. I was thinking I could do that just by passing condition with class_labels parameter. After sometime I realized that if unet model is not initialized with class_embed_type the forward method completely ignores the class_labels parameter. The problem can be found in unet_2d.py between lines 285-293:

    if self.class_embedding is not None:
        if class_labels is None:
            raise ValueError("class_labels should be provided when doing class conditioning")

        if self.config.class_embed_type == "timestep":
            class_labels = self.time_proj(class_labels)

        class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
        emb = emb + class_emb

Before checking out the the unet_2d.py I thought it already had an embedding which would be used whenever class_labels parameter is passed. So I was trying to use it that way. Maybe I can open a feature request or a pull request for it but I think there should at least be a warning or an error. You can pass a parameter and it might just be ignored. I added a snippet for reproduction. You can see that passing class_labels does not change anything at all.

Reproduction

unet = UNet2DModel(
sample_size=32,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 256),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)

batch_size = 8
rand_noise = torch.rand(batch_size,3,32,32)
timesteps = torch.randint(low=0,high=1000,size=(batch_size,))
class_labels = torch.randint(low=0,high=1000,size=(1,batch_size)).long()

out_1 = unet(rand_noise, timesteps,class_labels = class_labels).sample
out_2 = unet(rand_noise, timesteps).sample
print(torch.equal(out_1, out_2))

Logs

No response

System Info

Colab

Who can help?

No response

@kesimeg kesimeg added the bug Something isn't working label Oct 7, 2023
@kesimeg kesimeg changed the title UNet2DModel ignores class_labels parameter in forward UNet2DModel ignores class_labels parameter in forward method Oct 7, 2023
@DN6
Copy link
Collaborator

DN6 commented Oct 9, 2023

@patrickvonplaten Mind taking a look here please.

@DN6
Copy link
Collaborator

DN6 commented Oct 13, 2023

@kesimeg Could you open a PR with your suggested change and I can review?

@kesimeg
Copy link
Contributor Author

kesimeg commented Oct 15, 2023

@DN6 I opened the PR: #5401 .

Copy link

github-actions bot commented Nov 9, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Nov 9, 2023
@DN6 DN6 closed this as completed Nov 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

2 participants