diff --git a/src/osc_transformer_based_extractor/kpi_detection/cli_kpi_detection.py b/src/osc_transformer_based_extractor/kpi_detection/cli_kpi_detection.py index 8cae7d4..e9344d1 100644 --- a/src/osc_transformer_based_extractor/kpi_detection/cli_kpi_detection.py +++ b/src/osc_transformer_based_extractor/kpi_detection/cli_kpi_detection.py @@ -1,4 +1,5 @@ import typer +import os from .train_kpi_detection import ( train_kpi_detection, check_output_dir, @@ -43,6 +44,7 @@ def fine_tune_qna( output_dir: str = typer.Argument( ..., help="Directory to save the fine-tuned model." ), + export_model_name: str = typer.Argument(..., help="Name of the model to export."), save_steps: int = typer.Argument( ..., help="Number of steps between saving model checkpoints." ), @@ -59,10 +61,13 @@ def fine_tune_qna( batch_size=batch_size, learning_rate=learning_rate, output_dir=output_dir, + export_model_name=export_model_name, save_steps=save_steps, ) - - typer.echo(f"Model '{model_name}' trained and saved successfully at {output_dir}") + saved_model_path = os.path.join(output_dir, f"{export_model_name}") + typer.echo( + f"Model '{model_name}' is trained and saved successfully at {saved_model_path}" + ) @kpi_detection_app.command("inference") @@ -76,6 +81,7 @@ def inference_qna( model_path: str = typer.Argument( ..., help="Path to the pre-trained model directory OR name on huggingface." ), + batch_size: int = typer.Argument(16, help="The batch size for inference."), ): """Perform inference using a pre-trained model on a dataset of kpis and contexts, saving an output Excel file.""" try: @@ -86,6 +92,7 @@ def inference_qna( data_file_path=data_file_path, output_path=output_path, model_path=model_path, + batch_size=batch_size, ) typer.echo("Inference completed successfully!")