diff --git a/tests/test_data_ingestion.py b/tests/test_data_ingestion.py index 07cdb84..e94a689 100644 --- a/tests/test_data_ingestion.py +++ b/tests/test_data_ingestion.py @@ -1,41 +1,44 @@ import pytest from unittest.mock import patch -from swahiliNewsClassifier.entity.entities import DataIngestionConfig -from swahiliNewsClassifier.components import DataIngestion +from swahiliNewsClassifier.components.data_ingestion import DataIngestion, DataIngestionConfig @pytest.fixture -def mock_config(): +def data_ingestion_config(): return DataIngestionConfig( + root_dir="artifacts/data_ingestion", train_source_URL="https://drive.google.com/file/d/15stuLDZkXNOgBUC1rnx5yXYdVPViUjNB/view?usp=sharing", test_source_URL="https://drive.google.com/file/d/1mjmYzMdnn_UwSEgTQ7i-cJ5WSOokt9Er/view?usp=sharing", train_data_file="artifacts/data_ingestion/compressed/train_data.zip", test_data_file="artifacts/data_ingestion/compressed/test_data.zip", - unzip_dir="artifacts/data_ingestion/decompressed" + unzip_dir="artifacts/data_ingestion/decompressed", ) @pytest.fixture -def data_ingestion(mock_config): - return DataIngestion(config=mock_config) +def data_ingestion(data_ingestion_config): + return DataIngestion(config=data_ingestion_config) + +@patch('swahiliNewsClassifier.components.data_ingestion.os.makedirs') +@patch('swahiliNewsClassifier.components.data_ingestion.gdown.download') +def test_download_file(mock_gdown_download, mock_makedirs, data_ingestion): + mock_gdown_download.return_value = None -@patch("swahiliNewsClassifier.data_ingestion.gdown.download") -def test_download_file(mock_gdown, data_ingestion): - mock_gdown.return_value = None - data_ingestion.download_file() - - assert mock_gdown.call_count == 2 - mock_gdown.assert_any_call("https://drive.google.com/file/d/15stuLDZkXNOgBUC1rnx5yXYdVPViUjNB/view?usp=sharing", - "artifacts/data_ingestion/compressed/train_data.zip") - mock_gdown.assert_any_call("https://drive.google.com/file/d/1mjmYzMdnn_UwSEgTQ7i-cJ5WSOokt9Er/view?usp=sharing", - "artifacts/data_ingestion/compressed/test_data.zip") - -@patch("zipfile.ZipFile.extractall") -@patch("zipfile.ZipFile.__init__") -def test_extract_zip_file(mock_zip_init, mock_extractall, data_ingestion): - mock_zip_init.return_value = None - + + assert mock_gdown_download.call_count == 2 + + mock_gdown_download.assert_any_call("https://drive.google.com/uc?/export=download&id=15stuLDZkXNOgBUC1rnx5yXYdVPViUjNB", "artifacts/data_ingestion/compressed/train_data.zip") + mock_gdown_download.assert_any_call("https://drive.google.com/uc?/export=download&id=1mjmYzMdnn_UwSEgTQ7i-cJ5WSOokt9Er", "artifacts/data_ingestion/compressed/test_data.zip") + +@patch('swahiliNewsClassifier.components.data_ingestion.os.makedirs') +@patch('swahiliNewsClassifier.components.data_ingestion.zipfile.ZipFile.extractall') +@patch('swahiliNewsClassifier.components.data_ingestion.zipfile.ZipFile') +def test_extract_zip_file(mock_zipfile, mock_extractall, mock_makedirs, data_ingestion): data_ingestion.extract_zip_file() - - assert mock_zip_init.call_count == 2 - assert mock_extractall.call_count == 2 - mock_extractall.assert_any_call("artifacts/data_ingestion/decompressed") + + assert mock_makedirs.call_count == 1 + mock_makedirs.assert_called_with("artifacts/data_ingestion/decompressed", exist_ok=True) + + assert mock_zipfile.call_count == 2 + mock_zipfile.assert_any_call("artifacts/data_ingestion/compressed/train_data.zip", "r") + mock_zipfile.assert_any_call("artifacts/data_ingestion/compressed/test_data.zip", "r") + assert mock_zipfile.return_value.__enter__.return_value.extractall.call_count == 2