Skip to content

Commit

Permalink
Remove dataclass and write __init__ method.
Browse files Browse the repository at this point in the history
Only call add_message in Client classes.
  • Loading branch information
yaph committed Oct 14, 2024
1 parent 2f5cdb6 commit 4ac0d3b
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
3 changes: 0 additions & 3 deletions charla/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ def run(argv: argparse.Namespace) -> None:
if not (user_input := session.prompt()):
continue

client.add_message(role='user', text=user_input)
#output.append(f'{session.message}{user_input}\n')

# Handle OPEN command input and continue to next prompt.
if session.message == ui.t_open:
open_location = user_input.strip()
Expand Down
12 changes: 6 additions & 6 deletions charla/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, NamedTuple


Expand All @@ -8,12 +7,13 @@ class ModelInfo(NamedTuple):
context_length: int


@dataclass
class Client(ABC):
model: str
context: Any = field(default_factory=list)
message_history: list[dict] = field(default_factory=list)
system: str = ''
def __init__(self, model: str, system: str = ''):
self.model = model
self.system = system
# For chatting with memory and writing output.
self.context: Any = []
self.message_history: list[dict] = []

@abstractmethod
def generate(self, prompt: str):
Expand Down
1 change: 1 addition & 0 deletions charla/client/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,4 @@ def generate(self, prompt: str):

self.context.append(AssistantMessage(content=text))
self.add_message(role='assistant', text=text)

2 changes: 2 additions & 0 deletions charla/client/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def set_info(self):


def generate(self, prompt: str):
self.add_message(role='user', text=prompt)

response = self.client.generate(
model=self.model, prompt=prompt, context=self.context, stream=True, system=self.system)

Expand Down

0 comments on commit 4ac0d3b

Please sign in to comment.