Skip to content

Commit

Permalink
Add a DatasetFormatter that can dump a DatasetSplit to JSONL
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Nov 23, 2024
1 parent 0c1832c commit eec9b17
Show file tree
Hide file tree
Showing 3 changed files with 385 additions and 3 deletions.
6 changes: 3 additions & 3 deletions checks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ else
echo "Skipping Web UI: no files changed"
fi

echo "${headerStart}Running Python Tests${headerEnd}"
python3 -m pytest -q .

echo "${headerStart}Checking Types${headerEnd}"
pyright .

echo "${headerStart}Running Python Tests${headerEnd}"
python3 -m pytest -q .
128 changes: 128 additions & 0 deletions libs/core/kiln_ai/adapters/fine_tune/dataset_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import json
import tempfile
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Protocol

from kiln_ai.datamodel import DatasetSplit, TaskRun


class DatasetFormat(str, Enum):
"""Format types for dataset generation"""

CHAT_MESSAGE_RESPONSE_JSONL = "chat_message_response_jsonl"
CHAT_MESSAGE_TOOLCALL_JSONL = "chat_message_toolcall_jsonl"


class FormatGenerator(Protocol):
"""Protocol for format generators"""

def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ...


def generate_chat_message_response(
task_run: TaskRun, system_message: str
) -> Dict[str, Any]:
"""Generate OpenAI chat format with plaintext response"""
return {
"messages": [
{"role": "system", "content": system_message},
{"role": "user", "content": task_run.input},
{"role": "assistant", "content": task_run.output.output},
]
}


def generate_chat_message_toolcall(
task_run: TaskRun, system_message: str
) -> Dict[str, Any]:
"""Generate OpenAI chat format with tool call response"""
try:
arguments = json.loads(task_run.output.output)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in task run output: {e}") from e

return {
"messages": [
{"role": "system", "content": system_message},
{"role": "user", "content": task_run.input},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "task_response",
"arguments": arguments,
},
}
],
},
]
}


FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
DatasetFormat.CHAT_MESSAGE_RESPONSE_JSONL: generate_chat_message_response,
DatasetFormat.CHAT_MESSAGE_TOOLCALL_JSONL: generate_chat_message_toolcall,
}


class DatasetFormatter:
"""Handles formatting of datasets into various output formats"""

def __init__(self, dataset: DatasetSplit, system_message: str):
self.dataset = dataset
self.system_message = system_message

task = dataset.parent_task()
if task is None:
raise ValueError("Dataset has no parent task")
self.task = task

def dump_to_file(
self, split_name: str, format_type: DatasetFormat, path: Path | None = None
) -> Path:
"""
Format the dataset into the specified format.
Args:
split_name: Name of the split to dump
format_type: Format to generate the dataset in
path: Optional path to write to. If None, writes to temp directory
Returns:
Path to the generated file
"""
if format_type not in FORMAT_GENERATORS:
raise ValueError(f"Unsupported format: {format_type}")
if split_name not in self.dataset.split_contents:
raise ValueError(f"Split {split_name} not found in dataset")

generator = FORMAT_GENERATORS[format_type]

# Write to a temp file if no path is provided
output_path = (
path
or Path(tempfile.gettempdir())
/ f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
)

runs = self.task.runs()
runs_by_id = {run.id: run for run in runs}

# Generate formatted output with UTF-8 encoding
with open(output_path, "w", encoding="utf-8") as f:
for run_id in self.dataset.split_contents[split_name]:
task_run = runs_by_id[run_id]
if task_run is None:
raise ValueError(
f"Task run {run_id} not found. This is required by this dataset."
)

example = generator(task_run, self.system_message)
f.write(json.dumps(example) + "\n")

return output_path
254 changes: 254 additions & 0 deletions libs/core/kiln_ai/adapters/fine_tune/test_dataset_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import json
import tempfile
from pathlib import Path
from unittest.mock import Mock

import pytest

from kiln_ai.adapters.fine_tune.dataset_formatter import (
DatasetFormat,
DatasetFormatter,
generate_chat_message_response,
generate_chat_message_toolcall,
)
from kiln_ai.datamodel import (
DatasetSplit,
DataSource,
DataSourceType,
Task,
TaskOutput,
TaskRun,
)


@pytest.fixture
def mock_task():
task = Mock(spec=Task)
task_runs = [
TaskRun(
id=f"run{i}",
input='{"test": "input"}',
input_source=DataSource(
type=DataSourceType.human, properties={"created_by": "test"}
),
output=TaskOutput(
output='{"test": "output"}',
source=DataSource(
type=DataSourceType.synthetic,
properties={
"model_name": "test",
"model_provider": "test",
"adapter_name": "test",
},
),
),
)
for i in range(1, 4)
]
task.runs.return_value = task_runs
return task


@pytest.fixture
def mock_dataset(mock_task):
dataset = Mock(spec=DatasetSplit)
dataset.name = "test_dataset"
dataset.parent_task.return_value = mock_task
dataset.split_contents = {"train": ["run1", "run2"], "test": ["run3"]}
return dataset


