Skip to content

Commit

Permalink
require finetunes to specify the training split name, and optionally …
Browse files Browse the repository at this point in the history
…the test split name
  • Loading branch information
scosman committed Nov 23, 2024
1 parent eec9b17 commit 8727c36
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 6 deletions.
11 changes: 9 additions & 2 deletions libs/core/kiln_ai/adapters/fine_tune/base_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,15 @@ class BaseFinetuneAdapter(ABC):
A base class for fine-tuning adapters.
"""

def __init__(self, model: FinetuneModel):
self.model = model
def __init__(
self,
datamodel: FinetuneModel,
train_split_name: str,
test_split_name: str | None = None,
):
self.datamodel = datamodel
self.train_split_name = train_split_name
self.test_split_name = test_split_name

@abstractmethod
def start(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ class OpenAIFinetune(BaseFinetuneAdapter):
"""

def status(self) -> FineTuneStatus:
if not self.model or not self.model.provider_id:
if not self.datamodel or not self.datamodel.provider_id:
return FineTuneStatus(
status=FineTuneStatusType.pending,
message="This fine-tune has not been started or has not been assigned a provider ID.",
)

try:
# Will raise an error if the job is not found, or for other issues
response = oai_client.fine_tuning.jobs.retrieve(self.model.provider_id)
response = oai_client.fine_tuning.jobs.retrieve(self.datamodel.provider_id)
except openai.APIConnectionError:
return FineTuneStatus(
status=FineTuneStatusType.unknown, message="Server connection error"
Expand Down
3 changes: 2 additions & 1 deletion libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@ def sample_task(tmp_path):
@pytest.fixture
def basic_finetune(sample_task):
return MockFinetune(
model=FinetuneModel(
datamodel=FinetuneModel(
parent=sample_task,
name="test_finetune",
provider="test_provider",
provider_id="model_1234",
base_model_id="test_model",
),
train_split_name="train",
)


Expand Down
3 changes: 2 additions & 1 deletion libs/core/kiln_ai/adapters/fine_tune/test_openai_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
@pytest.fixture
def openai_finetune():
finetune = OpenAIFinetune(
model=FinetuneModel(
datamodel=FinetuneModel(
name="test-finetune",
provider="openai",
provider_id="openai-123",
base_model_id="gpt-4o",
),
train_split_name="train",
)
return finetune

Expand Down

0 comments on commit 8727c36

Please sign in to comment.