Skip to content

Commit

Permalink
Update AI: fix save-interaction for goldens; add together provider (#965
Browse files Browse the repository at this point in the history
)
  • Loading branch information
wwwillchen committed Sep 14, 2024
1 parent 595cb24 commit 329dcca
Show file tree
Hide file tree
Showing 15 changed files with 175 additions and 58 deletions.
12 changes: 12 additions & 0 deletions ai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,15 @@ The common package contains the code that is shared between multiple Python bina
## Docbot

Docbot is a standalone app that creates a RAG-chatbot for the Mesop docs.

## Fine-tuning

### Together

> Pre-requisite: install together CLI: `pip install --upgrade together`
1. [Upload dataset](https://docs.together.ai/reference/files)

```sh
together files upload data/golden_datasets/udiff-2024-09-13_llama3.jsonl
```
5 changes: 5 additions & 0 deletions ai/src/ai/common/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def apply_patch(original_code: str, patch: str) -> ApplyPatchResult:


def apply_udiff(original_text: str, udiff: str) -> ApplyPatchResult:
# Remove the edit markers.
original_text = original_text.replace(EDIT_HERE_MARKER, "")
udiff = udiff.replace(EDIT_HERE_MARKER, "")
# Check if the udiff contains code blocks
if "```" in udiff:
udiff = extract_code_blocks(udiff)
Expand Down Expand Up @@ -86,6 +89,8 @@ def apply_udiff(original_text: str, udiff: str) -> ApplyPatchResult:
True,
"[AI-001] No changes were applied. Please try again.",
)
if original_text.endswith("\n"):
patched_code = patched_code.rstrip() + "\n"
return ApplyPatchResult(False, patched_code)


Expand Down
9 changes: 9 additions & 0 deletions ai/src/ai/common/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import os
import shutil
from datetime import datetime
from typing import Generic, Literal, TypeVar

from pydantic import BaseModel, field_validator, model_validator
Expand Down Expand Up @@ -55,6 +56,8 @@ def is_valid_line_number(self):
class BaseExample(BaseModel):
id: str
input: ExampleInput
created_at: datetime = None
updated_at: datetime = None

@field_validator("id", mode="after")
@classmethod
Expand Down Expand Up @@ -162,6 +165,12 @@ def save(self, entity: T, overwrite: bool = False):
id = entity.id
dir_path = os.path.join(self.directory_path, id)

# Update timestamps
current_time = datetime.utcnow()
if entity.created_at is None:
entity.created_at = current_time
entity.updated_at = current_time

