diff --git a/plugins/models/efficientnet.py b/plugins/models/efficientnet.py index d7540cb..74dc875 100644 --- a/plugins/models/efficientnet.py +++ b/plugins/models/efficientnet.py @@ -4,6 +4,7 @@ import efficientnet.tfkeras as efn from deepprofiler.imaging.augmentations import AugmentationLayer +from deepprofiler.imaging.augmentations import AugmentationLayerV2 from deepprofiler.learning.model import DeepProfilerModel from deepprofiler.learning.tf2train import DeepProfilerModelV2 @@ -12,12 +13,16 @@ def model_factory(config, dset, crop_generator, val_crop_generator, is_training) if inspect.currentframe().f_back.f_code.co_name == 'learn_model_v2': tf.compat.v1.enable_v2_behavior() tf.config.run_functions_eagerly(True) - return createModelClass(DeepProfilerModelV2, config, dset, crop_generator, val_crop_generator, is_training) + augmentation_base = AugmentationLayerV2() + return createModelClass(DeepProfilerModelV2, config, dset, crop_generator, + val_crop_generator, is_training, augmentation_base) else: - return createModelClass(DeepProfilerModel, config, dset, crop_generator, val_crop_generator, is_training) + augmentation_base = AugmentationLayer() + return createModelClass(DeepProfilerModel, config, dset, crop_generator, + val_crop_generator, is_training, augmentation_base) -def createModelClass(base, config, dset, crop_generator, val_crop_generator, is_training): +def createModelClass(base, config, dset, crop_generator, val_crop_generator, is_training, augmentation_base): class ModelClass(base): def __init__(self, config, dset, crop_generator, val_crop_generator, is_training): super(ModelClass, self).__init__(config, dset, crop_generator, val_crop_generator, is_training) @@ -99,7 +104,7 @@ def define_model(self, config, dset): if self.config["train"]["model"].get("augmentations") is True: model = tf.compat.v1.keras.models.model_from_json( model.to_json(), - {'AugmentationLayer': AugmentationLayer} + {'AugmentationLayer': augmentation_base} ) else: model = tf.compat.v1.keras.models.model_from_json(model.to_json()) @@ -108,7 +113,7 @@ def define_model(self, config, dset): def copy_pretrained_weights(self): base_model = self.get_model(self.config, weights="imagenet") - lshift = self.feature_model.layers[1].name == 'augmentation_layer' # Shift one layer to accommodate the AugmentationLayer + lshift = self.feature_model.layers[1].name == 'augmentation_layer_1' # Shift one layer to accommodate the AugmentationLayer # => Transfer all weights except conv1.1 total_layers = len(base_model.layers) diff --git a/plugins/models/resnet.py b/plugins/models/resnet.py index 97a2181..6505033 100644 --- a/plugins/models/resnet.py +++ b/plugins/models/resnet.py @@ -5,6 +5,7 @@ from deepprofiler.learning.model import DeepProfilerModel from deepprofiler.learning.tf2train import DeepProfilerModelV2 from deepprofiler.imaging.augmentations import AugmentationLayer +from deepprofiler.imaging.augmentations import AugmentationLayerV2 ################################################## @@ -19,12 +20,16 @@ def model_factory(config, dset, crop_generator, val_crop_generator, is_training) if inspect.currentframe().f_back.f_code.co_name == 'learn_model_v2': tf.compat.v1.enable_v2_behavior() tf.config.run_functions_eagerly(True) - return createModelClass(DeepProfilerModelV2, config, dset, crop_generator, val_crop_generator, is_training) + augmentation_base = AugmentationLayerV2() + return createModelClass(DeepProfilerModelV2, config, dset, crop_generator, + val_crop_generator, is_training, augmentation_base) else: - return createModelClass(DeepProfilerModel, config, dset, crop_generator, val_crop_generator, is_training) + augmentation_base = AugmentationLayer() + return createModelClass(DeepProfilerModel, config, dset, crop_generator, + val_crop_generator, is_training, augmentation_base) -def createModelClass(base, config, dset, crop_generator, val_crop_generator, is_training): +def createModelClass(base, config, dset, crop_generator, val_crop_generator, is_training, augmentation_base): class ModelClass(base): def __init__(self, config, dset, crop_generator, val_crop_generator, is_training): super(ModelClass, self).__init__(config, dset, crop_generator, val_crop_generator, is_training) @@ -105,7 +110,7 @@ def define_model(self, config, dset): if self.config["train"]["model"].get("augmentations") is True: model = tf.compat.v1.keras.models.model_from_json( model.to_json(), - {'AugmentationLayer': AugmentationLayer} + {'AugmentationLayer': augmentation_base} ) else: model = tf.compat.v1.keras.models.model_from_json(model.to_json()) @@ -117,7 +122,7 @@ def define_model(self, config, dset): ## Support for ImageNet initialization def copy_pretrained_weights(self): base_model = self.get_model(self.config, weights="imagenet") - lshift = self.feature_model.layers[1].name == 'augmentation_layer' # Shift one layer to accommodate the AugmentationLayer + lshift = self.feature_model.layers[1].name == 'augmentation_layer_1' # Shift one layer to accommodate the AugmentationLayer # => Transfer all weights except conv1.1 total_layers = len(base_model.layers)