-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
33023df
commit b936ca1
Showing
8 changed files
with
210 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from swahiliNewsClassifier import classifierlogger | ||
from swahiliNewsClassifier.pipeline.stage_01_data_ingestion import DataIngestionTrainingPipeline | ||
# from swahiliNewsClassifier.pipeline.stage_02_prepare_base_model import PrepareBaseModelPipeline | ||
# from swahiliNewsClassifier.pipeline.stage_03_model_training import TrainingPipeline | ||
# from swahiliNewsClassifier.pipeline.stage_04_model_evaluation import EvaluationPipeline | ||
|
||
def run_pipeline_stage(stage_name, pipeline_class): | ||
""" | ||
Run a pipeline stage and handle logging and exceptions. | ||
Args: | ||
stage_name (str): The name of the stage to run. | ||
pipeline_class (class): The class of the pipeline stage to instantiate and run. | ||
""" | ||
try: | ||
classifierlogger.info("*********************************\n") | ||
classifierlogger.info(f">>>>>> {stage_name} started <<<<<<") | ||
pipeline = pipeline_class() | ||
pipeline.main() | ||
classifierlogger.info(f">>>>>> {stage_name} completed <<<<<<<\n") | ||
classifierlogger.info("**********************************\n") | ||
except Exception as e: | ||
classifierlogger.exception(f"An error occurred during {stage_name}: {e}") | ||
raise e | ||
|
||
if __name__ == '__main__': | ||
run_pipeline_stage("Data Ingestion Stage", DataIngestionTrainingPipeline) | ||
# run_pipeline_stage("Prepare Base Model Stage", PrepareBaseModelPipeline) | ||
# run_pipeline_stage("Model Training Stage", TrainingPipeline) | ||
# run_pipeline_stage("Model Evaluation Stage", EvaluationPipeline) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os | ||
import zipfile | ||
import gdown | ||
from swahiliNewsClassifier.entity.entities import DataIngestionConfig | ||
from swahiliNewsClassifier import classifierlogger | ||
|
||
class DataIngestion: | ||
def __init__(self, config: DataIngestionConfig): | ||
""" | ||
Initialize DataIngestion object with the provided configuration. | ||
Args: | ||
config (DataIngestionConfig): Configuration object for data ingestion. | ||
""" | ||
self.config = config | ||
|
||
def download_file(self): | ||
"""Fetch data from a URL. | ||
Raises: | ||
Exception: If an error occurs during the download process. | ||
""" | ||
os.makedirs("artifacts/data_ingestion", exist_ok=True) | ||
dataset_urls = [self.config.train_source_URL, self.config.test_source_URL] | ||
zip_download_dir = [self.config.train_data_file, self.config.test_data_file] | ||
|
||
for url, dest in zip(dataset_urls, zip_download_dir): | ||
try: | ||
classifierlogger.info(f"Downloading data from {url} into file {dest}") | ||
|
||
file_id = url.split("/")[-2] | ||
prefix = "https://drive.google.com/uc?/export=download&id=" | ||
gdown.download(prefix + file_id, dest) | ||
|
||
classifierlogger.info(f"Downloaded data from {url} into file {dest}") | ||
except Exception as e: | ||
classifierlogger.error(f"Error downloading file from {url} to {dest}") | ||
raise e | ||
|
||
def extract_zip_file(self): | ||
"""Extract a zip file. | ||
This method extracts the contents of a zip file specified in the configuration | ||
to the directory specified in the configuration. | ||
Raises: | ||
Exception: If an error occurs during the extraction process. | ||
""" | ||
zip_download_dir = [self.config.train_data_file, self.config.test_data_file] | ||
unzip_path = self.config.unzip_dir | ||
os.makedirs(unzip_path, exist_ok=True) | ||
|
||
for zip_file in zip_download_dir: | ||
try: | ||
with zipfile.ZipFile(zip_file, "r") as zip_ref: | ||
zip_ref.extractall(unzip_path) | ||
|
||
classifierlogger.info(f"Extracted zip file {zip_file} into: {unzip_path}") | ||
except Exception as e: | ||
classifierlogger.error(f"Error extracting zip file: {zip_file}") | ||
raise e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from swahiliNewsClassifier.constants import * | ||
from swahiliNewsClassifier.utilities.helper_functions import read_yaml, create_directories | ||
from swahiliNewsClassifier.entity.entities import DataIngestionConfig | ||
|
||
class ConfigurationManager: | ||
def __init__(self, config_filepath=CONFIG_FILE_PATH, params_filepath=PARAMS_FILE_PATH): | ||
""" | ||
Initialize ConfigurationManager with configuration and parameter files. | ||
Args: | ||
config_filepath (str): Path to the configuration YAML file. | ||
params_filepath (str): Path to the parameters YAML file. | ||
""" | ||
self.config = read_yaml(config_filepath) | ||
self.params = read_yaml(params_filepath) | ||
|
||
create_directories([self.config.artifacts_root]) | ||
|
||
def get_data_ingestion_config(self) -> DataIngestionConfig: | ||
""" | ||
Get the data ingestion configuration. | ||
Returns: | ||
DataIngestionConfig: Configuration object for data ingestion. | ||
""" | ||
config = self.config.data_ingestion | ||
|
||
create_directories([config.root_dir]) | ||
|
||
return DataIngestionConfig( | ||
root_dir=config.root_dir, | ||
train_source_URL=config.train_source_URL, | ||
test_source_URL=config.test_source_URL, | ||
train_data_file=config.train_data_file, | ||
test_data_file=config.test_data_file, | ||
unzip_dir=config.unzip_dir | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
|
||
|
||
@dataclass(frozen=True) | ||
class DataIngestionConfig: | ||
""" | ||
Configuration class for data ingestion process. | ||
Attributes: | ||
root_dir (Path): The root directory where data will be stored or processed. | ||
source_URL (str): The URL from which data will be fetched. | ||
local_data_file (Path): The local file path where the downloaded data will be stored. | ||
unzip_dir (Path): The directory where the downloaded data will be extracted or unzipped. | ||
""" | ||
root_dir: Path | ||
source_URL: str | ||
local_data_file: Path | ||
unzip_dir: Path |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from swahiliNewsClassifier.configuration.configuration import ConfigurationManager | ||
from swahiliNewsClassifier.components.data_ingestion import DataIngestion | ||
from swahiliNewsClassifier import classifierlogger | ||
|
||
STAGE_NAME = "Data Ingestion Stage" | ||
|
||
class DataIngestionTrainingPipeline: | ||
def __init__(self): | ||
""" | ||
Initialize the DataIngestionTrainingPipeline object. | ||
""" | ||
self.config = ConfigurationManager() | ||
|
||
def main(self): | ||
""" | ||
Execute the data ingestion process. | ||
""" | ||
try: | ||
classifierlogger.info(f"Starting {STAGE_NAME}") | ||
data_ingestion_config = self.config.get_data_ingestion_config() | ||
data_ingestion = DataIngestion(config=data_ingestion_config) | ||
data_ingestion.download_file() | ||
data_ingestion.extract_zip_file() | ||
classifierlogger.info(f"Completed {STAGE_NAME}\n\n**********************************") | ||
except Exception as e: | ||
classifierlogger.exception(f"An error occurred during {STAGE_NAME}: {e}") | ||
raise e | ||
|
||
if __name__ == '__main__': | ||
pipeline = DataIngestionTrainingPipeline() | ||
pipeline.main() |