if not overwrite:
if os.path.exists(dir_path):
raise ValueError(
Expand Down
15 changes: 15 additions & 0 deletions ai/src/ai/common/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,20 @@ def execute(self, input: ExampleInput, seed: int | None = None) -> str:
return response.choices[0].message.content or ""


class TogetherExecutor(OpenaiExecutor):
def __init__(
self,
model_name: str,
prompt_fragments: list[PromptFragment],
temperature: float,
):
super().__init__(model_name, prompt_fragments, temperature)
self.client = OpenAI(
base_url="https://api.together.xyz/v1",
api_key=getenv("TOGETHER_API_KEY"),
)


class GeminiExecutor(ProviderExecutor):
def __init__(
self,
Expand Down Expand Up @@ -196,6 +210,7 @@ def _format_gemini_messages(
provider_executors: dict[str, type[ProviderExecutor]] = {
"openai": OpenaiExecutor,
"fireworks": FireworksExecutor,
"together": TogetherExecutor,
"gemini": GeminiExecutor,
}

Expand Down
13 changes: 11 additions & 2 deletions ai/src/ai/console/pages/evals_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def evals_page():
grid_template_columns="1fr 1fr 96px 96px 96px",
gap=16,
align_items="center",
height="100%",
overflow_y="auto",
padding=me.Padding(bottom=12),
)
):
# Header
Expand All @@ -36,6 +35,16 @@ def evals_page():
me.text("State", style=me.Style(font_weight="bold"))
me.text("Score", style=me.Style(font_weight="bold"))
me.text("Examples Succeeded / Run", style=me.Style(font_weight="bold"))
with me.box(
style=me.Style(
display="grid",
grid_template_columns="1fr 1fr 96px 96px 96px",
gap=16,
align_items="center",
height="100%",
overflow_y="auto",
)
):
# Body
for eval in evals:
me.button(
Expand Down
25 changes: 22 additions & 3 deletions ai/src/ai/console/pages/expected_examples_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,29 @@ def expected_examples_page():
with me.box(
style=me.Style(
display="grid",
grid_template_columns="repeat(4, 1fr)",
grid_template_columns="240px 1fr 140px 140px 48px 48px",
gap=16,
align_items="center",
overflow_y="auto",
height="100%",
padding=me.Padding(bottom=12),
)
):
# Header
me.text("ID", style=me.Style(font_weight="bold"))
me.text("Prompt", style=me.Style(font_weight="bold"))
me.text("Created at", style=me.Style(font_weight="bold"))
me.text("Updated at", style=me.Style(font_weight="bold"))
me.text("Has input code", style=me.Style(font_weight="bold"))
me.text("Has line # target", style=me.Style(font_weight="bold"))
with me.box(
style=me.Style(
display="grid",
grid_template_columns="240px 1fr 140px 140px 48px 48px",
gap=16,
align_items="center",
overflow_y="auto",
height="100%",
)
):
# Body
for example in examples:
me.button(
Expand All @@ -47,5 +58,13 @@ def expected_examples_page():
style=me.Style(font_size=16),
)
me.text(example.input.prompt)
if example.created_at:
me.text(example.created_at.strftime("%Y-%m-%d"))
else:
me.text("")
if example.updated_at:
me.text(example.updated_at.strftime("%Y-%m-%d"))
else:
me.text("")
me.text(str(bool(example.input.input_code)))
me.text(str(bool(example.input.line_number_target)))
25 changes: 22 additions & 3 deletions ai/src/ai/console/pages/golden_examples_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,29 @@ def golden_examples_page():
with me.box(
style=me.Style(
display="grid",
grid_template_columns="200px 1fr 48px 48px",
grid_template_columns="200px 1fr 140px 140px 48px 48px",
gap=12,
align_items="center",
overflow_y="auto",
height="100%",
padding=me.Padding(bottom=12),
)
):
# Header
me.text("ID", style=me.Style(font_weight="bold"))
me.text("Prompt", style=me.Style(font_weight="bold"))
me.text("Created at", style=me.Style(font_weight="bold"))
me.text("Updated at", style=me.Style(font_weight="bold"))
me.text("Has input code", style=me.Style(font_weight="bold"))
me.text("Has line # target", style=me.Style(font_weight="bold"))
with me.box(
style=me.Style(
display="grid",
grid_template_columns="200px 1fr 140px 140px 48px 48px",
gap=12,
align_items="center",
overflow_y="auto",
height="100%",
)
):
# Body
for example in examples:
me.button(
Expand All @@ -57,5 +68,13 @@ def golden_examples_page():
style=me.Style(font_size=16),
)
me.text(example.input.prompt)
if example.created_at:
me.text(example.created_at.strftime("%Y-%m-%d"))
else:
me.text("")
if example.updated_at:
me.text(example.updated_at.strftime("%Y-%m-%d"))
else:
me.text("")
me.text(str(bool(example.input.input_code)))
me.text(str(bool(example.input.line_number_target)))
14 changes: 7 additions & 7 deletions ai/src/ai/console/pages/models_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ def on_load(e: me.LoadEvent):
@me.page(title="Mesop AI Console - Models", path="/models", on_load=on_load)
def models_page():
with page_scaffold(current_path="/models", title="Models"):
with me.box(style=me.Style(padding=me.Padding(bottom=16))):
me.button(
"Add Model",
on_click=lambda e: me.navigate("/models/add"),
type="flat",
color="accent",
)
models = store.get_all()
with me.box(
style=me.Style(
Expand All @@ -35,10 +42,3 @@ def models_page():
)
me.text(model.name)
me.text(model.provider)
with me.box(style=me.Style(padding=me.Padding(top=32))):
me.button(
"Add Model",
on_click=lambda e: me.navigate("/models/add"),
type="flat",
color="accent",
)
24 changes: 17 additions & 7 deletions ai/src/ai/console/pages/producers_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ def on_load(e: me.LoadEvent):
def producers_page():
with page_scaffold(current_path="/producers", title="Producers"):
producers = store.get_all()
with me.box(style=me.Style(padding=me.Padding(bottom=16))):
me.button(
"Add Producer",
on_click=lambda e: me.navigate("/producers/add"),
type="flat",
color="accent",
)

with me.box(
style=me.Style(
display="grid",
Expand All @@ -28,6 +36,15 @@ def producers_page():
me.text("Prompt Context", style=me.Style(font_weight="bold"))
me.text("Output Format", style=me.Style(font_weight="bold"))
me.text("Temp", style=me.Style(font_weight="bold"))
with me.box(
style=me.Style(
display="grid",
grid_template_columns="repeat(3, 1fr) 64px 64px",
gap=16,
align_items="center",
overflow_y="auto",
)
):
# Body
for producer in producers:
me.button(
Expand Down Expand Up @@ -57,10 +74,3 @@ def producers_page():
)
me.text(producer.output_format)
me.text(str(producer.temperature))
with me.box(style=me.Style(padding=me.Padding(top=32))):
me.button(
"Add Producer",
on_click=lambda e: me.navigate("/producers/add"),
type="flat",
color="accent",
)
14 changes: 7 additions & 7 deletions ai/src/ai/console/pages/prompt_contexts_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ def on_load(e: me.LoadEvent):
def prompt_contexts_page():
with page_scaffold(current_path="/prompt-contexts", title="Prompt Contexts"):
prompt_contexts = prompt_context_store.get_all()
with me.box(style=me.Style(padding=me.Padding(bottom=16))):
me.button(
"Add Prompt Context",
on_click=lambda e: me.navigate("/prompt-contexts/add"),
type="flat",
color="accent",
)
with me.box(
style=me.Style(
display="grid",
Expand Down Expand Up @@ -46,10 +53,3 @@ def prompt_contexts_page():
key=fragment_id,
style=me.Style(font_size=16),
)
with me.box(style=me.Style(padding=me.Padding(top=32))):
me.button(
"Add Prompt Context",
on_click=lambda e: me.navigate("/prompt-contexts/add"),
type="flat",
color="accent",
)
14 changes: 7 additions & 7 deletions ai/src/ai/console/pages/prompt_fragments_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ def prompt_fragments_page():
with page_scaffold(
current_path="/prompt-fragments", title="Prompt Fragments"
):
with me.box(style=me.Style(padding=me.Padding(bottom=16))):
me.button(
"Add Prompt Fragment",
on_click=lambda e: me.navigate("/prompt-fragments/add"),
type="flat",
color="accent",
)
prompt_fragments = prompt_fragment_store.get_all()
with me.box(
style=me.Style(
Expand Down Expand Up @@ -48,10 +55,3 @@ def prompt_fragments_page():

me.text(prompt_fragment.role)
me.text(str(prompt_fragment.chain_of_thought))
with me.box(style=me.Style(padding=me.Padding(top=32))):
me.button(
"Add Prompt Fragment",
on_click=lambda e: me.navigate("/prompt-fragments/add"),
type="flat",
color="accent",
)
19 changes: 19 additions & 0 deletions ai/src/ai/offline_common/golden_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,23 @@ def create_golden_dataset(*, producer_id: str, dataset_name: str) -> str:
)
f.write(json.dumps({"messages": messages}) + "\n")
print("created golden dataset", golden_dataset_path)
convert_openai_format_to_llama3_format(
golden_dataset_path,
golden_dataset_path.replace(".jsonl", "_llama3.jsonl"),
)
return golden_dataset_path


def convert_openai_format_to_llama3_format(input_file: str, output_file: str):
with open(input_file) as in_f, open(output_file, "w") as out_f:
for line in in_f:
data = json.loads(line.strip())
output = "<|begin_of_text|>"

for message in data["messages"]:
role, content = message["role"], message["content"]
output += (
f"<|start_header_id|>{role}<|end_header_id|> {content}<|eot_id|>"
)

out_f.write(json.dumps({"text": output}) + "\n")
Loading

0 comments on commit 329dcca

Please sign in to comment.