diff --git a/libs/core/kiln_ai/adapters/__init__.py b/libs/core/kiln_ai/adapters/__init__.py index a72cb407..57d31b62 100644 --- a/libs/core/kiln_ai/adapters/__init__.py +++ b/libs/core/kiln_ai/adapters/__init__.py @@ -12,26 +12,22 @@ The repair submodule contains an adapter for the repair task. """ -import sys - -# Avoid circular import since we use datamodel in some tests -if "pytest" not in sys.modules: - from . import ( - base_adapter, - data_gen, - fine_tune, - langchain_adapters, - ml_model_list, - prompt_builders, - repair, - ) - - __all__ = [ - "base_adapter", - "langchain_adapters", - "ml_model_list", - "prompt_builders", - "repair", - "data_gen", - "fine_tune", - ] +from . import ( + base_adapter, + data_gen, + fine_tune, + langchain_adapters, + ml_model_list, + prompt_builders, + repair, +) + +__all__ = [ + "base_adapter", + "langchain_adapters", + "ml_model_list", + "prompt_builders", + "repair", + "data_gen", + "fine_tune", +] diff --git a/libs/core/kiln_ai/adapters/fine_tune/__init__.py b/libs/core/kiln_ai/adapters/fine_tune/__init__.py index 4e6c3e2e..be3cc9ba 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/__init__.py +++ b/libs/core/kiln_ai/adapters/fine_tune/__init__.py @@ -4,10 +4,6 @@ A set of classes for fine-tuning models. """ -import sys +# from . import base_finetune, openai_finetune -# Avoid circular import since we use datamodel in some tests -if "pytest" not in sys.modules: - from . import dataset_split - - __all__ = ["dataset_split"] +# __all__ = ["base_finetune", "openai_finetune"] diff --git a/libs/core/kiln_ai/adapters/fine_tune/base_finetune.py b/libs/core/kiln_ai/adapters/fine_tune/base_finetune.py index e38bb35b..7b2edc59 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/base_finetune.py +++ b/libs/core/kiln_ai/adapters/fine_tune/base_finetune.py @@ -1,14 +1,10 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Literal +from typing import Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel -from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel - -if TYPE_CHECKING: - from kiln_ai.datamodel import Task +from kiln_ai.datamodel import Finetune as FinetuneModel class FineTuneStatusType(str, Enum): @@ -28,42 +24,16 @@ class FineTuneStatus(BaseModel): message: str | None = None -@dataclass -class FineTuneParameter: +class FineTuneParameter(BaseModel): name: str type: Literal["string", "int", "float", "bool"] description: str optional: bool = True -class BaseFinetune(KilnParentedModel, ABC): - name: str = NAME_FIELD - description: str | None = Field( - default=None, - description="A description of the fine-tune for you and your team. Not used in training.", - ) - provider: str = Field( - description="The provider to use for the fine-tune (e.g. 'openai')." - ) - base_model_id: str = Field( - description="The id of the base model to use for the fine-tune. This string relates to the provider's IDs for their own models, not Kiln IDs." - ) - provider_id: str | None = Field( - default=None, - description="The ID of the fine-tuned model on the provider's side.", - ) - parameters: dict[str, str | int | float | bool] = Field( - default_factory=dict, - description="The parameters to use for this fine-tune. These are provider-specific.", - ) - - def parent_task(self) -> "Task | None": - # inline import to avoid circular import - from kiln_ai.datamodel import Task - - if not isinstance(self.parent, Task): - return None - return self.parent +class BaseFinetuneAdapter(ABC): + def __init__(self, model: FinetuneModel): + self.model = model @abstractmethod def start(self) -> None: diff --git a/libs/core/kiln_ai/adapters/fine_tune/dataset_split.py b/libs/core/kiln_ai/adapters/fine_tune/dataset_split.py deleted file mode 100644 index b558d4cf..00000000 --- a/libs/core/kiln_ai/adapters/fine_tune/dataset_split.py +++ /dev/null @@ -1,158 +0,0 @@ -import math -import random -from typing import TYPE_CHECKING, Callable - -from pydantic import BaseModel, Field, model_validator - -from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel - -if TYPE_CHECKING: - from kiln_ai.datamodel import Task, TaskRun -# Define the type alias for clarity -DatasetFilter = Callable[["TaskRun"], bool] - - -def AllDatasetFilter(_: "TaskRun") -> bool: - return True - - -def HighRatingDatasetFilter(task_run: "TaskRun") -> bool: - if task_run.output is None or task_run.output.rating is None: - return False - return task_run.output.rating.is_high_quality() - - -class DatasetSplitDefinition(BaseModel): - """ - A definition of a split in a dataset. - - Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) - """ - - name: str = NAME_FIELD - description: str | None = Field( - default=None, - description="A description of the dataset for you and your team. Not used in training.", - ) - percentage: float = Field( - ge=0.0, - le=1.0, - description="The percentage of the dataset that this split represents (between 0 and 1).", - ) - - -AllSplitDefinition: list[DatasetSplitDefinition] = [ - DatasetSplitDefinition(name="all", percentage=1.0) -] -Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [ - DatasetSplitDefinition(name="train", percentage=0.8), - DatasetSplitDefinition(name="test", percentage=0.2), -] -Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [ - DatasetSplitDefinition(name="train", percentage=0.6), - DatasetSplitDefinition(name="test", percentage=0.2), - DatasetSplitDefinition(name="val", percentage=0.2), -] - - -class DatasetSplit(KilnParentedModel): - """ - A collection of task runs, with optional splits (train, test, validation) - - You probably want to use DatasetSplit class from the datamodel module, which is has relationships to the Task and TaskRun models. - """ - - # TODO: NAME_FIELD - name: str - description: str | None = Field( - default=None, - description="A description of the dataset for you and your team. Not used in training.", - ) - splits: list[DatasetSplitDefinition] = Field( - default_factory=list, - description="The splits in the dataset.", - ) - split_contents: dict[str, list[str]] = Field( - description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", - ) - - @model_validator(mode="after") - def validate_split_percentages(self) -> "DatasetSplit": - total = sum(split.percentage for split in self.splits) - if not math.isclose(total, 1.0, rel_tol=1e-9): - raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") - return self - - @classmethod - def from_task( - cls, - name: str, - task: "Task", - splits: list[DatasetSplitDefinition], - filter: DatasetFilter = AllDatasetFilter, - description: str | None = None, - ): - split_contents = cls.build_split_contents(task, splits, filter) - return cls( - parent=task, - name=name, - description=description, - splits=splits, - split_contents=split_contents, - ) - - @classmethod - def build_split_contents( - cls, - task: "Task", - splits: list[DatasetSplitDefinition], - filter: DatasetFilter, - ) -> dict[str, list[str]]: - valid_ids = [] - for task_run in task.runs(): - if filter(task_run): - valid_ids.append(task_run.id) - - # Shuffle and split by split percentage - random.shuffle(valid_ids) - split_contents = {} - start_idx = 0 - remaining_items = len(valid_ids) - - # Handle all splits except the last one - for split in splits[:-1]: - split_size = round(len(valid_ids) * split.percentage) - split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] - start_idx += split_size - remaining_items -= split_size - - # Last split gets all remaining items (for rounding) - if splits: - split_contents[splits[-1].name] = valid_ids[start_idx:] - - return split_contents - - def parent_task(self) -> "Task | None": - # inline import to avoid circular import - from kiln_ai.datamodel import Task - - if not isinstance(self.parent, Task): - return None - return self.parent - - def missing_count(self) -> int: - """ - Returns: - int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset - """ - parent = self.parent_task() - if parent is None: - raise ValueError("DatasetSplit has no parent task") - - runs = parent.runs() - all_ids = set(run.id for run in runs) - all_ids_in_splits = set() - for ids in self.split_contents.values(): - all_ids_in_splits.update(ids) - missing = all_ids_in_splits - all_ids - return len(missing) diff --git a/libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py b/libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py index 9e26b061..29c8c37e 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py +++ b/libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py @@ -4,7 +4,7 @@ from openai.types.fine_tuning import FineTuningJob from kiln_ai.adapters.fine_tune.base_finetune import ( - BaseFinetune, + BaseFinetuneAdapter, FineTuneParameter, FineTuneStatus, FineTuneStatusType, @@ -16,9 +16,9 @@ ) -class OpenAIFinetune(BaseFinetune): +class OpenAIFinetune(BaseFinetuneAdapter): def status(self) -> FineTuneStatus: - if not self.provider_id: + if not self.model or not self.model.provider_id: return FineTuneStatus( status=FineTuneStatusType.pending, message="This fine-tune has not been started or has not been assigned a provider ID.", @@ -26,7 +26,7 @@ def status(self) -> FineTuneStatus: try: # Will raise an error if the job is not found, or for other issues - response = oai_client.fine_tuning.jobs.retrieve(self.provider_id) + response = oai_client.fine_tuning.jobs.retrieve(self.model.provider_id) except openai.APIConnectionError: return FineTuneStatus( status=FineTuneStatusType.unknown, message="Server connection error" @@ -67,11 +67,7 @@ def status(self) -> FineTuneStatus: return FineTuneStatus( status=FineTuneStatusType.failed, message="Job cancelled" ) - if ( - status in ["validating_files", "running", "queued"] - or response.finished_at is None - or response.estimated_finish is not None - ): + if status in ["validating_files", "running", "queued"]: time_to_finish_msg: str | None = None if response.estimated_finish is not None: time_to_finish_msg = f"Estimated finish time: {int(response.estimated_finish - time.time())} seconds." diff --git a/libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py b/libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py index b8f999e8..0d6a23b3 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py +++ b/libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py @@ -1,18 +1,16 @@ -# ruff: noqa: I001 - Import order matters here. Need datamodel before finetune - import pytest -from pydantic import ValidationError -from kiln_ai.datamodel import Task from kiln_ai.adapters.fine_tune.base_finetune import ( - BaseFinetune, + BaseFinetuneAdapter, FineTuneParameter, FineTuneStatus, FineTuneStatusType, ) +from kiln_ai.datamodel import Finetune as FinetuneModel +from kiln_ai.datamodel import Task -class MockFinetune(BaseFinetune): +class MockFinetune(BaseFinetuneAdapter): """Mock implementation of BaseFinetune for testing""" def start(self) -> None: @@ -54,43 +52,14 @@ def sample_task(tmp_path): @pytest.fixture def basic_finetune(sample_task): return MockFinetune( - name="test_finetune", - parent=sample_task, - provider="test_provider", - base_model_id="test_model", - provider_id="model_1234", - ) - - -def test_finetune_basic_properties(basic_finetune): - assert basic_finetune.name == "test_finetune" - assert basic_finetune.provider == "test_provider" - assert basic_finetune.base_model_id == "test_model" - assert basic_finetune.provider_id == "model_1234" - assert basic_finetune.parameters == {} - assert basic_finetune.description is None - - -def test_finetune_parameters_validation(): - with pytest.raises(ValidationError): - MockFinetune( - name="test", + model=FinetuneModel( + parent=sample_task, + name="test_finetune", provider="test_provider", - base_model_provider_id="test_model", - parameters="invalid", # Should be a dict - ) - - -def test_finetune_parent_task(sample_task, basic_finetune): - assert basic_finetune.parent_task() == sample_task - - # Test with no parent - orphan_finetune = MockFinetune( - name="orphan", - provider="test_provider", - base_model_id="test_model", + provider_id="model_1234", + base_model_id="test_model", + ), ) - assert orphan_finetune.parent_task() is None def test_finetune_status(basic_finetune): @@ -112,25 +81,3 @@ def test_available_parameters(): assert epochs_param.name == "epochs" assert epochs_param.type == "int" assert epochs_param.optional is False - - -def test_finetune_with_parameters(sample_task): - finetune = MockFinetune( - name="test_with_params", - parent=sample_task, - provider="test_provider", - base_model_id="test_model", - parameters={ - "learning_rate": 0.001, - "epochs": 10, - "batch_size": 32, - "fast": True, - "prefix": "test_prefix", - }, - ) - - assert finetune.parameters["learning_rate"] == 0.001 - assert finetune.parameters["epochs"] == 10 - assert finetune.parameters["batch_size"] == 32 - assert finetune.parameters["fast"] is True - assert finetune.parameters["prefix"] == "test_prefix" diff --git a/libs/core/kiln_ai/adapters/fine_tune/test_openai_finetune.py b/libs/core/kiln_ai/adapters/fine_tune/test_openai_finetune.py index a98cc33d..577a3596 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/test_openai_finetune.py +++ b/libs/core/kiln_ai/adapters/fine_tune/test_openai_finetune.py @@ -1,31 +1,32 @@ -# ruff: noqa: I001 - Import order matters here. Need datamodel before finetune - import time from unittest.mock import MagicMock, patch import openai import pytest +from openai.types.fine_tuning import FineTuningJob -from kiln_ai.utils.config import Config -from kiln_ai.datamodel import Task from kiln_ai.adapters.fine_tune.base_finetune import FineTuneStatusType from kiln_ai.adapters.fine_tune.openai_finetune import OpenAIFinetune +from kiln_ai.datamodel import Finetune as FinetuneModel +from kiln_ai.utils.config import Config @pytest.fixture def openai_finetune(): finetune = OpenAIFinetune( - name="test-finetune", - provider="openai", - provider_id="openai-123", - base_model_id="gpt-4o", + model=FinetuneModel( + name="test-finetune", + provider="openai", + provider_id="openai-123", + base_model_id="gpt-4o", + ), ) return finetune @pytest.fixture def mock_response(): - response = MagicMock() + response = MagicMock(spec=FineTuningJob) response.error = None response.status = "succeeded" response.finished_at = time.time() diff --git a/libs/core/kiln_ai/datamodel/__init__.py b/libs/core/kiln_ai/datamodel/__init__.py index 33f47b47..db2a2e20 100644 --- a/libs/core/kiln_ai/datamodel/__init__.py +++ b/libs/core/kiln_ai/datamodel/__init__.py @@ -1,16 +1,16 @@ from __future__ import annotations import json +import math +import random from enum import Enum, IntEnum -from typing import TYPE_CHECKING, Dict, List, Type, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Type, Union import jsonschema import jsonschema.exceptions from pydantic import BaseModel, Field, model_validator from typing_extensions import Self -from kiln_ai.adapters.fine_tune.base_finetune import BaseFinetune -from kiln_ai.adapters.fine_tune.dataset_split import DatasetSplit from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str from .basemodel import ( @@ -140,6 +140,33 @@ def validate_output_format(self, task: Task) -> Self: return self +class Finetune(KilnParentedModel): + name: str = NAME_FIELD + description: str | None = Field( + default=None, + description="A description of the fine-tune for you and your team. Not used in training.", + ) + provider: str = Field( + description="The provider to use for the fine-tune (e.g. 'openai')." + ) + base_model_id: str = Field( + description="The id of the base model to use for the fine-tune. This string relates to the provider's IDs for their own models, not Kiln IDs." + ) + provider_id: str | None = Field( + default=None, + description="The ID of the fine-tuned model on the provider's side.", + ) + parameters: dict[str, str | int | float | bool] = Field( + default_factory=dict, + description="The parameters to use for this fine-tune. These are provider-specific.", + ) + + def parent_task(self) -> Task | None: + if not isinstance(self.parent, Task): + return None + return self.parent + + class DataSourceType(str, Enum): """ The source type of a piece of data. @@ -328,6 +355,153 @@ def validate_repaired_output(self) -> Self: return self +# Define the type alias for clarity +DatasetFilter = Callable[[TaskRun], bool] + + +def AllDatasetFilter(_: TaskRun) -> bool: + return True + + +def HighRatingDatasetFilter(task_run: TaskRun) -> bool: + if task_run.output is None or task_run.output.rating is None: + return False + return task_run.output.rating.is_high_quality() + + +class DatasetSplitDefinition(BaseModel): + """ + A definition of a split in a dataset. + + Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) + """ + + name: str = NAME_FIELD + description: str | None = Field( + default=None, + description="A description of the dataset for you and your team. Not used in training.", + ) + percentage: float = Field( + ge=0.0, + le=1.0, + description="The percentage of the dataset that this split represents (between 0 and 1).", + ) + + +AllSplitDefinition: list[DatasetSplitDefinition] = [ + DatasetSplitDefinition(name="all", percentage=1.0) +] +Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [ + DatasetSplitDefinition(name="train", percentage=0.8), + DatasetSplitDefinition(name="test", percentage=0.2), +] +Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [ + DatasetSplitDefinition(name="train", percentage=0.6), + DatasetSplitDefinition(name="test", percentage=0.2), + DatasetSplitDefinition(name="val", percentage=0.2), +] + + +class DatasetSplit(KilnParentedModel): + """ + A collection of task runs, with optional splits (train, test, validation) + """ + + name: str = NAME_FIELD + description: str | None = Field( + default=None, + description="A description of the dataset for you and your team. Not used in training.", + ) + splits: list[DatasetSplitDefinition] = Field( + default_factory=list, + description="The splits in the dataset.", + ) + split_contents: dict[str, list[str]] = Field( + description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", + ) + + @model_validator(mode="after") + def validate_split_percentages(self) -> "DatasetSplit": + total = sum(split.percentage for split in self.splits) + if not math.isclose(total, 1.0, rel_tol=1e-9): + raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") + return self + + @classmethod + def from_task( + cls, + name: str, + task: "Task", + splits: list[DatasetSplitDefinition], + filter: DatasetFilter = AllDatasetFilter, + description: str | None = None, + ): + split_contents = cls.build_split_contents(task, splits, filter) + return cls( + parent=task, + name=name, + description=description, + splits=splits, + split_contents=split_contents, + ) + + @classmethod + def build_split_contents( + cls, + task: "Task", + splits: list[DatasetSplitDefinition], + filter: DatasetFilter, + ) -> dict[str, list[str]]: + valid_ids = [] + for task_run in task.runs(): + if filter(task_run): + valid_ids.append(task_run.id) + + # Shuffle and split by split percentage + random.shuffle(valid_ids) + split_contents = {} + start_idx = 0 + remaining_items = len(valid_ids) + + # Handle all splits except the last one + for split in splits[:-1]: + split_size = round(len(valid_ids) * split.percentage) + split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] + start_idx += split_size + remaining_items -= split_size + + # Last split gets all remaining items (for rounding) + if splits: + split_contents[splits[-1].name] = valid_ids[start_idx:] + + return split_contents + + def parent_task(self) -> "Task | None": + # inline import to avoid circular import + from kiln_ai.datamodel import Task + + if not isinstance(self.parent, Task): + return None + return self.parent + + def missing_count(self) -> int: + """ + Returns: + int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset + """ + parent = self.parent_task() + if parent is None: + raise ValueError("DatasetSplit has no parent task") + + runs = parent.runs() + all_ids = set(run.id for run in runs) + all_ids_in_splits = set() + for ids in self.split_contents.values(): + all_ids_in_splits.update(ids) + missing = all_ids_in_splits - all_ids + return len(missing) + + class TaskRequirement(BaseModel): """ Defines a specific requirement that should be met by task outputs. @@ -363,7 +537,7 @@ class Task( parent_of={ "runs": TaskRun, "dataset_splits": DatasetSplit, - "finetunes": BaseFinetune, + "finetunes": Finetune, }, ): """ diff --git a/libs/core/kiln_ai/adapters/fine_tune/test_dataset_split.py b/libs/core/kiln_ai/datamodel/test_dataset_split.py similarity index 97% rename from libs/core/kiln_ai/adapters/fine_tune/test_dataset_split.py rename to libs/core/kiln_ai/datamodel/test_dataset_split.py index 577d3b31..6e777e35 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/test_dataset_split.py +++ b/libs/core/kiln_ai/datamodel/test_dataset_split.py @@ -1,28 +1,22 @@ -# ruff: noqa: I001 - Import order matters here. Need datamodel before dataset_split - import pytest from pydantic import ValidationError # import datamodel first or we get circular import errors from kiln_ai.datamodel import ( + AllDatasetFilter, + AllSplitDefinition, + DatasetSplit, + DatasetSplitDefinition, DataSource, DataSourceType, + HighRatingDatasetFilter, Task, TaskOutput, TaskOutputRating, TaskOutputRatingType, TaskRun, -) - -# import dataset_split last -from kiln_ai.adapters.fine_tune.dataset_split import ( - AllDatasetFilter, - DatasetSplit, - DatasetSplitDefinition, - HighRatingDatasetFilter, - Train80Test20SplitDefinition, - AllSplitDefinition, Train60Test20Val20SplitDefinition, + Train80Test20SplitDefinition, ) diff --git a/libs/core/kiln_ai/datamodel/test_example_models.py b/libs/core/kiln_ai/datamodel/test_example_models.py index 2e396fcd..b35a8817 100644 --- a/libs/core/kiln_ai/datamodel/test_example_models.py +++ b/libs/core/kiln_ai/datamodel/test_example_models.py @@ -5,10 +5,10 @@ from pydantic import ValidationError from kiln_ai.datamodel import ( - BaseFinetune, DatasetSplit, DataSource, DataSourceType, + Finetune, Project, Task, TaskDeterminism, @@ -105,8 +105,8 @@ def test_dataset_split_relationship(): def test_base_finetune_relationship(): - assert BaseFinetune.relationship_name() == "finetunes" - assert BaseFinetune.parent_type().__name__ == "Task" + assert Finetune.relationship_name() == "finetunes" + assert Finetune.parent_type().__name__ == "Task" def test_structured_output_workflow(tmp_path): diff --git a/libs/core/kiln_ai/datamodel/test_models.py b/libs/core/kiln_ai/datamodel/test_models.py index 209f2e13..3c816af7 100644 --- a/libs/core/kiln_ai/datamodel/test_models.py +++ b/libs/core/kiln_ai/datamodel/test_models.py @@ -6,6 +6,7 @@ from kiln_ai.datamodel import ( DataSource, DataSourceType, + Finetune, Project, Task, TaskOutput, @@ -225,3 +226,77 @@ def test_task_run_intermediate_outputs(): "cot": "chain of thought output", "draft": "draft output", } + + +def test_finetune_basic(): + # Test basic initialization + finetune = Finetune( + name="test-finetune", + provider="openai", + base_model_id="gpt-3.5-turbo", + ) + assert finetune.name == "test-finetune" + assert finetune.provider == "openai" + assert finetune.base_model_id == "gpt-3.5-turbo" + assert finetune.provider_id is None + assert finetune.parameters == {} + assert finetune.description is None + + +def test_finetune_full(): + # Test with all fields populated + finetune = Finetune( + name="test-finetune", + description="Test description", + provider="openai", + base_model_id="gpt-3.5-turbo", + provider_id="ft-abc123", + parameters={ + "epochs": 3, + "learning_rate": 0.1, + "batch_size": 4, + "use_fp16": True, + "model_suffix": "-v1", + }, + ) + assert finetune.description == "Test description" + assert finetune.provider_id == "ft-abc123" + assert finetune.parameters == { + "epochs": 3, + "learning_rate": 0.1, + "batch_size": 4, + "use_fp16": True, + "model_suffix": "-v1", + } + + +def test_finetune_parent_task(): + # Test parent_task() method + task = Task(name="Test Task", instruction="Test instruction") + finetune = Finetune( + name="test-finetune", + provider="openai", + base_model_id="gpt-3.5-turbo", + parent=task, + ) + + assert finetune.parent_task() == task + + # Test with no parent + finetune_no_parent = Finetune( + name="test-finetune", + provider="openai", + base_model_id="gpt-3.5-turbo", + ) + assert finetune_no_parent.parent_task() is None + + +def test_finetune_parameters_validation(): + # Test that parameters only accept valid types + with pytest.raises(ValidationError): + Finetune( + name="test-finetune", + provider="openai", + base_model_id="gpt-3.5-turbo", + parameters={"invalid": [1, 2, 3]}, # Lists are not allowed + )