Skip to content

Commit

Permalink
Add missing_count
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Nov 22, 2024
1 parent b4fc215 commit a5714c5
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
15 changes: 15 additions & 0 deletions libs/core/kiln_ai/adapters/fine_tune/dataset_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,18 @@ def build_split_contents(
split_contents[splits[-1].name] = valid_ids[start_idx:]

return split_contents

def missing_count(self) -> int:
"""
Returns:
int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
"""
if TYPE_CHECKING and not isinstance(self.parent, Task):
raise ValueError("DatasetSplit has no parent task")
runs = self.parent.runs()
all_ids = set(run.id for run in runs)
all_ids_in_splits = set()
for ids in self.split_contents.values():
all_ids_in_splits.update(ids)
missing = all_ids_in_splits - all_ids
return len(missing)
38 changes: 38 additions & 0 deletions libs/core/kiln_ai/adapters/fine_tune/test_dataset_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,41 @@ def test_dataset_split_with_single_split(sample_task, sample_task_runs):
dataset = DatasetSplit.from_task("Split Name", sample_task, splits)

assert len(dataset.split_contents["all"]) == len(sample_task_runs)


def test_missing_count(sample_task, sample_task_runs):
assert sample_task_runs is not None
# Create a dataset split with all task runs
dataset = DatasetSplit.from_task(
"Split Name", sample_task, Train80Test20SplitDefinition
)

# Initially there should be no missing runs
assert dataset.missing_count() == 0

# Add some IDs to the split, that aren't on disk
dataset.split_contents["test"].append("1")
dataset.split_contents["test"].append("2")
dataset.split_contents["test"].append("3")
# shouldn't happen, but should not double count if it does
dataset.split_contents["train"].append("3")

# Now we should have 3 missing runs
assert dataset.missing_count() == 3


def test_smaller_sample(sample_task, sample_task_runs):
assert sample_task_runs is not None
# Create a dataset split with all task runs
dataset = DatasetSplit.from_task(
"Split Name", sample_task, Train80Test20SplitDefinition
)

# Initially there should be no missing runs
assert dataset.missing_count() == 0

dataset.split_contents["test"].pop()
dataset.split_contents["train"].pop()

# Now we should have 0 missing runs. It's okay that dataset has newer data.
assert dataset.missing_count() == 0

0 comments on commit a5714c5

Please sign in to comment.