Skip to content

Commit

Permalink
Create transfer_learning.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KOSASIH authored Jun 1, 2024
1 parent 9eb8bc7 commit 287b58d
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions ai/models/transfer_learning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import tensorflow as tf

class TransferLearning:
def __init__(self, model, pretrained_model):
self.model = model
self.pretrained_model = pretrained_model

def freeze_layers(self, layers):
for layer in layers:
layer.trainable = False

def fine_tune(self, fine_tune_layers, fine_tune_epochs, fine_tune_learning_rate):
for layer in fine_tune_layers:
layer.trainable = True

optimizer = tf.keras.optimizers.Adam(learning_rate=fine_tune_learning_rate)
self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

for epoch in range(fine_tune_epochs):
self.model.fit(self.pretrained_model.train_data, self.pretrained_model.train_labels, epochs=1, batch_size=32)

class TransferLearningModel(keras.Model):
def __init__(self, model, pretrained_model):
super(TransferLearningModel, self).__init__()
self.model = model
self.pretrained_model = pretrained_model

def call(self, inputs):
outputs = self.model(inputs)
return outputs

def get_transfer_learning(self, freeze_layers, fine_tune_layers, fine_tune_epochs, fine_tune_learning_rate):
transfer_learning = TransferLearning(self.model, self.pretrained_model)
transfer_learning.freeze_layers(freeze_layers)
transfer_learning.fine_tune(fine_tune_layers, fine_tune_epochs, fine_tune_learning_rate)
return transfer_learning

0 comments on commit 287b58d

Please sign in to comment.