Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: create inference chain & implement first working prototype #3

Merged
merged 6 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/chains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from core.template import PROMPT_TEMPLATE
from core.llm import LLM

HOW_CLI_CHAIN = PROMPT_TEMPLATE | LLM
9 changes: 9 additions & 0 deletions core/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from core.config import Config
from core.providers import LLM_PROVIDERS

llm_config = Config().values

LLM = LLM_PROVIDERS.get(llm_config.get("provider"))['provider'](
model = LLM_PROVIDERS.get(llm_config.get("provider"))['model'],
api_key = llm_config.get("api_key")
)
2 changes: 1 addition & 1 deletion core/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
os.environ["GRPC_VERBOSITY"] = "NONE"

LLM_PROVIDERS = {
"Gemini": { "provider": ChatGoogleGenerativeAI, "model": "gemini-1.5-pro" },
"Gemini": { "provider": ChatGoogleGenerativeAI, "model": "gemini-1.5-flash" },
}
36 changes: 36 additions & 0 deletions formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from rich.table import Table
from rich import box


def display_result(task: str, result: dict):
console = Console()

task_panel = Panel(
Text(task, style="bold magenta"),
title="Task",
border_style="cyan",
expand=False,
)
console.print(task_panel)

if result["status"] != "success":
console.print(f"Status: [bold red]{result['status']}[/bold red]")
return

command_table = Table(box=box.ROUNDED, expand=True, show_header=False)
command_table.add_column("Commands", style="green")
for command in result["commands"]:
command_table.add_row(command)

confidence_panel = Panel(
f"{result['confidence']:.2%}",
title="Confidence Score",
border_style="yellow",
expand=False,
)

console.print(Panel(command_table, title="Commands", border_style="green"))
console.print(confidence_panel)
5 changes: 4 additions & 1 deletion how.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing_extensions import Annotated
from rich.prompt import Prompt
from core.config import Config
from formatting import display_result

app = typer.Typer(
name="how",
Expand All @@ -21,7 +22,9 @@ def to(
typer.secho("Please setup the configuration first using `how setup`", fg="red", bold=True)
raise typer.Abort()

print(f"Send {task} to LLM. Test Finished.")
from infer import get_result
result = get_result(task)
display_result(task, result)


@app.command()
Expand Down
28 changes: 28 additions & 0 deletions infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import warnings
from core.chains import HOW_CLI_CHAIN
from core.parser import PARSER

def get_result(task: str) -> dict[str, str | list | float]:
"""Invokes the chain with the given task and returns the result.

Args:
task (str): The task to perform.

Returns:
dict: The result of the chain.
"""
tries = 3
parsed = {"status": "error", "commands": [], "confidence": 0.0}

while tries > 0:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
res = HOW_CLI_CHAIN.invoke({ "task": task })

try:
parsed = PARSER.invoke(res)
break
except Exception as e:
tries -= 1

return parsed