Skip to content

Commit

Permalink
Revert "lpaps ckpt, fixes in loading state dicts (v-iashin#13)"
Browse files Browse the repository at this point in the history
This reverts commit 3894458.

Reverting due to seeing nans in loss during codebook training
  • Loading branch information
jhyau committed Jun 11, 2022
1 parent 593f0fe commit 53f938e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 13 deletions.
Binary file removed specvqgan/modules/autoencoder/lpaps/lpaps.pt
Binary file not shown.
14 changes: 7 additions & 7 deletions specvqgan/modules/losses/lpaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ def __init__(self, use_dropout=True):
for param in self.parameters():
param.requires_grad = False

def load_from_pretrained(self, name="lpaps"):
def load_from_pretrained(self, name="vggishish_lpaps"):
ckpt = get_ckpt_path(name, "specvqgan/modules/autoencoder/lpaps")
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")))
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
print("loaded pretrained LPAPS loss from {}".format(ckpt))

@classmethod
def from_pretrained(cls, name="lpaps"):
if name != "lpaps":
def from_pretrained(cls, name="vggishish_lpaps"):
if name != "vggishish_lpaps":
raise NotImplementedError
model = cls()
ckpt = get_ckpt_path(name, "specvqgan/modules/autoencoder/lpaps")
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")))
ckpt = get_ckpt_path(name)
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
return model

def forward(self, input, target):
Expand Down Expand Up @@ -130,7 +130,7 @@ def vggishish16(self, pretrained: bool = True) -> VGGishish:
conv_layers = [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512]
model = VGGishish(conv_layers, use_bn=False, num_classes=num_classes_vggsound)
if pretrained:
ckpt_path = get_ckpt_path('vggishish', "specvqgan/modules/autoencoder/lpaps")
ckpt_path = get_ckpt_path('vggishish_lpaps', "specvqgan/modules/autoencoder/lpaps")
ckpt = torch.load(ckpt_path, map_location=torch.device("cpu"))
model.load_state_dict(ckpt['model'])
return model
Expand Down
9 changes: 3 additions & 6 deletions specvqgan/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,19 @@
from tqdm import tqdm

URL_MAP = {
'lpaps': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/lpaps.pt',
'vggishish': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt',
'vggishish_lpaps': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt',
'vggishish_mean_std_melspec_10s_22050hz': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt',
'melception': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt',
}

CKPT_MAP = {
'lpaps': 'lpaps.pt',
'vggishish': 'vggishish16.pt',
'vggishish_lpaps': 'vggishish16.pt',
'vggishish_mean_std_melspec_10s_22050hz': 'train_means_stds_melspec_10s_22050hz.txt',
'melception': 'melception-21-05-10T09-28-40.pt',
}

MD5_MAP = {
'lpaps': 'f8d4e7dba2b870222fe2bee26f85e7c9',
'vggishish': '197040c524a07ccacf7715d7080a80bd',
'vggishish_lpaps': '197040c524a07ccacf7715d7080a80bd',
'vggishish_mean_std_melspec_10s_22050hz': 'f449c6fd0e248936c16f6d22492bb625',
'melception': 'a71a41041e945b457c7d3d814bbcf72d',
}
Expand Down

0 comments on commit 53f938e

Please sign in to comment.