Skip to content

Commit

Permalink
Update cli_kpi_detection.py
Browse files Browse the repository at this point in the history
Signed-off-by: tanishq-ids <166009643+tanishq-ids@users.noreply.github.com>
  • Loading branch information
tanishq-ids authored Nov 6, 2024
1 parent 3c740f4 commit e2ea30c
Showing 1 changed file with 9 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typer
import os
from .train_kpi_detection import (
train_kpi_detection,
check_output_dir,
Expand Down Expand Up @@ -43,6 +44,7 @@ def fine_tune_qna(
output_dir: str = typer.Argument(
..., help="Directory to save the fine-tuned model."
),
export_model_name: str = typer.Argument(..., help="Name of the model to export."),
save_steps: int = typer.Argument(
..., help="Number of steps between saving model checkpoints."
),
Expand All @@ -59,10 +61,13 @@ def fine_tune_qna(
batch_size=batch_size,
learning_rate=learning_rate,
output_dir=output_dir,
export_model_name=export_model_name,
save_steps=save_steps,
)

typer.echo(f"Model '{model_name}' trained and saved successfully at {output_dir}")
saved_model_path = os.path.join(output_dir, f"{export_model_name}")
typer.echo(
f"Model '{model_name}' is trained and saved successfully at {saved_model_path}"
)


@kpi_detection_app.command("inference")
Expand All @@ -76,6 +81,7 @@ def inference_qna(
model_path: str = typer.Argument(
..., help="Path to the pre-trained model directory OR name on huggingface."
),
batch_size: int = typer.Argument(16, help="The batch size for inference."),
):
"""Perform inference using a pre-trained model on a dataset of kpis and contexts, saving an output Excel file."""
try:
Expand All @@ -86,6 +92,7 @@ def inference_qna(
data_file_path=data_file_path,
output_path=output_path,
model_path=model_path,
batch_size=batch_size,
)

typer.echo("Inference completed successfully!")
Expand Down

0 comments on commit e2ea30c

Please sign in to comment.