Skip to content

Commit

Permalink
Missing files from last commit
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Nov 22, 2024
1 parent 1817aec commit 9c442d3
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 0 deletions.
71 changes: 71 additions & 0 deletions libs/core/kiln_ai/adapters/fine_tune/base_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Literal

from pydantic import Field

from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel

if TYPE_CHECKING:
from kiln_ai.datamodel import Task


class FineTuneStatus(str, Enum):
pending = "pending"
running = "running"
completed = "completed"
failed = "failed"


@dataclass
class FineTuneParameter:
name: str
type: Literal["string", "int", "float", "bool"]
description: str
optional: bool = True


class BaseFinetune(KilnParentedModel, ABC):
name: str = NAME_FIELD
description: str | None = Field(
default=None,
description="A description of the fine-tune for you and your team. Not used in training.",
)
provider: str = Field(
description="The provider to use for the fine-tune (e.g. 'openai')."
)
base_model_id: str = Field(
description="The id of the base model to use for the fine-tune. This string relates to the provider's IDs for their own models, not Kiln IDs."
)
provider_id: str | None = Field(
default=None,
description="The ID of the fine-tuned model on the provider's side.",
)
parameters: dict[str, str | int | float | bool] = Field(
default_factory=dict,
description="The parameters to use for this fine-tune. These are provider-specific.",
)

def parent_task(self) -> "Task | None":
# inline import to avoid circular import
from kiln_ai.datamodel import Task

if not isinstance(self.parent, Task):
return None
return self.parent

@abstractmethod
def start(self) -> None:
pass

@abstractmethod
def status(self) -> FineTuneStatus:
pass

@classmethod
def available_parameters(cls) -> list[FineTuneParameter]:
"""
Returns a list of parameters that can be provided for this fine-tune.
"""
return []
31 changes: 31 additions & 0 deletions libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from kiln_ai.adapters.fine_tune.base_finetune import BaseFinetune, FineTuneParameter


class OpenAIFinetune(BaseFinetune):
@classmethod
def available_parameters(cls) -> list[FineTuneParameter]:
return [
FineTuneParameter(
name="batch_size",
type="int",
description="Number of examples in each batch. A larger batch size means that model parameters are updated less frequently, but with lower variance. Defaults to 'auto'",
),
FineTuneParameter(
name="learning_rate_multiplier",
type="float",
description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid overfitting. Defaults to 'auto'",
optional=True,
),
FineTuneParameter(
name="n_epochs",
type="int",
description="The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset. Defaults to 'auto'",
optional=True,
),
FineTuneParameter(
name="seed",
type="int",
description="The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases. If a seed is not specified, one will be generated for you.",
optional=True,
),
]
143 changes: 143 additions & 0 deletions libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# ruff: noqa: I001 - Import order matters here. Need datamodel before finetune

import pytest
from pydantic import ValidationError

from kiln_ai.datamodel import Task
from kiln_ai.adapters.fine_tune.base_finetune import (
BaseFinetune,
FineTuneParameter,
FineTuneStatus,
)


class MockFinetune(BaseFinetune):
"""Mock implementation of BaseFinetune for testing"""

def start(self) -> None:
pass

def status(self) -> FineTuneStatus:
return FineTuneStatus.pending

@classmethod
def available_parameters(cls) -> list[FineTuneParameter]:
return [
FineTuneParameter(
name="learning_rate",
type="float",
description="Learning rate for training",
),
FineTuneParameter(
name="epochs",
type="int",
description="Number of training epochs",
optional=False,
),
]


@pytest.fixture
def sample_task(tmp_path):
task_path = tmp_path / "task.kiln"
task = Task(
name="Test Task",
path=task_path,
description="Test task for fine-tuning",
instruction="Test instruction",
)
task.save_to_file()
return task


@pytest.fixture
def basic_finetune(sample_task):
return MockFinetune(
name="test_finetune",
parent=sample_task,
provider="test_provider",
base_model_id="test_model",
provider_id="model_1234",
)


def test_finetune_basic_properties(basic_finetune):
assert basic_finetune.name == "test_finetune"
assert basic_finetune.provider == "test_provider"
assert basic_finetune.base_model_id == "test_model"
assert basic_finetune.provider_id == "model_1234"
assert basic_finetune.parameters == {}
assert basic_finetune.description is None


def test_finetune_parameters_validation():
with pytest.raises(ValidationError):
MockFinetune(
name="test",
provider="test_provider",
base_model_provider_id="test_model",
parameters="invalid", # Should be a dict
)


def test_finetune_parent_task(sample_task, basic_finetune):
assert basic_finetune.parent_task() == sample_task

# Test with no parent
orphan_finetune = MockFinetune(
name="orphan",
provider="test_provider",
base_model_id="test_model",
)
assert orphan_finetune.parent_task() is None


def test_finetune_status(basic_finetune):
assert basic_finetune.status() == FineTuneStatus.pending
assert isinstance(basic_finetune.status(), FineTuneStatus)


def test_available_parameters():
params = MockFinetune.available_parameters()
assert len(params) == 2

learning_rate_param = params[0]
assert learning_rate_param.name == "learning_rate"
assert learning_rate_param.type == "float"
assert learning_rate_param.optional is True

epochs_param = params[1]
assert epochs_param.name == "epochs"
assert epochs_param.type == "int"
assert epochs_param.optional is False


def test_finetune_status_enum():
assert set(FineTuneStatus) == {
FineTuneStatus.pending,
FineTuneStatus.running,
FineTuneStatus.completed,
FineTuneStatus.failed,
}


def test_finetune_with_parameters(sample_task):
finetune = MockFinetune(
name="test_with_params",
parent=sample_task,
provider="test_provider",
base_model_id="test_model",
parameters={
"learning_rate": 0.001,
"epochs": 10,
"batch_size": 32,
"fast": True,
"prefix": "test_prefix",
},
)

assert finetune.parameters["learning_rate"] == 0.001
assert finetune.parameters["epochs"] == 10
assert finetune.parameters["batch_size"] == 32
assert finetune.parameters["fast"] is True
assert finetune.parameters["prefix"] == "test_prefix"

0 comments on commit 9c442d3

Please sign in to comment.