def test_generate_chat_message_response():
task_run = TaskRun(
id="run1",
input="test input",
input_source=DataSource(
type=DataSourceType.human, properties={"created_by": "test"}
),
output=TaskOutput(
output="test output",
source=DataSource(
type=DataSourceType.synthetic,
properties={
"model_name": "test",
"model_provider": "test",
"adapter_name": "test",
},
),
),
)

result = generate_chat_message_response(task_run, "system message")

assert result == {
"messages": [
{"role": "system", "content": "system message"},
{"role": "user", "content": "test input"},
{"role": "assistant", "content": "test output"},
]
}


def test_generate_chat_message_toolcall():
task_run = TaskRun(
id="run1",
input="test input",
input_source=DataSource(
type=DataSourceType.human, properties={"created_by": "test"}
),
output=TaskOutput(
output='{"key": "value"}',
source=DataSource(
type=DataSourceType.synthetic,
properties={
"model_name": "test",
"model_provider": "test",
"adapter_name": "test",
},
),
),
)

result = generate_chat_message_toolcall(task_run, "system message")

assert result == {
"messages": [
{"role": "system", "content": "system message"},
{"role": "user", "content": "test input"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "task_response",
"arguments": {"key": "value"},
},
}
],
},
]
}


def test_generate_chat_message_toolcall_invalid_json():
task_run = TaskRun(
id="run1",
input="test input",
input_source=DataSource(
type=DataSourceType.human, properties={"created_by": "test"}
),
output=TaskOutput(
output="invalid json",
source=DataSource(
type=DataSourceType.synthetic,
properties={
"model_name": "test",
"model_provider": "test",
"adapter_name": "test",
},
),
),
)

with pytest.raises(ValueError, match="Invalid JSON in task run output"):
generate_chat_message_toolcall(task_run, "system message")


def test_dataset_formatter_init_no_parent_task(mock_dataset):
mock_dataset.parent_task.return_value = None

with pytest.raises(ValueError, match="Dataset has no parent task"):
DatasetFormatter(mock_dataset, "system message")


def test_dataset_formatter_dump_invalid_format(mock_dataset):
formatter = DatasetFormatter(mock_dataset, "system message")

with pytest.raises(ValueError, match="Unsupported format"):
formatter.dump_to_file("train", "invalid_format") # type: ignore


def test_dataset_formatter_dump_invalid_split(mock_dataset):
formatter = DatasetFormatter(mock_dataset, "system message")

with pytest.raises(ValueError, match="Split invalid_split not found in dataset"):
formatter.dump_to_file(
"invalid_split", DatasetFormat.CHAT_MESSAGE_RESPONSE_JSONL
)


def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path):
formatter = DatasetFormatter(mock_dataset, "system message")
output_path = tmp_path / "output.jsonl"

result_path = formatter.dump_to_file(
"train", DatasetFormat.CHAT_MESSAGE_RESPONSE_JSONL, output_path
)

assert result_path == output_path
assert output_path.exists()

# Verify file contents
with open(output_path) as f:
lines = f.readlines()
assert len(lines) == 2 # Should have 2 entries for train split
for line in lines:
data = json.loads(line)
assert "messages" in data
assert len(data["messages"]) == 3
assert data["messages"][0]["content"] == "system message"
assert data["messages"][1]["content"] == '{"test": "input"}'
assert data["messages"][2]["content"] == '{"test": "output"}'


def test_dataset_formatter_dump_to_temp_file(mock_dataset):
formatter = DatasetFormatter(mock_dataset, "system message")

result_path = formatter.dump_to_file(
"train", DatasetFormat.CHAT_MESSAGE_RESPONSE_JSONL
)

assert result_path.exists()
assert result_path.parent == Path(tempfile.gettempdir())
assert result_path.name.startswith("test_dataset_train_")
assert result_path.name.endswith(".jsonl")
# Verify file contents
with open(result_path) as f:
lines = f.readlines()
assert len(lines) == 2


def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
formatter = DatasetFormatter(mock_dataset, "system message")
output_path = tmp_path / "output.jsonl"

result_path = formatter.dump_to_file(
"train", DatasetFormat.CHAT_MESSAGE_TOOLCALL_JSONL, output_path
)

assert result_path == output_path
assert output_path.exists()

# Verify file contents
with open(output_path) as f:
lines = f.readlines()
assert len(lines) == 2 # Should have 2 entries for train split
for line in lines:
data = json.loads(line)
assert "messages" in data
assert len(data["messages"]) == 3
# Check system and user messages
assert data["messages"][0]["content"] == "system message"
assert data["messages"][1]["content"] == '{"test": "input"}'
# Check tool call format
assistant_msg = data["messages"][2]
assert assistant_msg["content"] is None
assert "tool_calls" in assistant_msg
assert len(assistant_msg["tool_calls"]) == 1
tool_call = assistant_msg["tool_calls"][0]
assert tool_call["type"] == "function"
assert tool_call["function"]["name"] == "task_response"
assert tool_call["function"]["arguments"] == {"test": "output"}

0 comments on commit eec9b17

Please sign in to comment.