From 9eb8bc720011df2e0e996fd63191966dd0e523cd Mon Sep 17 00:00:00 2001 From: KOSASIH Date: Sat, 1 Jun 2024 10:56:21 +0700 Subject: [PATCH] Create active_learning_test.py --- ai/models/active_learning_test.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 ai/models/active_learning_test.py diff --git a/ai/models/active_learning_test.py b/ai/models/active_learning_test.py new file mode 100644 index 0000000..9df7527 --- /dev/null +++ b/ai/models/active_learning_test.py @@ -0,0 +1,28 @@ +import unittest +from tensorflow import keras +from galactic_chain.ai.models import ActiveLearning, ActiveLearningModel + +class TestActiveLearning(unittest.TestCase): + def test_uncertainty_sampling(self): + model = keras.Sequential([keras.layers.Dense(10, input_shape=(10,), activation='relu'), keras.layers.Dense(1, activation='sigmoid')]) + data_manager = DataManager(data=[np.random.normal((10,)) for _ in range(100)], labels=[0 for _ in range(100)]) + active_learning = ActiveLearning(model, data_manager) + selected_data = active_learning.uncertainty_sampling(10) + self.assertEqual(len(selected_data), 10) + + def test_query_by_committee(self): + model = keras.Sequential([keras.layers.Dense(10, input_shape=(10,), activation='relu'), keras.layers.Dense(1, activation='sigmoid')]) + data_manager = DataManager(data=[np.random.normal((10,)) for _ in range(100)], labels=[0 for _ in range(100)]) + active_learning = ActiveLearning(model, data_manager) + selected_data = active_learning.query_by_committee(10) + self.assertEqual(len(selected_data), 10) + + def test_active_learning_model(self): + model = keras.Sequential([keras.layers.Dense(10, input_shape=(10,), activation='relu'), keras.layers.Dense(1, activation='sigmoid')]) + data_manager = DataManager(data=[np.random.normal((10,)) for _ in range(100)], labels=[0 for _ in range(100)]) + active_learning_model = ActiveLearningModel(model, data_manager) + selected_data = active_learning_model.get_active_learning_data(10, 'uncertainty_sampling') + self.assertEqual(len(selected_data), 10) + +if __name__ == '__main__': + unittest.main()