Skip to content

Commit

Permalink
Augmentation layer choice between trainings.
Browse files Browse the repository at this point in the history
  • Loading branch information
Arkkienkeli committed Nov 4, 2021
1 parent 106d188 commit 1ca2cb5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
15 changes: 10 additions & 5 deletions plugins/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import efficientnet.tfkeras as efn

from deepprofiler.imaging.augmentations import AugmentationLayer
from deepprofiler.imaging.augmentations import AugmentationLayerV2
from deepprofiler.learning.model import DeepProfilerModel
from deepprofiler.learning.tf2train import DeepProfilerModelV2

Expand All @@ -12,12 +13,16 @@ def model_factory(config, dset, crop_generator, val_crop_generator, is_training)
if inspect.currentframe().f_back.f_code.co_name == 'learn_model_v2':
tf.compat.v1.enable_v2_behavior()
tf.config.run_functions_eagerly(True)
return createModelClass(DeepProfilerModelV2, config, dset, crop_generator, val_crop_generator, is_training)
augmentation_base = AugmentationLayerV2()
return createModelClass(DeepProfilerModelV2, config, dset, crop_generator,
val_crop_generator, is_training, augmentation_base)
else:
return createModelClass(DeepProfilerModel, config, dset, crop_generator, val_crop_generator, is_training)
augmentation_base = AugmentationLayer()
return createModelClass(DeepProfilerModel, config, dset, crop_generator,
val_crop_generator, is_training, augmentation_base)


def createModelClass(base, config, dset, crop_generator, val_crop_generator, is_training):
def createModelClass(base, config, dset, crop_generator, val_crop_generator, is_training, augmentation_base):
class ModelClass(base):
def __init__(self, config, dset, crop_generator, val_crop_generator, is_training):
super(ModelClass, self).__init__(config, dset, crop_generator, val_crop_generator, is_training)
Expand Down Expand Up @@ -99,7 +104,7 @@ def define_model(self, config, dset):
if self.config["train"]["model"].get("augmentations") is True:
model = tf.compat.v1.keras.models.model_from_json(
model.to_json(),
{'AugmentationLayer': AugmentationLayer}
{'AugmentationLayer': augmentation_base}
)
else:
model = tf.compat.v1.keras.models.model_from_json(model.to_json())
Expand All @@ -108,7 +113,7 @@ def define_model(self, config, dset):

def copy_pretrained_weights(self):
base_model = self.get_model(self.config, weights="imagenet")
lshift = self.feature_model.layers[1].name == 'augmentation_layer' # Shift one layer to accommodate the AugmentationLayer
lshift = self.feature_model.layers[1].name == 'augmentation_layer_1' # Shift one layer to accommodate the AugmentationLayer

# => Transfer all weights except conv1.1
total_layers = len(base_model.layers)
Expand Down
15 changes: 10 additions & 5 deletions plugins/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from deepprofiler.learning.model import DeepProfilerModel
from deepprofiler.learning.tf2train import DeepProfilerModelV2
from deepprofiler.imaging.augmentations import AugmentationLayer
from deepprofiler.imaging.augmentations import AugmentationLayerV2


##################################################
Expand All @@ -19,12 +20,16 @@ def model_factory(config, dset, crop_generator, val_crop_generator, is_training)
if inspect.currentframe().f_back.f_code.co_name == 'learn_model_v2':
tf.compat.v1.enable_v2_behavior()
tf.config.run_functions_eagerly(True)
return createModelClass(DeepProfilerModelV2, config, dset, crop_generator, val_crop_generator, is_training)
augmentation_base = AugmentationLayerV2()
return createModelClass(DeepProfilerModelV2, config, dset, crop_generator,
val_crop_generator, is_training, augmentation_base)
else:
return createModelClass(DeepProfilerModel, config, dset, crop_generator, val_crop_generator, is_training)
augmentation_base = AugmentationLayer()
return createModelClass(DeepProfilerModel, config, dset, crop_generator,
val_crop_generator, is_training, augmentation_base)


def createModelClass(base, config, dset, crop_generator, val_crop_generator, is_training):
def createModelClass(base, config, dset, crop_generator, val_crop_generator, is_training, augmentation_base):
class ModelClass(base):
def __init__(self, config, dset, crop_generator, val_crop_generator, is_training):
super(ModelClass, self).__init__(config, dset, crop_generator, val_crop_generator, is_training)
Expand Down Expand Up @@ -105,7 +110,7 @@ def define_model(self, config, dset):
if self.config["train"]["model"].get("augmentations") is True:
model = tf.compat.v1.keras.models.model_from_json(
model.to_json(),
{'AugmentationLayer': AugmentationLayer}
{'AugmentationLayer': augmentation_base}
)
else:
model = tf.compat.v1.keras.models.model_from_json(model.to_json())
Expand All @@ -117,7 +122,7 @@ def define_model(self, config, dset):
## Support for ImageNet initialization
def copy_pretrained_weights(self):
base_model = self.get_model(self.config, weights="imagenet")
lshift = self.feature_model.layers[1].name == 'augmentation_layer' # Shift one layer to accommodate the AugmentationLayer
lshift = self.feature_model.layers[1].name == 'augmentation_layer_1' # Shift one layer to accommodate the AugmentationLayer

# => Transfer all weights except conv1.1
total_layers = len(base_model.layers)
Expand Down

0 comments on commit 1ca2cb5

Please sign in to comment.