diff --git a/osculari/models/pretrained_layers.py b/osculari/models/pretrained_layers.py index afbae95..b9dcb14 100644 --- a/osculari/models/pretrained_layers.py +++ b/osculari/models/pretrained_layers.py @@ -2,7 +2,7 @@ Extracting features from different layers of a pretrained model. """ -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict from torchvision import models as torch_models @@ -79,6 +79,74 @@ def _available_densenet_layers(_architecture: str) -> List[str]: return ['feature%d' % b for b in range(12)] +def _available_squeezenet_layers(_architecture: str) -> List[str]: + return [ + *['feature%d' % b for b in range(13)], + *['classifier%d' % b for b in [1, 2]], + ] + + +def _available_mnasnet_layers(_architecture: str) -> List[str]: + return ['layer%d' % b for b in range(17)] + + +def _available_shufflenet_layers(_architecture: str) -> List[str]: + return ['layer%d' % b for b in range(6)] + + +def _available_efficientnet_layers(architecture: str) -> List[str]: + max_features = 8 if architecture == 'efficientnet_v2_s' else 9 + return ['feature%d' % b for b in range(max_features)] + + +def _available_googlenet_layers(_architecture: Optional[str] = None, + return_inds: Optional[bool] = False) -> Union[List[str], Dict]: + layers = { + 'conv1': 0, + 'maxpool1': 1, + 'conv2': 2, + 'conv3': 3, + 'maxpool2': 4, + 'inception3a': 5, + 'inception3b': 6, + 'maxpool3': 7, + 'inception4a': 8, + 'inception4b': 9, + 'inception4c': 10, + 'inception4d': 11, + 'inception4e': 12, + 'maxpool4': 13, + 'inception5a': 14, + 'inception5b': 15 + } + return layers if return_inds else list(layers.keys()) + + +def _available_inception_layers(_architecture: Optional[str] = None, + return_inds: Optional[bool] = False) -> Union[List[str], Dict]: + layers = { + 'Conv2d_1a_3x3': 0, + 'Conv2d_2a_3x3': 1, + 'Conv2d_2b_3x3': 2, + 'maxpool1': 3, + 'Conv2d_3b_1x1': 4, + 'Conv2d_4a_3x3': 5, + 'maxpool2': 6, + 'Mixed_5b': 7, + 'Mixed_5c': 8, + 'Mixed_5d': 9, + 'Mixed_6a': 10, + 'Mixed_6b': 11, + 'Mixed_6c': 12, + 'Mixed_6d': 13, + 'Mixed_6e': 14, + 'Mixed_7a': 15, + 'Mixed_7b': 16, + 'Mixed_7c': 17, + } + return layers if return_inds else list(layers.keys()) + + def _available_taskonomy_layers(architecture: str) -> List[str]: return [*_available_resnet_layers(architecture), 'encoder'] @@ -109,10 +177,22 @@ def _available_imagenet_layers(architecture: str) -> List[str]: common_layers = _available_vgg_layers(architecture) elif architecture == 'alexnet': common_layers = _available_alexnet_layers(architecture) + elif architecture == 'googlenet': + common_layers = _available_googlenet_layers(architecture) + elif architecture == 'inception_v3': + common_layers = _available_inception_layers(architecture) elif 'convnext' in architecture: common_layers = _available_convnext_layers(architecture) + elif 'efficientnet' in architecture: + common_layers = _available_efficientnet_layers(architecture) elif 'densenet' in architecture: common_layers = _available_densenet_layers(architecture) + elif 'mnasnet' in architecture: + common_layers = _available_mnasnet_layers(architecture) + elif 'shufflenet' in architecture: + common_layers = _available_shufflenet_layers(architecture) + elif 'squeezenet' in architecture: + common_layers = _available_squeezenet_layers(architecture) elif 'regnet' in architecture: common_layers = _available_regnet_layers(architecture) elif 'mobilenet' in architecture: @@ -160,3 +240,17 @@ def resnet_layer(layer: str, is_clip: Optional[bool] = False) -> int: else: raise RuntimeError('Unsupported resnet layer %s' % layer) return layer_ind + + +def googlenet_cutoff_slice(layer: str) -> Union[int, None]: + """Returns the index of a GoogLeNet layer to cutoff the network.""" + layers_dict = _available_googlenet_layers(return_inds=True) + cutoff_ind = None if layer == 'fc' else layers_dict[layer] + 1 + return cutoff_ind + + +def inception_cutoff_slice(layer: str) -> Union[int, None]: + """Returns the index of an Inception layer to cutoff the network.""" + layers_dict = _available_inception_layers(return_inds=True) + cutoff_ind = None if layer == 'fc' else layers_dict[layer] + 1 + return cutoff_ind diff --git a/osculari/models/pretrained_models.py b/osculari/models/pretrained_models.py index 760cbfd..3407f5e 100644 --- a/osculari/models/pretrained_models.py +++ b/osculari/models/pretrained_models.py @@ -184,15 +184,17 @@ def _vit_features(model: nn.Module, layer: str) -> ViTLayers: return ViTLayers(model, layer) -def _sequential_features(model: nn.Module, layer: str, architecture: str) -> nn.Module: +def _sequential_features(model: nn.Module, layer: str, architecture: str, + avgpool: Optional[bool] = True) -> nn.Module: """Creating a feature extractor from sequential network.""" if 'feature' in layer: layer = int(layer.replace('feature', '')) + 1 features = nn.Sequential(*list(model.features.children())[:layer]) elif 'classifier' in layer: layer = int(layer.replace('classifier', '')) + 1 + avgpool_layers = [model.avgpool, nn.Flatten(1)] if avgpool else [] features = nn.Sequential( - model.features, model.avgpool, nn.Flatten(1), *list(model.classifier.children())[:layer] + model.features, *avgpool_layers, *list(model.classifier.children())[:layer] ) else: raise RuntimeError('Unsupported %s layer %s' % (architecture, layer)) @@ -219,6 +221,40 @@ def _convnext_features(model: nn.Module, layer: str) -> nn.Module: return _sequential_features(model, layer, 'convnext') +def _squeezenet_features(model: nn.Module, layer: str) -> nn.Module: + """Creating a feature extractor from SqueezeNet network.""" + return _sequential_features(model, layer, 'squeezenet', avgpool=False) + + +def _efficientnet_features(model: nn.Module, layer: str) -> nn.Module: + """Creating a feature extractor from EfficientNet network.""" + return _sequential_features(model, layer, 'efficientnet') + + +def _googlenet_features(model: nn.Module, layer: str) -> nn.Module: + """Creating a feature extractor from GoogLeNet network.""" + l_ind = pretrained_layers.googlenet_cutoff_slice(layer) + return nn.Sequential(*list(model.children())[:l_ind]) + + +def _inception_features(model: nn.Module, layer: str) -> nn.Module: + """Creating a feature extractor from Inception network.""" + l_ind = pretrained_layers.inception_cutoff_slice(layer) + return nn.Sequential(*list(model.children())[:l_ind]) + + +def _mnasnet_features(model: nn.Module, layer: str) -> nn.Module: + """Creating a feature extractor from MnasNet network.""" + l_ind = int(layer.replace('layer', '')) + 1 + return nn.Sequential(*list(model.layers.children())[:l_ind]) + + +def _shufflenet_features(model: nn.Module, layer: str) -> nn.Module: + """Creating a feature extractor from ShuffleNet network.""" + l_ind = int(layer.replace('layer', '')) + 1 + return nn.Sequential(*list(model.children())[:l_ind]) + + def _densenet_features(model: nn.Module, layer: str) -> nn.Module: """Creating a feature extractor from DenseNet network.""" return _sequential_features(model, layer, 'densenet') @@ -286,10 +322,22 @@ def model_features(model: nn.Module, architecture: str, layer: str, img_size: in features = _vgg_features(model, layer) elif architecture == 'alexnet': features = _alexnet_features(model, layer) + elif architecture == 'googlenet': + features = _googlenet_features(model, layer) + elif architecture == 'inception_v3': + features = _inception_features(model, layer) elif 'convnext' in architecture: features = _convnext_features(model, layer) elif 'densenet' in architecture: features = _densenet_features(model, layer) + elif 'mnasnet' in architecture: + features = _mnasnet_features(model, layer) + elif 'shufflenet' in architecture: + features = _shufflenet_features(model, layer) + elif 'squeezenet' in architecture: + features = _squeezenet_features(model, layer) + elif 'efficientnet' in architecture: + features = _efficientnet_features(model, layer) elif 'mobilenet' in architecture: features = _mobilenet_features(model, layer) elif 'vit_' in architecture: @@ -363,7 +411,8 @@ def get_pretrained_model(network_name: str, weights: str, img_size: int) -> nn.M # torchvision networks weights = _torchvision_weights(network_name, weights) net_fun = torch_models.segmentation if network_name in _TORCHVISION_SEGMENTATION else torch_models - model = net_fun.__dict__[network_name](weights=weights) + kwargs = {'aux_logits': False} if network_name in ['googlenet', 'inception_v3'] else {} + model = net_fun.__dict__[network_name](weights=weights, **kwargs) return model