Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/more metrics #73

Merged
merged 11 commits into from
Jan 29, 2024
129 changes: 71 additions & 58 deletions src/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,66 +2,79 @@
from main import main, get_runtime_kwargs
from dotenv import load_dotenv
from core.analysis import CodeRetryLimitExceeded
from datetime import datetime

load_dotenv()
together_token = os.getenv("TOGETHER_API_KEY")
openai_token = os.getenv("OPENAI_API_KEY")

prompting_techniques = ["zero-shot", "few-shot"]
assistants = [
"llama-chat",
"openai",
"mixtral-8x7b",
]
if __name__ == "__main__":
load_dotenv()
together_token = os.getenv("TOGETHER_API_KEY")
openai_token = os.getenv("OPENAI_API_KEY")

analysis_message_limit = 8
runtime = "jupyter-notebook"
report_params_no = {
"zero-shot_llama-chat": 1,
"zero-shot_openai": 2,
"zero-shot_mixtral-8x7b": 3,
"few-shot_llama-chat": 4,
"few-shot_openai": 5,
"few-shot_mixtral-8x7b": 6,
}
prompting_techniques = [
"zero-shot",
"few-shot",
]
assistants = [
"openai",
"mixtral-8x7b",
"llama-chat",
]

dataset_path = "data/<dataset>.csv"
dataset_name = dataset_path.split("/")[-1].split(".")[0]
analysis_message_limit = 8
runtime = "jupyter-notebook"
report_params_no = {
"zero-shot_llama-chat": 1,
"zero-shot_openai": 2,
"zero-shot_mixtral-8x7b": 3,
"few-shot_llama-chat": 4,
"few-shot_openai": 5,
"few-shot_mixtral-8x7b": 6,
}

for assistant in assistants:
for prompting_technique in prompting_techniques:
kwargs = get_runtime_kwargs(
runtime,
prompting_technique,
assistant,
)
kwargs["analysis_assistant_kwargs"]["api_key"] = (
openai_token if assistant == "openai" else together_token
)
kwargs["code_assistant_kwargs"]["api_key"] = (
openai_token if assistant == "openai" else together_token
)
report_no = report_params_no[f"{prompting_technique}_{assistant}"]
try:
output_pdf_path, error_count, code_messages_missing_snippets = main(
dataset_name,
dataset_path,
runtime,
assistant,
assistant,
prompting_technique,
analysis_message_limit=analysis_message_limit,
output_pdf_path=f"../{dataset_name}_{report_no}.pdf",
**kwargs,
)
except CodeRetryLimitExceeded as e:
print(e)
print(output_pdf_path)
print("Error Count:", error_count)
print("Code Messages Missing Snippets:", code_messages_missing_snippets)
# create text file with error count and code messages missing snippets
with open(f"data/{dataset_name}_{report_no}.txt", "w") as f:
f.write(f"Error Count: {error_count}\n")
f.write(
f"Code Messages Missing Snippets: {code_messages_missing_snippets}\n"
)
dataset_path = "data/wine-quality.csv"
dataset_name = dataset_path.split("/")[-1].split(".")[0]
ITERATIONS = 100

for iteration in range(1, ITERATIONS):
print(f"Iteration: {iteration}")
for assistant in assistants:
for prompting_technique in prompting_techniques:
kwargs = get_runtime_kwargs(
runtime,
prompting_technique,
assistant,
)
kwargs["analysis_assistant_kwargs"]["api_key"] = (
openai_token if assistant == "openai" else together_token
)
kwargs["code_assistant_kwargs"]["api_key"] = (
openai_token if assistant == "openai" else together_token
)
report_no = report_params_no[f"{prompting_technique}_{assistant}"]
try:
output_pdf_path, error_count, code_messages_missing_snippets = main(
dataset_name,
dataset_path,
runtime,
assistant,
assistant,
prompting_technique,
analysis_message_limit=analysis_message_limit,
output_pdf_path=f"../{dataset_name}_{report_no}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.pdf",
**kwargs,
)
except Exception as e:
print(e)
continue
print(output_pdf_path)
print("Error Count:", error_count)
print("Code Messages Missing Snippets:", code_messages_missing_snippets)
# create text file with error count and code messages missing snippets
with open(
f"data/{dataset_name}_{report_no}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.txt",
"w",
) as f:
f.write(f"Error Count: {error_count}\n")
f.write(
f"Code Messages Missing Snippets: {code_messages_missing_snippets}\n"
)
74 changes: 70 additions & 4 deletions src/core/analysis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import csv
from datetime import datetime
from typing import Union

Expand All @@ -20,6 +21,45 @@ def __init__(self, message="Exceeded code retry limit"):
super().__init__(self.message)


def save_to_csv(data: list) -> None:
"""
Save the results of the analysis to a CSV file.
"""
filename = "results.csv"
# Check if file exists
file_exists = False
try:
with open(filename, "r") as f:
file_exists = True
except FileNotFoundError:
file_exists = False

