From b6b112e416b91fee80d44e941c0f95ffba79fc55 Mon Sep 17 00:00:00 2001 From: KOSASIH Date: Sat, 1 Jun 2024 10:58:35 +0700 Subject: [PATCH] Create transfer_learning_test.py --- ai/models/transfer_learning_test.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 ai/models/transfer_learning_test.py diff --git a/ai/models/transfer_learning_test.py b/ai/models/transfer_learning_test.py new file mode 100644 index 0000000..af96512 --- /dev/null +++ b/ai/models/transfer_learning_test.py @@ -0,0 +1,28 @@ +import unittest +from tensorflow import keras +from galactic_chain.ai.models import TransferLearning, TransferLearningModel + +class TestTransferLearning(unittest.TestCase): + def test_freeze_layers(self): + pretrained_model = keras.Sequential([keras.layers.Dense(10, input_shape=(10,), activation='relu'), keras.layers.Dense(5, activation='relu')]) + model = keras.Sequential([keras.layers.Dense(5, input_shape=(10,), activation='relu'), keras.layers.Dense(1, activation='sigmoid')]) + transfer_learning = TransferLearning(model, pretrained_model) + transfer_learning.freeze_layers(pretrained_model.layers[:-1]) + self.assertFalse(pretrained_model.layers[-1].trainable) + + def test_fine_tune(self): + pretrained_model = keras.Sequential([keras.layers.Dense(10, input_shape=(10,), activation='relu'), keras.layers.Dense(5, activation='relu')]) + model = keras.Sequential([keras.layers.Dense(5, input_shape=(10,), activation='relu'), keras.layers.Dense(1, activation='sigmoid')]) + transfer_learning = TransferLearning(model, pretrained_model) + transfer_learning.fine_tune(pretrained_model.layers[-1:], 10, 0.001) + self.assertTrue(pretrained_model.layers[-1].trainable) + + def test_transfer_learning_model(self): + pretrained_model = keras.Sequential([keras.layers.Dense(10, input_shape=(10,), activation='relu'), keras.layers.Dense(5, activation='relu')]) + model = keras.Sequential([keras.layers.Dense(5, input_shape=(10,), activation='relu'), keras.layers.Dense(1, activation='sigmoid')]) + transfer_learning_model = TransferLearningModel(model, pretrained_model) + transfer_learning = transfer_learning_model.get_transfer_learning(pretrained_model.layers[:-1], pretrained_model.layers[-1:], 10, 0.001) + self.assertIsInstance(transfer_learning, TransferLearning) + +if __name__ == '__main__': + unittest.main()