Skip to content

Commit

Permalink
Merge pull request #73 from jswistak/feature/more-metrics
Browse files Browse the repository at this point in the history
Feature/more metrics
  • Loading branch information
jswistak authored Jan 29, 2024
2 parents 8b5969d + dc196cd commit 54ab9f7
Show file tree
Hide file tree
Showing 9 changed files with 1,982 additions and 69 deletions.
129 changes: 71 additions & 58 deletions src/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,66 +2,79 @@
from main import main, get_runtime_kwargs
from dotenv import load_dotenv
from core.analysis import CodeRetryLimitExceeded
from datetime import datetime

load_dotenv()
together_token = os.getenv("TOGETHER_API_KEY")
openai_token = os.getenv("OPENAI_API_KEY")

prompting_techniques = ["zero-shot", "few-shot"]
assistants = [
"llama-chat",
"openai",
"mixtral-8x7b",
]
if __name__ == "__main__":
load_dotenv()
together_token = os.getenv("TOGETHER_API_KEY")
openai_token = os.getenv("OPENAI_API_KEY")

analysis_message_limit = 8
runtime = "jupyter-notebook"
report_params_no = {
"zero-shot_llama-chat": 1,
"zero-shot_openai": 2,
"zero-shot_mixtral-8x7b": 3,
"few-shot_llama-chat": 4,
"few-shot_openai": 5,
"few-shot_mixtral-8x7b": 6,
}
prompting_techniques = [
"zero-shot",
"few-shot",
]
assistants = [
"openai",
"mixtral-8x7b",
"llama-chat",
]

dataset_path = "data/<dataset>.csv"
dataset_name = dataset_path.split("/")[-1].split(".")[0]
analysis_message_limit = 8
runtime = "jupyter-notebook"
report_params_no = {
"zero-shot_llama-chat": 1,
"zero-shot_openai": 2,
"zero-shot_mixtral-8x7b": 3,
"few-shot_llama-chat": 4,
"few-shot_openai": 5,
"few-shot_mixtral-8x7b": 6,
}

for assistant in assistants:
for prompting_technique in prompting_techniques:
kwargs = get_runtime_kwargs(
runtime,
prompting_technique,
assistant,
)
kwargs["analysis_assistant_kwargs"]["api_key"] = (
openai_token if assistant == "openai" else together_token
)
kwargs["code_assistant_kwargs"]["api_key"] = (
openai_token if assistant == "openai" else together_token
)
report_no = report_params_no[f"{prompting_technique}_{assistant}"]
try:
output_pdf_path, error_count, code_messages_missing_snippets = main(
dataset_name,
dataset_path,
runtime,
assistant,
assistant,
prompting_technique,
analysis_message_limit=analysis_message_limit,
output_pdf_path=f"../{dataset_name}_{report_no}.pdf",
**kwargs,
)
except CodeRetryLimitExceeded as e:
print(e)
print(output_pdf_path)
print("Error Count:", error_count)
print("Code Messages Missing Snippets:", code_messages_missing_snippets)
# create text file with error count and code messages missing snippets
with open(f"data/{dataset_name}_{report_no}.txt", "w") as f:
f.write(f"Error Count: {error_count}\n")
f.write(
f"Code Messages Missing Snippets: {code_messages_missing_snippets}\n"
)
dataset_path = "data/wine-quality.csv"
dataset_name = dataset_path.split("/")[-1].split(".")[0]
ITERATIONS = 100

for iteration in range(1, ITERATIONS):
print(f"Iteration: {iteration}")
for assistant in assistants:
for prompting_technique in prompting_techniques:
kwargs = get_runtime_kwargs(
runtime,
prompting_technique,
assistant,
)
kwargs["analysis_assistant_kwargs"]["api_key"] = (
openai_token if assistant == "openai" else together_token
)
kwargs["code_assistant_kwargs"]["api_key"] = (
openai_token if assistant == "openai" else together_token
)
report_no = report_params_no[f"{prompting_technique}_{assistant}"]
try:
output_pdf_path, error_count, code_messages_missing_snippets = main(
dataset_name,
dataset_path,
runtime,
assistant,
assistant,
prompting_technique,
analysis_message_limit=analysis_message_limit,
output_pdf_path=f"../{dataset_name}_{report_no}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.pdf",
**kwargs,
)
except Exception as e:
print(e)
continue
print(output_pdf_path)
print("Error Count:", error_count)
print("Code Messages Missing Snippets:", code_messages_missing_snippets)
# create text file with error count and code messages missing snippets
with open(
f"data/{dataset_name}_{report_no}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.txt",
"w",
) as f:
f.write(f"Error Count: {error_count}\n")
f.write(
f"Code Messages Missing Snippets: {code_messages_missing_snippets}\n"
)
74 changes: 70 additions & 4 deletions src/core/analysis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import csv
from datetime import datetime
from typing import Union

Expand All @@ -20,6 +21,45 @@ def __init__(self, message="Exceeded code retry limit"):
super().__init__(self.message)


def save_to_csv(data: list) -> None:
"""
Save the results of the analysis to a CSV file.
"""
filename = "results.csv"
# Check if file exists
file_exists = False
try:
with open(filename, "r") as f:
file_exists = True
except FileNotFoundError:
file_exists = False

