Skip to content

Commit

Permalink
Refactor fine tune. Separate model from adapter. Much cleaner. Remove…
Browse files Browse the repository at this point in the history
… cicurlar ref hacks, as they are gone.
  • Loading branch information
scosman committed Nov 23, 2024
1 parent 0a957f4 commit 151fe79
Show file tree
Hide file tree
Showing 11 changed files with 315 additions and 324 deletions.
42 changes: 19 additions & 23 deletions libs/core/kiln_ai/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,22 @@
The repair submodule contains an adapter for the repair task.
"""

import sys

# Avoid circular import since we use datamodel in some tests
if "pytest" not in sys.modules:
from . import (
base_adapter,
data_gen,
fine_tune,
langchain_adapters,
ml_model_list,
prompt_builders,
repair,
)

__all__ = [
"base_adapter",
"langchain_adapters",
"ml_model_list",
"prompt_builders",
"repair",
"data_gen",
"fine_tune",
]
from . import (
base_adapter,
data_gen,
fine_tune,
langchain_adapters,
ml_model_list,
prompt_builders,
repair,
)

__all__ = [
"base_adapter",
"langchain_adapters",
"ml_model_list",
"prompt_builders",
"repair",
"data_gen",
"fine_tune",
]
8 changes: 2 additions & 6 deletions libs/core/kiln_ai/adapters/fine_tune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
A set of classes for fine-tuning models.
"""

import sys
# from . import base_finetune, openai_finetune

# Avoid circular import since we use datamodel in some tests
if "pytest" not in sys.modules:
from . import dataset_split

__all__ = ["dataset_split"]
# __all__ = ["base_finetune", "openai_finetune"]
44 changes: 7 additions & 37 deletions libs/core/kiln_ai/adapters/fine_tune/base_finetune.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Literal
from typing import Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel

from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel

if TYPE_CHECKING:
from kiln_ai.datamodel import Task
from kiln_ai.datamodel import Finetune as FinetuneModel


class FineTuneStatusType(str, Enum):
Expand All @@ -28,42 +24,16 @@ class FineTuneStatus(BaseModel):
message: str | None = None


@dataclass
class FineTuneParameter:
class FineTuneParameter(BaseModel):
name: str
type: Literal["string", "int", "float", "bool"]
description: str
optional: bool = True


class BaseFinetune(KilnParentedModel, ABC):
name: str = NAME_FIELD
description: str | None = Field(
default=None,
description="A description of the fine-tune for you and your team. Not used in training.",
)
provider: str = Field(
description="The provider to use for the fine-tune (e.g. 'openai')."
)
base_model_id: str = Field(
description="The id of the base model to use for the fine-tune. This string relates to the provider's IDs for their own models, not Kiln IDs."
)
provider_id: str | None = Field(
default=None,
description="The ID of the fine-tuned model on the provider's side.",
)
parameters: dict[str, str | int | float | bool] = Field(
default_factory=dict,
description="The parameters to use for this fine-tune. These are provider-specific.",
)

def parent_task(self) -> "Task | None":
# inline import to avoid circular import
from kiln_ai.datamodel import Task

if not isinstance(self.parent, Task):
return None
return self.parent
class BaseFinetuneAdapter(ABC):
def __init__(self, model: FinetuneModel):
self.model = model

@abstractmethod
def start(self) -> None:
Expand Down
158 changes: 0 additions & 158 deletions libs/core/kiln_ai/adapters/fine_tune/dataset_split.py

This file was deleted.

14 changes: 5 additions & 9 deletions libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from openai.types.fine_tuning import FineTuningJob

from kiln_ai.adapters.fine_tune.base_finetune import (
BaseFinetune,
BaseFinetuneAdapter,
FineTuneParameter,
FineTuneStatus,
FineTuneStatusType,
Expand All @@ -16,17 +16,17 @@
)


class OpenAIFinetune(BaseFinetune):
class OpenAIFinetune(BaseFinetuneAdapter):
def status(self) -> FineTuneStatus:
if not self.provider_id:
if not self.model or not self.model.provider_id:
return FineTuneStatus(
status=FineTuneStatusType.pending,
message="This fine-tune has not been started or has not been assigned a provider ID.",
)

try:
# Will raise an error if the job is not found, or for other issues
response = oai_client.fine_tuning.jobs.retrieve(self.provider_id)
response = oai_client.fine_tuning.jobs.retrieve(self.model.provider_id)
except openai.APIConnectionError:
return FineTuneStatus(
status=FineTuneStatusType.unknown, message="Server connection error"
Expand Down Expand Up @@ -67,11 +67,7 @@ def status(self) -> FineTuneStatus:
return FineTuneStatus(
status=FineTuneStatusType.failed, message="Job cancelled"
)
if (
status in ["validating_files", "running", "queued"]
or response.finished_at is None
or response.estimated_finish is not None
):
if status in ["validating_files", "running", "queued"]:
time_to_finish_msg: str | None = None
if response.estimated_finish is not None:
time_to_finish_msg = f"Estimated finish time: {int(response.estimated_finish - time.time())} seconds."
Expand Down
Loading

0 comments on commit 151fe79

Please sign in to comment.