Skip to content

Commit

Permalink
modifs
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinKalema committed Jun 11, 2024
1 parent 2e43f45 commit 8a6d524
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
26 changes: 13 additions & 13 deletions src/swahiliNewsClassifier/components/data_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@


class DataIngestion:
def __init__(self, data_ingestion_config: DataIngestionConfig):
def __init__(self, data_ingestion_configurations: DataIngestionConfig):
"""
Initialize DataIngestion object with the provided configuration.
Args:
data_ingestion_config (DataIngestionConfig): Configuration object for data ingestion.
data_ingestion_configurations (DataIngestionConfig): Configuration object for data ingestion.
"""
self.data_ingestion_config = data_ingestion_config
self.data_ingestion_configurations = data_ingestion_configurations

def download_file(self):
"""Fetch data from a URL.
Expand All @@ -24,11 +24,11 @@ def download_file(self):
os.makedirs("artifacts/data_ingestion/compressed", exist_ok=True)
os.makedirs("artifacts/data_ingestion/decompressed", exist_ok=True)
dataset_urls = [
self.data_ingestion_config.train_source_URL,
self.data_ingestion_config.test_source_URL]
self.data_ingestion_configurations.train_source_URL,
self.data_ingestion_configurations.test_source_URL]
zip_download_dir = [
self.data_ingestion_config.train_data_file,
self.data_ingestion_config.test_data_file]
self.data_ingestion_configurations.train_data_file,
self.data_ingestion_configurations.test_data_file]

for url, dest in zip(dataset_urls, zip_download_dir):
try:
Expand All @@ -53,17 +53,17 @@ def extract_zip_file(self):
Exception: If an error occurs during the extraction process.
"""
zip_download_dir = [
self.data_ingestion_config.train_data_file,
self.data_ingestion_config.test_data_file]
unzip_path = self.data_ingestion_config.decompressed_dir
os.makedirs(unzip_path, exist_ok=True)
self.data_ingestion_configurations.train_data_file,
self.data_ingestion_configurations.test_data_file]
decompress_path = self.data_ingestion_configurations.decompressed_dir
os.makedirs(decompress_path, exist_ok=True)

for zip_file in zip_download_dir:
try:
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(unzip_path)
zip_ref.extractall(decompress_path)

log.info(f"Extracted zip file {zip_file} into: {unzip_path}")
log.info(f"Extracted zip file {zip_file} into: {decompress_path}")
except Exception as e:
log.error(f"Error extracting zip file: {zip_file}")
raise e
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
load_dotenv()

class ModelTrainingAndEvaluation:
def __init__(self, model_training_and_evaluation_config: ModelTrainingAndEvaluationConfig):
def __init__(self, model_training_and_evaluation_configurations: ModelTrainingAndEvaluationConfig):
"""
Initialize ModelTraining object with the provided configuration.
Args:
model_training_and_evaluation_config (ModelTrainingConfig): Configuration object for model training.
model_training_and_evaluation_configurations (ModelTrainingConfig): Configuration object for model training.
"""
self.model_training_and_evaluation_config = model_training_and_evaluation_config
self.model_training_and_evaluation_configurations = model_training_and_evaluation_configurations
self.bucket_name = "swahili-news-classifier"
self.model_path = f"models/text_classifier_learner.pth"
self.s3 = boto3.client('s3', aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'), region_name=os.getenv('REGION_NAME'))
Expand All @@ -48,7 +48,7 @@ def load_data(self) -> pd.DataFrame:
pd.DataFrame: Loaded training data.
"""
log.info('Loading training data')
train = pd.read_csv(self.model_training_and_evaluation_config.training_data)
train = pd.read_csv(self.model_training_and_evaluation_configurations.training_data)
return train

def prepare_data(self, train) -> 'tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]':
Expand All @@ -61,7 +61,7 @@ def prepare_data(self, train) -> 'tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame
Returns:
tuple: A tuple containing training data (df_trn), validation data (df_val), and data for language model (df_lm).
"""
df_trn, df_val = train_test_split(train, stratify=train['category'], test_size=self.model_training_and_evaluation_config.test_size, random_state=123)
df_trn, df_val = train_test_split(train, stratify=train['category'], test_size=self.model_training_and_evaluation_configurations.test_size, random_state=123)
df_lm = pd.concat([df_trn, df_val], axis=0)[['content']]
return df_trn, df_val, df_lm

Expand All @@ -81,7 +81,7 @@ def create_dataloaders(self, df_lm) -> DataLoaders:
get_x=ColReader('text'),
splitter=RandomSplitter(0.1))

