-
Notifications
You must be signed in to change notification settings - Fork 164
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
245 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from typing import TYPE_CHECKING, Literal | ||
|
||
from pydantic import Field | ||
|
||
from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel | ||
|
||
if TYPE_CHECKING: | ||
from kiln_ai.datamodel import Task | ||
|
||
|
||
class FineTuneStatus(str, Enum): | ||
pending = "pending" | ||
running = "running" | ||
completed = "completed" | ||
failed = "failed" | ||
|
||
|
||
@dataclass | ||
class FineTuneParameter: | ||
name: str | ||
type: Literal["string", "int", "float", "bool"] | ||
description: str | ||
optional: bool = True | ||
|
||
|
||
class BaseFinetune(KilnParentedModel, ABC): | ||
name: str = NAME_FIELD | ||
description: str | None = Field( | ||
default=None, | ||
description="A description of the fine-tune for you and your team. Not used in training.", | ||
) | ||
provider: str = Field( | ||
description="The provider to use for the fine-tune (e.g. 'openai')." | ||
) | ||
base_model_id: str = Field( | ||
description="The id of the base model to use for the fine-tune. This string relates to the provider's IDs for their own models, not Kiln IDs." | ||
) | ||
provider_id: str | None = Field( | ||
default=None, | ||
description="The ID of the fine-tuned model on the provider's side.", | ||
) | ||
parameters: dict[str, str | int | float | bool] = Field( | ||
default_factory=dict, | ||
description="The parameters to use for this fine-tune. These are provider-specific.", | ||
) | ||
|
||
def parent_task(self) -> "Task | None": | ||
# inline import to avoid circular import | ||
from kiln_ai.datamodel import Task | ||
|
||
if not isinstance(self.parent, Task): | ||
return None | ||
return self.parent | ||
|
||
@abstractmethod | ||
def start(self) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def status(self) -> FineTuneStatus: | ||
pass | ||
|
||
@classmethod | ||
def available_parameters(cls) -> list[FineTuneParameter]: | ||
""" | ||
Returns a list of parameters that can be provided for this fine-tune. | ||
""" | ||
return [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from kiln_ai.adapters.fine_tune.base_finetune import BaseFinetune, FineTuneParameter | ||
|
||
|
||
class OpenAIFinetune(BaseFinetune): | ||
@classmethod | ||
def available_parameters(cls) -> list[FineTuneParameter]: | ||
return [ | ||
FineTuneParameter( | ||
name="batch_size", | ||
type="int", | ||
description="Number of examples in each batch. A larger batch size means that model parameters are updated less frequently, but with lower variance. Defaults to 'auto'", | ||
), | ||
FineTuneParameter( | ||
name="learning_rate_multiplier", | ||
type="float", | ||
description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid overfitting. Defaults to 'auto'", | ||
optional=True, | ||
), | ||
FineTuneParameter( | ||
name="n_epochs", | ||
type="int", | ||
description="The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset. Defaults to 'auto'", | ||
optional=True, | ||
), | ||
FineTuneParameter( | ||
name="seed", | ||
type="int", | ||
description="The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases. If a seed is not specified, one will be generated for you.", | ||
optional=True, | ||
), | ||
] |
143 changes: 143 additions & 0 deletions
143
libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# ruff: noqa: I001 - Import order matters here. Need datamodel before finetune | ||
|
||
import pytest | ||
from pydantic import ValidationError | ||
|
||
from kiln_ai.datamodel import Task | ||
from kiln_ai.adapters.fine_tune.base_finetune import ( | ||
BaseFinetune, | ||
FineTuneParameter, | ||
FineTuneStatus, | ||
) | ||
|
||
|
||
class MockFinetune(BaseFinetune): | ||
"""Mock implementation of BaseFinetune for testing""" | ||
|
||
def start(self) -> None: | ||
pass | ||
|
||
def status(self) -> FineTuneStatus: | ||
return FineTuneStatus.pending | ||
|
||
@classmethod | ||
def available_parameters(cls) -> list[FineTuneParameter]: | ||
return [ | ||
FineTuneParameter( | ||
name="learning_rate", | ||
type="float", | ||
description="Learning rate for training", | ||
), | ||
FineTuneParameter( | ||
name="epochs", | ||
type="int", | ||
description="Number of training epochs", | ||
optional=False, | ||
), | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def sample_task(tmp_path): | ||
task_path = tmp_path / "task.kiln" | ||
task = Task( | ||
name="Test Task", | ||
path=task_path, | ||
description="Test task for fine-tuning", | ||
instruction="Test instruction", | ||
) | ||
task.save_to_file() | ||
return task | ||
|
||
|
||
@pytest.fixture | ||
def basic_finetune(sample_task): | ||
return MockFinetune( | ||
name="test_finetune", | ||
parent=sample_task, | ||
provider="test_provider", | ||
base_model_id="test_model", | ||
provider_id="model_1234", | ||
) | ||
|
||
|
||
def test_finetune_basic_properties(basic_finetune): | ||
assert basic_finetune.name == "test_finetune" | ||
assert basic_finetune.provider == "test_provider" | ||
assert basic_finetune.base_model_id == "test_model" | ||
assert basic_finetune.provider_id == "model_1234" | ||
assert basic_finetune.parameters == {} | ||
assert basic_finetune.description is None | ||
|
||
|
||
def test_finetune_parameters_validation(): | ||
with pytest.raises(ValidationError): | ||
MockFinetune( | ||
name="test", | ||
provider="test_provider", | ||
base_model_provider_id="test_model", | ||
parameters="invalid", # Should be a dict | ||
) | ||
|
||
|
||
def test_finetune_parent_task(sample_task, basic_finetune): | ||
assert basic_finetune.parent_task() == sample_task | ||
|
||
# Test with no parent | ||
orphan_finetune = MockFinetune( | ||
name="orphan", | ||
provider="test_provider", | ||
base_model_id="test_model", | ||
) | ||
assert orphan_finetune.parent_task() is None | ||
|
||
|
||
def test_finetune_status(basic_finetune): | ||
assert basic_finetune.status() == FineTuneStatus.pending | ||
assert isinstance(basic_finetune.status(), FineTuneStatus) | ||
|
||
|
||
def test_available_parameters(): | ||
params = MockFinetune.available_parameters() | ||
assert len(params) == 2 | ||
|
||
learning_rate_param = params[0] | ||
assert learning_rate_param.name == "learning_rate" | ||
assert learning_rate_param.type == "float" | ||
assert learning_rate_param.optional is True | ||
|
||
epochs_param = params[1] | ||
assert epochs_param.name == "epochs" | ||
assert epochs_param.type == "int" | ||
assert epochs_param.optional is False | ||
|
||
|
||
def test_finetune_status_enum(): | ||
assert set(FineTuneStatus) == { | ||
FineTuneStatus.pending, | ||
FineTuneStatus.running, | ||
FineTuneStatus.completed, | ||
FineTuneStatus.failed, | ||
} | ||
|
||
|
||
def test_finetune_with_parameters(sample_task): | ||
finetune = MockFinetune( | ||
name="test_with_params", | ||
parent=sample_task, | ||
provider="test_provider", | ||
base_model_id="test_model", | ||
parameters={ | ||
"learning_rate": 0.001, | ||
"epochs": 10, | ||
"batch_size": 32, | ||
"fast": True, | ||
"prefix": "test_prefix", | ||
}, | ||
) | ||
|
||
assert finetune.parameters["learning_rate"] == 0.001 | ||
assert finetune.parameters["epochs"] == 10 | ||
assert finetune.parameters["batch_size"] == 32 | ||
assert finetune.parameters["fast"] is True | ||
assert finetune.parameters["prefix"] == "test_prefix" |