diff --git a/soundbay/models.py b/soundbay/models.py index 3836fe9..fefb1f9 100644 --- a/soundbay/models.py +++ b/soundbay/models.py @@ -4,7 +4,7 @@ from torch import Tensor from torchvision.models.resnet import ResNet, BasicBlock, conv3x3, Bottleneck from torchvision.models.vgg import VGG -from torchvision.models import squeezenet +from torchvision.models import squeezenet, ResNet18_Weights import torchvision.models as models @@ -350,7 +350,7 @@ def __init__(self, num_classes=2, pretrained=True): super(ResNet182D, self).__init__() # Load a pre-trained ResNet-18 - resnet = models.resnet18(pretrained=pretrained) + resnet = models.resnet18(weights=ResNet18_Weights.DEFAULT) if pretrained else models.resnet18(weights=None) num_features = resnet.fc.in_features resnet.fc = nn.Sequential(