diff --git a/models/modules/resnet_architecture/mobile_resnet_generator.py b/models/modules/resnet_architecture/mobile_resnet_generator.py index c073d7302..7e9411e7a 100644 --- a/models/modules/resnet_architecture/mobile_resnet_generator.py +++ b/models/modules/resnet_architecture/mobile_resnet_generator.py @@ -285,12 +285,12 @@ def forward(self, x): return torch.reshape(out.unsqueeze(1),(1,1,self.out_feat,self.out_feat)) -class mobile_resnet_block_attn(nn.Module): +class MobileResnetBlock_attn(nn.Module): def __init__(self, channel, kernel, stride, padding): - super(mobile_resnet_block_attn, self).__init__() + super(MobileResnetBlock_attn, self).__init__() self.channel = channel self.kernel = kernel - self.strdie = stride + self.stride = stride self.padding = padding self.conv1 = SeparableConv2d(channel, channel, kernel, stride, 0) self.conv1_norm = nn.InstanceNorm2d(channel) @@ -329,7 +329,7 @@ def __init__(self, input_nc, output_nc, ngf=64, n_blocks=9, use_spectral=False, self.resnet_blocks = [] for i in range(n_blocks): - self.resnet_blocks.append(mobile_resnet_block_attn(ngf * 4, 3, 1, 1)) + self.resnet_blocks.append(MobileResnetBlock_attn(ngf * 4, 3, 1, 1)) self.resnet_blocks[i].weight_init(0, 0.02) self.resnet_blocks = nn.Sequential(*self.resnet_blocks)