# Open the file in append mode ('a'). If the file doesn't exist, it will be created.
with open(filename, "a", newline="") as f:
writer = csv.writer(f, delimiter=";")
if not file_exists:
# If file does not exist, write the header
writer.writerow(
[
"code_assistant_type",
"prompt_type",
"dataset_name",
"report_path",
"error_count",
"code_messages_missing_snippets",
"msg_count",
"analysis_message_limit",
"exception",
]
)
# Rest can be calculated from these formulas:
# analyst_count = (msg_count + 1) // 2
# code_count = msg_count // 2 + error_count
# total_assistant_requests = msg_count + error_count

writer.writerow(data)


def analyze(
dataset_path: str,
runtime: IRuntime,
Expand Down Expand Up @@ -66,16 +106,16 @@ def analyze(

conv = Conversation(runtime, code_assistant, analysis_assistant, prompt, conv_list)
error_count = 0
msg_count = 0
try:
while analysis_message_limit is None or analysis_message_limit > 0:
if analysis_message_limit is not None:
analysis_message_limit -= 1
elif "q" in input(
while analysis_message_limit is None or msg_count < analysis_message_limit:
if analysis_message_limit is None and "q" in input(
f"{Colors.BOLD_BLACK.value}Press 'q' to quit or any other key to continue: {Colors.END.value}"
):
break

msg = conv.perform_next_step()
msg_count += 1
code_retry_limit = 3
while conv.last_msg_contains_execution_errors():
error_count += 1
Expand All @@ -99,6 +139,19 @@ def analyze(
report_path = runtime.generate_report("reports", report_name)
except Exception as ex:
report_path = None
save_to_csv(
[
code_assistant.__class__.__name__,
prompt.__class__.__name__,
dataset_file_name,
report_path,
error_count,
conv.code_messages_missing_snippets,
msg_count,
analysis_message_limit,
e,
]
)

raise e

Expand All @@ -113,6 +166,19 @@ def analyze(
print(
f"{Colors.BOLD_BLUE.value}Code Assistant messages missing code snippets: {conv.code_messages_missing_snippets}{Colors.END.value}"
)
save_to_csv(
[
code_assistant.__class__.__name__,
prompt.__class__.__name__,
dataset_file_name,
report_path,
error_count,
conv.code_messages_missing_snippets,
msg_count,
analysis_message_limit,
None,
]
)

conv_json = conv.get_conversation_json()
conv_path = f"conversations/conversation-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.json"
Expand Down
6 changes: 6 additions & 0 deletions src/core/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,17 @@ def get_conversation(self) -> List[Message]:
def _add_to_conversation(
self, role: ConversationRolesInternalEnum, content: str
) -> None:
"""Add message to the conversation."""
self._conversation.append(Message(role=role, content=content))

def _get_last_message(self) -> Message:
"""Get the last message in the conversation."""
return self._conversation[-1]

def _send_message_analysis(self) -> None:
"""
Generates output from the analysis assistant and adds it to the conversation history.
"""
analysis_conv = self._prompt.generate_conversation_context(
self._conversation, ConversationRolesInternalEnum.ANALYSIS, LLMType.GPT4
)
Expand All @@ -63,6 +68,7 @@ def _send_message_analysis(self) -> None:
self._runtime.add_description(analysis_response)

def _execute_python_snippet(self, code: str) -> int:
"""Execute python code snippet in the runtime."""
cell_idx = self._runtime.add_code(code)
self._runtime.execute_cell(cell_idx)
return cell_idx
Expand Down
16 changes: 16 additions & 0 deletions src/llm_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@ def generate_together_completion(
top_k: int = 50,
top_p: float = 0.7,
) -> dict:
"""
Generate a completion using the Together API.
Parameters:
- prompt (str): The prompt to generate the completion from.
- model (str): The model to use for the completion.
- max_tokens (int): The maximum number of tokens to generate.
- temperature (float): The temperature parameter for the LLM.
- top_k (int): The top-k parameter for the LLM.
- top_p (float): The top-p parameter for the LLM.
Returns:
dict: The generated completion from the LLM.
"""
output = together.Complete.create(
prompt=prompt,
model=model,
Expand All @@ -23,6 +37,7 @@ def generate_together_completion(


def get_together_text(output: dict) -> str:
"""Get text response from the Together API response."""
try:
return output["output"]["choices"][0]["text"]
except KeyError:
Expand All @@ -31,6 +46,7 @@ def get_together_text(output: dict) -> str:


def conversation_prompt_to_instruct(conversation: list) -> str:
"""Convert a conversation prompt to an instruct prompt."""
prompt = ""
for number in range(len(conversation)):
if conversation[number].role == "system":
Expand Down
10 changes: 9 additions & 1 deletion src/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,30 @@


class ConversationRolesEnum(str, Enum):
"""Conversation roles for the LLM API."""

SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
FUNCTION = "function"


class ConversationRolesInternalEnum(str, Enum):
"""Conversation roles for the internal usage."""

CODE = "code_generation"
ANALYSIS = "analysis_suggestion_interpretation"


class Message(BaseModel):
"""Messages are the basic building blocks of a conversation."""

role: ConversationRolesEnum | ConversationRolesInternalEnum
content: str


class LLMType(Enum):
"""The type of LLM to use."""

GPT4 = "gpt4"
LLAMA2 = "llama2"
LLAMA2 = "llama2"
Loading

0 comments on commit 54ab9f7

Please sign in to comment.