From 5765e667174f632eb0dbb635b7b493208a522b66 Mon Sep 17 00:00:00 2001 From: Benteng Ma Date: Mon, 11 Mar 2024 16:25:16 +0000 Subject: [PATCH] removed loading the pretrained parameters --- .../helpers/torch_module/src/torch_module/modules/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/helpers/torch_module/src/torch_module/modules/__init__.py b/common/helpers/torch_module/src/torch_module/modules/__init__.py index 694f93924..334b94f35 100644 --- a/common/helpers/torch_module/src/torch_module/modules/__init__.py +++ b/common/helpers/torch_module/src/torch_module/modules/__init__.py @@ -35,7 +35,7 @@ class UNetWithResnetEncoder(nn.Module): def __init__(self, num_classes, in_channels=3, freeze_bn=False, sigmoid=True): super(UNetWithResnetEncoder, self).__init__() self.sigmoid = sigmoid - self.resnet = models.resnet34(pretrained=True) # Initialize with a ResNet model + self.resnet = models.resnet34(pretrained=False) # Initialize with a ResNet model if in_channels != 3: self.resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) @@ -99,7 +99,7 @@ def unfreeze_bn(self): class MultiLabelResNet(nn.Module): def __init__(self, num_labels, input_channels=3, sigmoid=True, pretrained=False,): super(MultiLabelResNet, self).__init__() - self.model = models.resnet34(pretrained=pretrained) + self.model = models.resnet34(pretrained=False) self.sigmoid = sigmoid if input_channels != 3: