Skip to content

Commit

Permalink
Merge pull request #3 from FireHead90544/integration
Browse files Browse the repository at this point in the history
feat: create inference chain & implement first working prototype
  • Loading branch information
FireHead90544 authored Jul 27, 2024
2 parents a3dba00 + b2e4b15 commit d7b2fe5
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 2 deletions.
4 changes: 4 additions & 0 deletions core/chains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from core.template import PROMPT_TEMPLATE
from core.llm import LLM

HOW_CLI_CHAIN = PROMPT_TEMPLATE | LLM
9 changes: 9 additions & 0 deletions core/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from core.config import Config
from core.providers import LLM_PROVIDERS

llm_config = Config().values

LLM = LLM_PROVIDERS.get(llm_config.get("provider"))['provider'](
model = LLM_PROVIDERS.get(llm_config.get("provider"))['model'],
api_key = llm_config.get("api_key")
)
2 changes: 1 addition & 1 deletion core/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
os.environ["GRPC_VERBOSITY"] = "NONE"

LLM_PROVIDERS = {
"Gemini": { "provider": ChatGoogleGenerativeAI, "model": "gemini-1.5-pro" },
"Gemini": { "provider": ChatGoogleGenerativeAI, "model": "gemini-1.5-flash" },
}
36 changes: 36 additions & 0 deletions formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from rich.table import Table
from rich import box


def display_result(task: str, result: dict):
console = Console()

task_panel = Panel(
Text(task, style="bold magenta"),
title="Task",
border_style="cyan",
expand=False,
)
console.print(task_panel)

if result["status"] != "success":
console.print(f"Status: [bold red]{result['status']}[/bold red]")
return

command_table = Table(box=box.ROUNDED, expand=True, show_header=False)
command_table.add_column("Commands", style="green")
for command in result["commands"]:
command_table.add_row(command)

confidence_panel = Panel(
f"{result['confidence']:.2%}",
title="Confidence Score",
border_style="yellow",
expand=False,
)

console.print(Panel(command_table, title="Commands", border_style="green"))
console.print(confidence_panel)
5 changes: 4 additions & 1 deletion how.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing_extensions import Annotated
from rich.prompt import Prompt
from core.config import Config
from formatting import display_result

app = typer.Typer(
name="how",
Expand All @@ -21,7 +22,9 @@ def to(
typer.secho("Please setup the configuration first using `how setup`", fg="red", bold=True)
raise typer.Abort()

print(f"Send {task} to LLM. Test Finished.")
from infer import get_result
result = get_result(task)
display_result(task, result)


@app.command()
Expand Down
28 changes: 28 additions & 0 deletions infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import warnings
from core.chains import HOW_CLI_CHAIN
from core.parser import PARSER

def get_result(task: str) -> dict[str, str | list | float]:
"""Invokes the chain with the given task and returns the result.
Args:
task (str): The task to perform.
Returns:
dict: The result of the chain.
"""
tries = 3
parsed = {"status": "error", "commands": [], "confidence": 0.0}

while tries > 0:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
res = HOW_CLI_CHAIN.invoke({ "task": task })

try:
parsed = PARSER.invoke(res)
break
except Exception as e:
tries -= 1

return parsed

0 comments on commit d7b2fe5

Please sign in to comment.