-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a DatasetFormatter that can dump a DatasetSplit to JSONL
- Loading branch information
Showing
3 changed files
with
385 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
128 changes: 128 additions & 0 deletions
128
libs/core/kiln_ai/adapters/fine_tune/dataset_formatter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import json | ||
import tempfile | ||
from enum import Enum | ||
from pathlib import Path | ||
from typing import Any, Dict, Protocol | ||
|
||
from kiln_ai.datamodel import DatasetSplit, TaskRun | ||
|
||
|
||
class DatasetFormat(str, Enum): | ||
"""Format types for dataset generation""" | ||
|
||
CHAT_MESSAGE_RESPONSE_JSONL = "chat_message_response_jsonl" | ||
CHAT_MESSAGE_TOOLCALL_JSONL = "chat_message_toolcall_jsonl" | ||
|
||
|
||
class FormatGenerator(Protocol): | ||
"""Protocol for format generators""" | ||
|
||
def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ... | ||
|
||
|
||
def generate_chat_message_response( | ||
task_run: TaskRun, system_message: str | ||
) -> Dict[str, Any]: | ||
"""Generate OpenAI chat format with plaintext response""" | ||
return { | ||
"messages": [ | ||
{"role": "system", "content": system_message}, | ||
{"role": "user", "content": task_run.input}, | ||
{"role": "assistant", "content": task_run.output.output}, | ||
] | ||
} | ||
|
||
|
||
def generate_chat_message_toolcall( | ||
task_run: TaskRun, system_message: str | ||
) -> Dict[str, Any]: | ||
"""Generate OpenAI chat format with tool call response""" | ||
try: | ||
arguments = json.loads(task_run.output.output) | ||
except json.JSONDecodeError as e: | ||
raise ValueError(f"Invalid JSON in task run output: {e}") from e | ||
|
||
return { | ||
"messages": [ | ||
{"role": "system", "content": system_message}, | ||
{"role": "user", "content": task_run.input}, | ||
{ | ||
"role": "assistant", | ||
"content": None, | ||
"tool_calls": [ | ||
{ | ||
"id": "call_1", | ||
"type": "function", | ||
"function": { | ||
"name": "task_response", | ||
"arguments": arguments, | ||
}, | ||
} | ||
], | ||
}, | ||
] | ||
} | ||
|
||
|
||
FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = { | ||
DatasetFormat.CHAT_MESSAGE_RESPONSE_JSONL: generate_chat_message_response, | ||
DatasetFormat.CHAT_MESSAGE_TOOLCALL_JSONL: generate_chat_message_toolcall, | ||
} | ||
|
||
|
||
class DatasetFormatter: | ||
"""Handles formatting of datasets into various output formats""" | ||
|
||
def __init__(self, dataset: DatasetSplit, system_message: str): | ||
self.dataset = dataset | ||
self.system_message = system_message | ||
|
||
task = dataset.parent_task() | ||
if task is None: | ||
raise ValueError("Dataset has no parent task") | ||
self.task = task | ||
|
||
def dump_to_file( | ||
self, split_name: str, format_type: DatasetFormat, path: Path | None = None | ||
) -> Path: | ||
""" | ||
Format the dataset into the specified format. | ||
Args: | ||
split_name: Name of the split to dump | ||
format_type: Format to generate the dataset in | ||
path: Optional path to write to. If None, writes to temp directory | ||
Returns: | ||
Path to the generated file | ||
""" | ||
if format_type not in FORMAT_GENERATORS: | ||
raise ValueError(f"Unsupported format: {format_type}") | ||
if split_name not in self.dataset.split_contents: | ||
raise ValueError(f"Split {split_name} not found in dataset") | ||
|
||
generator = FORMAT_GENERATORS[format_type] | ||
|
||
# Write to a temp file if no path is provided | ||
output_path = ( | ||
path | ||
or Path(tempfile.gettempdir()) | ||
/ f"{self.dataset.name}_{split_name}_{format_type}.jsonl" | ||
) | ||
|
||
runs = self.task.runs() | ||
runs_by_id = {run.id: run for run in runs} | ||
|
||
# Generate formatted output with UTF-8 encoding | ||
with open(output_path, "w", encoding="utf-8") as f: | ||
for run_id in self.dataset.split_contents[split_name]: | ||
task_run = runs_by_id[run_id] | ||
if task_run is None: | ||
raise ValueError( | ||
f"Task run {run_id} not found. This is required by this dataset." | ||
) | ||
|
||
example = generator(task_run, self.system_message) | ||
f.write(json.dumps(example) + "\n") | ||
|
||
return output_path |
254 changes: 254 additions & 0 deletions
254
libs/core/kiln_ai/adapters/fine_tune/test_dataset_formatter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
import json | ||
import tempfile | ||
from pathlib import Path | ||
from unittest.mock import Mock | ||
|
||
import pytest | ||
|
||
from kiln_ai.adapters.fine_tune.dataset_formatter import ( | ||
DatasetFormat, | ||
DatasetFormatter, | ||
generate_chat_message_response, | ||
generate_chat_message_toolcall, | ||
) | ||
from kiln_ai.datamodel import ( | ||
DatasetSplit, | ||
DataSource, | ||
DataSourceType, | ||
Task, | ||
TaskOutput, | ||
TaskRun, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def mock_task(): | ||
task = Mock(spec=Task) | ||
task_runs = [ | ||
TaskRun( | ||
id=f"run{i}", | ||
input='{"test": "input"}', | ||
input_source=DataSource( | ||
type=DataSourceType.human, properties={"created_by": "test"} | ||
), | ||
output=TaskOutput( | ||
output='{"test": "output"}', | ||
source=DataSource( | ||
type=DataSourceType.synthetic, | ||
properties={ | ||
"model_name": "test", | ||
"model_provider": "test", | ||
"adapter_name": "test", | ||
}, | ||
), | ||
), | ||
) | ||
for i in range(1, 4) | ||
] | ||
task.runs.return_value = task_runs | ||
return task | ||
|
||
|
||
@pytest.fixture | ||
def mock_dataset(mock_task): | ||
dataset = Mock(spec=DatasetSplit) | ||
dataset.name = "test_dataset" | ||
dataset.parent_task.return_value = mock_task | ||
dataset.split_contents = {"train": ["run1", "run2"], "test": ["run3"]} | ||
return dataset | ||
|
||
|
||
def test_generate_chat_message_response(): | ||
task_run = TaskRun( | ||
id="run1", | ||
input="test input", | ||
input_source=DataSource( | ||
type=DataSourceType.human, properties={"created_by": "test"} | ||
), | ||
output=TaskOutput( | ||
output="test output", | ||
source=DataSource( | ||
type=DataSourceType.synthetic, | ||
properties={ | ||
"model_name": "test", | ||
"model_provider": "test", | ||
"adapter_name": "test", | ||
}, | ||
), | ||
), | ||
) | ||
|
||
result = generate_chat_message_response(task_run, "system message") | ||
|
||
assert result == { | ||
"messages": [ | ||
{"role": "system", "content": "system message"}, | ||
{"role": "user", "content": "test input"}, | ||
{"role": "assistant", "content": "test output"}, | ||
] | ||
} | ||
|
||
|
||
def test_generate_chat_message_toolcall(): | ||
task_run = TaskRun( | ||
id="run1", | ||
input="test input", | ||
input_source=DataSource( | ||
type=DataSourceType.human, properties={"created_by": "test"} | ||
), | ||
output=TaskOutput( | ||
output='{"key": "value"}', | ||
source=DataSource( | ||
type=DataSourceType.synthetic, | ||
properties={ | ||
"model_name": "test", | ||
"model_provider": "test", | ||
"adapter_name": "test", | ||
}, | ||
), | ||
), | ||
) | ||
|
||
result = generate_chat_message_toolcall(task_run, "system message") | ||
|
||
assert result == { | ||
"messages": [ | ||
{"role": "system", "content": "system message"}, | ||
{"role": "user", "content": "test input"}, | ||
{ | ||
"role": "assistant", | ||
"content": None, | ||
"tool_calls": [ | ||
{ | ||
"id": "call_1", | ||
"type": "function", | ||
"function": { | ||
"name": "task_response", | ||
"arguments": {"key": "value"}, | ||
}, | ||
} | ||
], | ||
}, | ||
] | ||
} | ||
|
||
|
||
def test_generate_chat_message_toolcall_invalid_json(): | ||
task_run = TaskRun( | ||
id="run1", | ||
input="test input", | ||
input_source=DataSource( | ||
type=DataSourceType.human, properties={"created_by": "test"} | ||
), | ||
output=TaskOutput( | ||
output="invalid json", | ||
source=DataSource( | ||
type=DataSourceType.synthetic, | ||
properties={ | ||
"model_name": "test", | ||
"model_provider": "test", | ||
"adapter_name": "test", | ||
}, | ||
), | ||
), | ||
) | ||
|
||
with pytest.raises(ValueError, match="Invalid JSON in task run output"): | ||
generate_chat_message_toolcall(task_run, "system message") | ||
|
||
|
||
def test_dataset_formatter_init_no_parent_task(mock_dataset): | ||
mock_dataset.parent_task.return_value = None | ||
|
||
with pytest.raises(ValueError, match="Dataset has no parent task"): | ||
DatasetFormatter(mock_dataset, "system message") | ||
|
||
|
||
def test_dataset_formatter_dump_invalid_format(mock_dataset): | ||
formatter = DatasetFormatter(mock_dataset, "system message") | ||
|
||
with pytest.raises(ValueError, match="Unsupported format"): | ||
formatter.dump_to_file("train", "invalid_format") # type: ignore | ||
|
||
|
||
def test_dataset_formatter_dump_invalid_split(mock_dataset): | ||
formatter = DatasetFormatter(mock_dataset, "system message") | ||
|
||
with pytest.raises(ValueError, match="Split invalid_split not found in dataset"): | ||
formatter.dump_to_file( | ||
"invalid_split", DatasetFormat.CHAT_MESSAGE_RESPONSE_JSONL | ||
) | ||
|
||
|
||
def test_dataset_formatter_dump_to_file(mock_dataset, tmp_path): | ||
formatter = DatasetFormatter(mock_dataset, "system message") | ||
output_path = tmp_path / "output.jsonl" | ||
|
||
result_path = formatter.dump_to_file( | ||
"train", DatasetFormat.CHAT_MESSAGE_RESPONSE_JSONL, output_path | ||
) | ||
|
||
assert result_path == output_path | ||
assert output_path.exists() | ||
|
||
# Verify file contents | ||
with open(output_path) as f: | ||
lines = f.readlines() | ||
assert len(lines) == 2 # Should have 2 entries for train split | ||
for line in lines: | ||
data = json.loads(line) | ||
assert "messages" in data | ||
assert len(data["messages"]) == 3 | ||
assert data["messages"][0]["content"] == "system message" | ||
assert data["messages"][1]["content"] == '{"test": "input"}' | ||
assert data["messages"][2]["content"] == '{"test": "output"}' | ||
|
||
|
||
def test_dataset_formatter_dump_to_temp_file(mock_dataset): | ||
formatter = DatasetFormatter(mock_dataset, "system message") | ||
|
||
result_path = formatter.dump_to_file( | ||
"train", DatasetFormat.CHAT_MESSAGE_RESPONSE_JSONL | ||
) | ||
|
||
assert result_path.exists() | ||
assert result_path.parent == Path(tempfile.gettempdir()) | ||
assert result_path.name.startswith("test_dataset_train_") | ||
assert result_path.name.endswith(".jsonl") | ||
# Verify file contents | ||
with open(result_path) as f: | ||
lines = f.readlines() | ||
assert len(lines) == 2 | ||
|
||
|
||
def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path): | ||
formatter = DatasetFormatter(mock_dataset, "system message") | ||
output_path = tmp_path / "output.jsonl" | ||
|
||
result_path = formatter.dump_to_file( | ||
"train", DatasetFormat.CHAT_MESSAGE_TOOLCALL_JSONL, output_path | ||
) | ||
|
||
assert result_path == output_path | ||
assert output_path.exists() | ||
|
||
# Verify file contents | ||
with open(output_path) as f: | ||
lines = f.readlines() | ||
assert len(lines) == 2 # Should have 2 entries for train split | ||
for line in lines: | ||
data = json.loads(line) | ||
assert "messages" in data | ||
assert len(data["messages"]) == 3 | ||
# Check system and user messages | ||
assert data["messages"][0]["content"] == "system message" | ||
assert data["messages"][1]["content"] == '{"test": "input"}' | ||
# Check tool call format | ||
assistant_msg = data["messages"][2] | ||
assert assistant_msg["content"] is None | ||
assert "tool_calls" in assistant_msg | ||
assert len(assistant_msg["tool_calls"]) == 1 | ||
tool_call = assistant_msg["tool_calls"][0] | ||
assert tool_call["type"] == "function" | ||
assert tool_call["function"]["name"] == "task_response" | ||
assert tool_call["function"]["arguments"] == {"test": "output"} |