Skip to content

Commit

Permalink
fix: UNet/UViT layers for cut NCE
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Mar 21, 2023
1 parent ca304a2 commit 8459876
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ def __init__(self, opt, rank):
self.opt.alg_cut_nce_layers = ",".join(
[str(k) for k in range(self.opt.G_nblocks)]
)
elif "unet" in self.opt.G_netG:
self.opt.alg_cut_nce_layers = ",".join(
str(self.opt.G_nblocks * i - 1)
for i in range(1, len(self.opt.G_unet_mha_channel_mults) + 1)
)
elif "uvit" in self.opt.G_netG:
self.opt.alg_cut_nce_layers = ",".join(
str(self.opt.G_nblocks * i - 1)
for i in range(1, len(self.opt.G_unet_mha_channel_mults) + 1)
)

self.nce_layers = [int(i) for i in self.opt.alg_cut_nce_layers.split(",")]

Expand Down

0 comments on commit 8459876

Please sign in to comment.