# Open the file in append mode ('a'). If the file doesn't exist, it will be created.
with open(filename, "a", newline="") as f:
writer = csv.writer(f, delimiter=";")
if not file_exists:
# If file does not exist, write the header
writer.writerow(
[
"code_assistant_type",
"prompt_type",
"dataset_name",
"report_path",
"error_count",
"code_messages_missing_snippets",
"msg_count",
"analysis_message_limit",
"exception",
]
)
# Rest can be calculated from these formulas:
# analyst_count = (msg_count + 1) // 2
# code_count = msg_count // 2 + error_count
# total_assistant_requests = msg_count + error_count

writer.writerow(data)


def analyze(
dataset_path: str,
runtime: IRuntime,
Expand Down Expand Up @@ -66,16 +106,16 @@ def analyze(

conv = Conversation(runtime, code_assistant, analysis_assistant, prompt, conv_list)
error_count = 0
msg_count = 0
try:
while analysis_message_limit is None or analysis_message_limit > 0:
if analysis_message_limit is not None:
analysis_message_limit -= 1
elif "q" in input(
while analysis_message_limit is None or msg_count < analysis_message_limit:
if analysis_message_limit is None and "q" in input(
f"{Colors.BOLD_BLACK.value}Press 'q' to quit or any other key to continue: {Colors.END.value}"
):
break

msg = conv.perform_next_step()
msg_count += 1
code_retry_limit = 3
while conv.last_msg_contains_execution_errors():
error_count += 1
Expand All @@ -99,6 +139,19 @@ def analyze(
report_path = runtime.generate_report("reports", report_name)
except Exception as ex:
report_path = None
save_to_csv(
[
code_assistant.__class__.__name__,
prompt.__class__.__name__,
dataset_file_name,
report_path,
error_count,
conv.code_messages_missing_snippets,
msg_count,
analysis_message_limit,
e,
]
)

raise e

Expand All @@ -113,6 +166,19 @@ def analyze(
print(
f"{Colors.BOLD_BLUE.value}Code Assistant messages missing code snippets: {conv.code_messages_missing_snippets}{Colors.END.value}"
)
save_to_csv(
[
code_assistant.__class__.__name__,
prompt.__class__.__name__,
dataset_file_name,
report_path,
error_count,
conv.code_messages_missing_snippets,
msg_count,
analysis_message_limit,
None,
]
)

conv_json = conv.get_conversation_json()
conv_path = f"conversations/conversation-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.json"
Expand Down
6 changes: 6 additions & 0 deletions src/core/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,17 @@ def get_conversation(self) -> List[Message]:
def _add_to_conversation(
self, role: ConversationRolesInternalEnum, content: str
) -> None:
"""Add message to the conversation."""
self._conversation.append(Message(role=role, content=content))

def _get_last_message(self) -> Message:
"""Get the last message in the conversation."""
return self._conversation[-1]

def _send_message_analysis(self) -> None:
"""
Generates output from the analysis assistant and adds it to the conversation history.
"""
analysis_conv = self._prompt.generate_conversation_context(
self._conversation, ConversationRolesInternalEnum.ANALYSIS, LLMType.GPT4
)
Expand All @@ -63,6 +68,7 @@ def _send_message_analysis(self) -> None:
self._runtime.add_description(analysis_response)

def _execute_python_snippet(self, code: str) -> int:
"""Execute python code snippet in the runtime."""
cell_idx = self._runtime.add_code(code)
self._runtime.execute_cell(cell_idx)
return cell_idx
Expand Down
16 changes: 16 additions & 0 deletions src/llm_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@ def generate_together_completion(
top_k: int = 50,
top_p: float = 0.7,
) -> dict:
"""
Generate a completion using the Together API.

Parameters:
- prompt (str): The prompt to generate the completion from.
- model (str): The model to use for the completion.
- max_tokens (int): The maximum number of tokens to generate.
- temperature (float): The temperature parameter for the LLM.
- top_k (int): The top-k parameter for the LLM.
- top_p (float): The top-p parameter for the LLM.

Returns:
dict: The generated completion from the LLM.
"""
output = together.Complete.create(
prompt=prompt,
model=model,
Expand All @@ -23,6 +37,7 @@ def generate_together_completion(


def get_together_text(output: dict) -> str:
"""Get text response from the Together API response."""
try:
return output["output"]["choices"][0]["text"]
except KeyError:
Expand All @@ -31,6 +46,7 @@ def get_together_text(output: dict) -> str:


def conversation_prompt_to_instruct(conversation: list) -> str:
"""Convert a conversation prompt to an instruct prompt."""
prompt = ""
for number in range(len(conversation)):
if conversation[number].role == "system":
Expand Down
10 changes: 9 additions & 1 deletion src/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,30 @@


class ConversationRolesEnum(str, Enum):
"""Conversation roles for the LLM API."""

SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
FUNCTION = "function"


class ConversationRolesInternalEnum(str, Enum):
"""Conversation roles for the internal usage."""

CODE = "code_generation"
ANALYSIS = "analysis_suggestion_interpretation"


class Message(BaseModel):
"""Messages are the basic building blocks of a conversation."""

role: ConversationRolesEnum | ConversationRolesInternalEnum
content: str


class LLMType(Enum):
"""The type of LLM to use."""

GPT4 = "gpt4"
LLAMA2 = "llama2"
LLAMA2 = "llama2"
Loading
Loading