dls = dblock.dataloaders(df_lm, bs=self.model_training_and_evaluation_config.batch_size_1)
dls = dblock.dataloaders(df_lm, bs=self.model_training_and_evaluation_configurations.batch_size_1)
return dls

def train_language_model(self, dls) -> Learner:
Expand All @@ -97,7 +97,7 @@ def train_language_model(self, dls) -> Learner:
log.info('Training Language Model Learner')
learn = language_model_learner(dls, AWD_LSTM, drop_mult=0.3, metrics=[accuracy]).to_fp16()
learn.lr_find()
learn.fine_tune(self.model_training_and_evaluation_config.epochs_1, self.model_training_and_evaluation_config.learning_rate_1)
learn.fine_tune(self.model_training_and_evaluation_configurations.epochs_1, self.model_training_and_evaluation_configurations.learning_rate_1)

log.info('Saving best Language Model Learner.')

Expand All @@ -124,15 +124,15 @@ def create_text_classifier_dataloaders(self, df_trn, dls_lm) -> DataLoaders:
get_y=ColReader('category'),
splitter=RandomSplitter(0.2))

return dblock.dataloaders(df_trn, bs=self.model_training_and_evaluation_config.batch_size_2)
return dblock.dataloaders(df_trn, bs=self.model_training_and_evaluation_configurations.batch_size_2)

def log_to_mlflow(self, metrics: list) -> None:
os.environ['MLFLOW_TRACKING_URI'] = self.model_training_and_evaluation_config.mlflow_tracking_uri
os.environ['MLFLOW_TRACKING_URI'] = self.model_training_and_evaluation_configurations.mlflow_tracking_uri

dagshub.init(repo_owner=self.model_training_and_evaluation_config.mlflow_repo_owner, repo_name=self.model_training_and_evaluation_config.mlflow_repo_name, mlflow=True)
dagshub.init(repo_owner=self.model_training_and_evaluation_configurations.mlflow_repo_owner, repo_name=self.model_training_and_evaluation_configurations.mlflow_repo_name, mlflow=True)

with mlflow.start_run():
mlflow.log_params(self.model_training_and_evaluation_config.all_params)
mlflow.log_params(self.model_training_and_evaluation_configurations.all_params)
mlflow.log_metric('val_loss', metrics[0])
mlflow.log_metric('val_accuracy', metrics[1])

Expand All @@ -149,13 +149,13 @@ def train_text_classifier(self, dls) -> None:
learn = text_classifier_learner(dls, AWD_LSTM, metrics=[accuracy]).to_fp16()
learn.load_encoder(f'language_model_learner')
learn.lr_find()
learn.fit_one_cycle(self.model_training_and_evaluation_config.epochs_2, self.model_training_and_evaluation_config.learning_rate_2)
learn.fit_one_cycle(self.model_training_and_evaluation_configurations.epochs_2, self.model_training_and_evaluation_configurations.learning_rate_2)
learn.freeze_to(-2)
learn.fit_one_cycle(self.model_training_and_evaluation_config.epochs_3, slice(1e-3/(2.6**4), self.model_training_and_evaluation_config.learning_rate_3))
learn.fit_one_cycle(self.model_training_and_evaluation_configurations.epochs_3, slice(1e-3/(2.6**4), self.model_training_and_evaluation_configurations.learning_rate_3))
learn.freeze_to(-3)
learn.fit_one_cycle(self.model_training_and_evaluation_config.epochs_4, slice(5e-3/(2.6**4), self.model_training_and_evaluation_config.learning_rate_4))
learn.fit_one_cycle(self.model_training_and_evaluation_configurations.epochs_4, slice(5e-3/(2.6**4), self.model_training_and_evaluation_configurations.learning_rate_4))
learn.unfreeze()
learn.fit_one_cycle(self.model_training_and_evaluation_config.epochs_5, slice(1e-3/(2.6**4), self.model_training_and_evaluation_config.learning_rate_5))
learn.fit_one_cycle(self.model_training_and_evaluation_configurations.epochs_5, slice(1e-3/(2.6**4), self.model_training_and_evaluation_configurations.learning_rate_5))
classifier_metrics = learn.validate()
self.log_to_mlflow(classifier_metrics)
learn.save_encoder(f'text_classifier_learner')
Expand Down

0 comments on commit 8a6d524

Please sign in to comment.