From a5714c5931ee4f7f2d40762ef6efb2eba18a8f48 Mon Sep 17 00:00:00 2001 From: scosman Date: Fri, 22 Nov 2024 13:42:45 -0500 Subject: [PATCH] Add missing_count --- .../adapters/fine_tune/dataset_split.py | 15 ++++++++ .../adapters/fine_tune/test_dataset_split.py | 38 +++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/libs/core/kiln_ai/adapters/fine_tune/dataset_split.py b/libs/core/kiln_ai/adapters/fine_tune/dataset_split.py index f2b1f7c..387adee 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/dataset_split.py +++ b/libs/core/kiln_ai/adapters/fine_tune/dataset_split.py @@ -131,3 +131,18 @@ def build_split_contents( split_contents[splits[-1].name] = valid_ids[start_idx:] return split_contents + + def missing_count(self) -> int: + """ + Returns: + int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset + """ + if TYPE_CHECKING and not isinstance(self.parent, Task): + raise ValueError("DatasetSplit has no parent task") + runs = self.parent.runs() + all_ids = set(run.id for run in runs) + all_ids_in_splits = set() + for ids in self.split_contents.values(): + all_ids_in_splits.update(ids) + missing = all_ids_in_splits - all_ids + return len(missing) diff --git a/libs/core/kiln_ai/adapters/fine_tune/test_dataset_split.py b/libs/core/kiln_ai/adapters/fine_tune/test_dataset_split.py index f6d48d5..577d3b3 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/test_dataset_split.py +++ b/libs/core/kiln_ai/adapters/fine_tune/test_dataset_split.py @@ -200,3 +200,41 @@ def test_dataset_split_with_single_split(sample_task, sample_task_runs): dataset = DatasetSplit.from_task("Split Name", sample_task, splits) assert len(dataset.split_contents["all"]) == len(sample_task_runs) + + +def test_missing_count(sample_task, sample_task_runs): + assert sample_task_runs is not None + # Create a dataset split with all task runs + dataset = DatasetSplit.from_task( + "Split Name", sample_task, Train80Test20SplitDefinition + ) + + # Initially there should be no missing runs + assert dataset.missing_count() == 0 + + # Add some IDs to the split, that aren't on disk + dataset.split_contents["test"].append("1") + dataset.split_contents["test"].append("2") + dataset.split_contents["test"].append("3") + # shouldn't happen, but should not double count if it does + dataset.split_contents["train"].append("3") + + # Now we should have 3 missing runs + assert dataset.missing_count() == 3 + + +def test_smaller_sample(sample_task, sample_task_runs): + assert sample_task_runs is not None + # Create a dataset split with all task runs + dataset = DatasetSplit.from_task( + "Split Name", sample_task, Train80Test20SplitDefinition + ) + + # Initially there should be no missing runs + assert dataset.missing_count() == 0 + + dataset.split_contents["test"].pop() + dataset.split_contents["train"].pop() + + # Now we should have 0 missing runs. It's okay that dataset has newer data. + assert dataset.missing_count() == 0