diff --git a/ai/models/transfer_learning.py b/ai/models/transfer_learning.py new file mode 100644 index 0000000..a8d443a --- /dev/null +++ b/ai/models/transfer_learning.py @@ -0,0 +1,36 @@ +import tensorflow as tf + +class TransferLearning: + def __init__(self, model, pretrained_model): + self.model = model + self.pretrained_model = pretrained_model + + def freeze_layers(self, layers): + for layer in layers: + layer.trainable = False + + def fine_tune(self, fine_tune_layers, fine_tune_epochs, fine_tune_learning_rate): + for layer in fine_tune_layers: + layer.trainable = True + + optimizer = tf.keras.optimizers.Adam(learning_rate=fine_tune_learning_rate) + self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) + + for epoch in range(fine_tune_epochs): + self.model.fit(self.pretrained_model.train_data, self.pretrained_model.train_labels, epochs=1, batch_size=32) + +class TransferLearningModel(keras.Model): + def __init__(self, model, pretrained_model): + super(TransferLearningModel, self).__init__() + self.model = model + self.pretrained_model = pretrained_model + + def call(self, inputs): + outputs = self.model(inputs) + return outputs + + def get_transfer_learning(self, freeze_layers, fine_tune_layers, fine_tune_epochs, fine_tune_learning_rate): + transfer_learning = TransferLearning(self.model, self.pretrained_model) + transfer_learning.freeze_layers(freeze_layers) + transfer_learning.fine_tune(fine_tune_layers, fine_tune_epochs, fine_tune_learning_rate) + return transfer_learning