You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello I was trying to add condition to a pre-trained UNet2DModel and re-train it with condition. I was thinking I could do that just by passing condition with class_labels parameter. After sometime I realized that if unet model is not initialized with class_embed_type the forward method completely ignores the class_labels parameter. The problem can be found in unet_2d.py between lines 285-293:
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when doing class conditioning")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
Before checking out the the unet_2d.py I thought it already had an embedding which would be used whenever class_labels parameter is passed. So I was trying to use it that way. Maybe I can open a feature request or a pull request for it but I think there should at least be a warning or an error. You can pass a parameter and it might just be ignored. I added a snippet for reproduction. You can see that passing class_labels does not change anything at all.
kesimeg
changed the title
UNet2DModel ignores class_labels parameter in forward
UNet2DModel ignores class_labels parameter in forward method
Oct 7, 2023
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Describe the bug
Hello I was trying to add condition to a pre-trained UNet2DModel and re-train it with condition. I was thinking I could do that just by passing condition with class_labels parameter. After sometime I realized that if unet model is not initialized with class_embed_type the forward method completely ignores the class_labels parameter. The problem can be found in unet_2d.py between lines 285-293:
Before checking out the the unet_2d.py I thought it already had an embedding which would be used whenever class_labels parameter is passed. So I was trying to use it that way. Maybe I can open a feature request or a pull request for it but I think there should at least be a warning or an error. You can pass a parameter and it might just be ignored. I added a snippet for reproduction. You can see that passing class_labels does not change anything at all.
Reproduction
unet = UNet2DModel(
sample_size=32,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 256),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
batch_size = 8
rand_noise = torch.rand(batch_size,3,32,32)
timesteps = torch.randint(low=0,high=1000,size=(batch_size,))
class_labels = torch.randint(low=0,high=1000,size=(1,batch_size)).long()
out_1 = unet(rand_noise, timesteps,class_labels = class_labels).sample
out_2 = unet(rand_noise, timesteps).sample
print(torch.equal(out_1, out_2))
Logs
No response
System Info
Colab
Who can help?
No response
The text was updated successfully, but these errors were encountered: