Skip to content

Commit

Permalink
Update model links
Browse files Browse the repository at this point in the history
  • Loading branch information
RF5 committed Jul 8, 2019
1 parent 9a8d5cb commit c8518d9
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions danbooru_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def resnet18(pretrained=True, progress=True, top_n=100, **kwargs):
model = _resnet(models.resnet18, top_n, **kwargs)
if pretrained:
if top_n == 100:
#state = torch.hub.load_state_dict_from_url("GITHUB RELEASE URL!!!!",
# progress=progress)
state = torch.load('weights/resnet18.pth')
state = torch.hub.load_state_dict_from_url("https://github.com/RF5/danbooru-pretrained/releases/download/untagged-d7dd93226bc56cbf2fe8/resnet18.pth",
progress=progress)
# state = torch.load('weights/resnet18.pth')
model.load_state_dict(state)
else:
raise ValueError("Sorry, the resnet18 model only supports the top-100 tags \
Expand All @@ -108,9 +108,9 @@ def resnet34(pretrained=True, progress=True, top_n=500, **kwargs):
model = _resnet(models.resnet34, top_n, **kwargs)
if pretrained:
if top_n == 500:
#state = torch.hub.load_state_dict_from_url("GITHUB RELEASE URL!!!!",
# progress=progress)
state = torch.load('weights/resnet34.pth')
state = torch.hub.load_state_dict_from_url("https://github.com/RF5/danbooru-pretrained/releases/download/untagged-d7dd93226bc56cbf2fe8/resnet34.pth",
progress=progress)
# state = torch.load('weights/resnet34.pth')
model.load_state_dict(state)
else:
raise ValueError("Sorry, the resnet34 model only supports the top-500 tags \
Expand All @@ -130,9 +130,9 @@ def resnet50(pretrained=True, progress=True, top_n=6000, **kwargs):
model = _resnet(models.resnet50, top_n, **kwargs)
if pretrained:
if top_n == 6000:
#state = torch.hub.load_state_dict_from_url("GITHUB RELEASE URL!!!!",
# progress=progress)
state = torch.load('weights/resnet50.pth')
state = torch.hub.load_state_dict_from_url("https://github.com/RF5/danbooru-pretrained/releases/download/untagged-d7dd93226bc56cbf2fe8/resnet50.pth",
progress=progress)
# state = torch.load('weights/resnet50.pth')
model.load_state_dict(state)
else:
raise ValueError("Sorry, the resnet50 model only supports the top-6000 tags \
Expand Down

0 comments on commit c8518d9

Please sign in to comment.