Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: First class predictions #187

Merged
merged 35 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6d1ab81
add base benchmark predictions class, move tests
kirahowe Aug 31, 2024
195b938
wip, validations working
kirahowe Aug 31, 2024
7ae809e
trying to make types work with arbitrary incoming values
kirahowe Aug 31, 2024
7b4687c
fix equality checking issue in test. closes #169.
kirahowe Aug 31, 2024
739aea7
add serializer to predictions
kirahowe Aug 31, 2024
7beaf05
remove separate competition predictions
kirahowe Aug 31, 2024
64fea34
update evaluation usage to work with benchmark predictions instance
kirahowe Aug 31, 2024
ca37ecc
update test set generation for evaluation
kirahowe Sep 1, 2024
eed1c53
run ruff autoformatting
kirahowe Sep 1, 2024
dbbdb54
Update polaris/utils/types.py
kiramclean Sep 4, 2024
40f9eb1
Update polaris/utils/types.py
kiramclean Sep 4, 2024
5eb5d45
wip
kirahowe Sep 5, 2024
07516f3
Merge branch 'main' into first-class-predictions
kirahowe Sep 5, 2024
5ec8959
add small docstring, allow string predictions
kirahowe Sep 9, 2024
41e9a80
safely get predictions in evaluation if available
kirahowe Sep 10, 2024
4084cf3
Merge branch 'main' into first-class-predictions
kirahowe Sep 10, 2024
907b77d
pass test set names to predictions and check for validity
kirahowe Sep 10, 2024
dd78c2d
simplify safe_mask
kirahowe Sep 10, 2024
14311a1
fix bad docstring path
kirahowe Sep 10, 2024
45189c3
Merge branch 'main' into first-class-predictions
cwognum Sep 11, 2024
e968a9b
Reintroduce the CompetitionPredictions class because it includes addi…
cwognum Sep 11, 2024
67c8f71
Add back the CompetitonPredictions to the docs
cwognum Sep 11, 2024
5170818
Reordered docs
cwognum Sep 11, 2024
cc2223c
Improved documentation and changed logic to disallow some edge cases
cwognum Sep 12, 2024
8a545d7
Fixed docs
cwognum Sep 12, 2024
175f779
Remove print statement
cwognum Sep 12, 2024
05af664
Reorganize code
cwognum Sep 12, 2024
74f4612
Simplified evaluation logic
cwognum Sep 12, 2024
76ec79b
Address all PR feedback
cwognum Sep 19, 2024
87cbfd3
Merge branch 'main' into first-class-predictions
cwognum Sep 19, 2024
ddf39a3
Add extra test case
cwognum Sep 19, 2024
b8a5bc4
Merge branch 'main' into first-class-predictions
cwognum Nov 19, 2024
5d90b5e
Addressed PR feedback
cwognum Nov 19, 2024
f926b3d
Fix type hint and fix model validator definition
cwognum Nov 19, 2024
d732a08
Fix import
cwognum Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/api/evaluation.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
::: polaris.evaluate.BenchmarkPredictions

---

::: polaris.evaluate.ResultsMetadata
options:
filters: ["!^_"]
Expand Down Expand Up @@ -25,4 +29,4 @@
::: polaris.evaluate.metrics.generic_metrics
::: polaris.evaluate.metrics.docking_metrics

---
---
3 changes: 3 additions & 0 deletions docs/tutorials/competition.participate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@
"competition_predictions = CompetitionPredictions(\n",
" name=\"hello-world-result\",\n",
" predictions=predictions,\n",
" target_labels=competition.target_cols,\n",
" test_set_labels=competition.test_set_labels,\n",
" test_set_sizes=competition.test_set_sizes,\n",
" github_url=\"https://github.com/polaris-hub/polaris-hub\",\n",
" paper_url=\"https://polarishub.io/\",\n",
" description=\"Hello, World!\",\n",
Expand Down
14 changes: 8 additions & 6 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,21 @@ nav:
- PDB Datasets: tutorials/dataset_pdb.ipynb
- SDF Datasets: tutorials/dataset_sdf.ipynb
- Optimization: tutorials/optimization.ipynb
- Competitions:
- tutorials/competition.participate.ipynb
# NOTE (cwognum): Competitions are currently gated.
# - Competitions:
# - tutorials/competition.participate.ipynb
- API Reference:
- Load: api/load.md
- Core:
- Dataset: api/dataset.md
- Benchmark: api/benchmark.md
- Subset: api/subset.md
- Evaluation: api/evaluation.md
- Competitions:
- Competition Dataset: api/competition.dataset.md
- Competition: api/competition.md
- Competiton Evaluation: api/competition.evaluation.md
# NOTE (cwognum): Competitions are currently gated.
# - Competitions:
# - Competition Dataset: api/competition.dataset.md
# - Competition: api/competition.md
# - Competiton Evaluation: api/competition.evaluation.md
- Hub:
- Client: api/hub.client.md
- External Auth Client: api/hub.external_client.md
Expand Down
6 changes: 5 additions & 1 deletion polaris/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,8 @@
SingleTaskBenchmarkSpecification,
)

__all__ = ["BenchmarkSpecification", "SingleTaskBenchmarkSpecification", "MultiTaskBenchmarkSpecification"]
__all__ = [
"BenchmarkSpecification",
"SingleTaskBenchmarkSpecification",
"MultiTaskBenchmarkSpecification",
]
120 changes: 55 additions & 65 deletions polaris/benchmark/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from polaris.utils.types import (
AccessType,
HubOwner,
PredictionsType,
IncomingPredictionsType,
SplitType,
TargetType,
TaskType,
Expand Down Expand Up @@ -131,7 +131,9 @@ def _validate_cols(cls, v, info: ValidationInfo):
if info.data.get("dataset") is not None and not all(
c in info.data["dataset"].table.columns for c in v
):
raise InvalidBenchmarkError("Not all specified target columns were found in the dataset.")
raise InvalidBenchmarkError("Not all specified columns were found in the dataset.")
if len(set(v)) != len(v):
raise InvalidBenchmarkError("The task specifies duplicate columns")
return v

@field_validator("metrics")
Expand Down Expand Up @@ -173,19 +175,18 @@ def _validate_split(self) -> Self:
4) There is no overlap between the train and test set
5) No row exists in the test set where all labels are missing/empty
"""

if not isinstance(self.split[1], dict):
self.split = self.split[0], {"test": self.split[1]}
split = self.split

# Train partition can be empty (zero-shot)
# Test partitions cannot be empty
if (isinstance(split[1], dict) and any(len(v) == 0 for v in split[1].values())) or (
not isinstance(split[1], dict) and len(split[1]) == 0
):
if any(len(v) == 0 for v in split[1].values()):
raise InvalidBenchmarkError("The predefined split contains empty test partitions")

train_idx_list = split[0]
full_test_idx_list = (
list(chain.from_iterable(split[1].values())) if isinstance(split[1], dict) else split[1]
)
full_test_idx_list = list(chain.from_iterable(split[1].values()))

if len(train_idx_list) == 0:
logger.info(
Expand All @@ -206,14 +207,11 @@ def _validate_split(self) -> Self:
# Check for duplicate indices within a given test set. Because a user can specify
# multiple test sets for a given benchmark and it is acceptable for indices to be shared
# across test sets, we check for duplicates in each test set independently.
if isinstance(split[1], dict):
for test_set_name, test_set_idx_list in split[1].items():
if len(test_set_idx_list) != len(set(test_set_idx_list)):
raise InvalidBenchmarkError(
f'Test set with name "{test_set_name}" contains duplicate indices'
)
elif len(full_test_idx_set) != len(full_test_idx_list):
raise InvalidBenchmarkError("The test set contains duplicate indices")
for test_set_name, test_set_idx_list in split[1].items():
if len(test_set_idx_list) != len(set(test_set_idx_list)):
raise InvalidBenchmarkError(
f'Test set with name "{test_set_name}" contains duplicate indices'
)

# All indices are valid given the dataset
dataset = self.dataset
Expand Down Expand Up @@ -307,18 +305,13 @@ def _compute_checksum(self):
for m in sorted(self.metrics, key=lambda k: k.name):
hash_fn.update(m.name.encode("utf-8"))

if not isinstance(self.split[1], dict):
split = self.split[0], {"test": self.split[1]}
else:
split = self.split

# Train set
s = json.dumps(sorted(split[0]))
s = json.dumps(sorted(self.split[0]))
hash_fn.update(s.encode("utf-8"))

# Test sets
for k in sorted(split[1].keys()):
s = json.dumps(sorted(split[1][k]))
for k in sorted(self.split[1].keys()):
s = json.dumps(sorted(self.split[1][k]))
hash_fn.update(k.encode("utf-8"))
hash_fn.update(s.encode("utf-8"))

Expand All @@ -335,7 +328,7 @@ def n_train_datapoints(self) -> int:
@property
def n_test_sets(self) -> int:
"""The number of test sets"""
return len(self.split[1]) if isinstance(self.split[1], dict) else 1
return len(self.split[1])

@computed_field
@property
Expand Down Expand Up @@ -370,6 +363,18 @@ def task_type(self) -> str:
v = TaskType.MULTI_TASK if len(self.target_cols) > 1 else TaskType.SINGLE_TASK
return v.value

@computed_field
@property
def test_set_labels(self) -> list[str]:
"""The labels of the test sets."""
return sorted(list(self.split[1].keys()))

@computed_field
@property
def test_set_sizes(self) -> list[str]:
"""The sizes of the test sets."""
return {k: len(v) for k, v in self.split[1].items()}

def _get_subset(self, indices, hide_targets=True, featurization_fn=None):
"""Returns a [`Subset`][polaris.dataset.Subset] using the given indices. Used
internally to construct the train and test sets."""
Expand All @@ -393,10 +398,7 @@ def make_test_subset(vals):
return self._get_subset(vals, hide_targets=hide_targets, featurization_fn=featurization_fn)

test_split = self.split[1]
if isinstance(test_split, dict):
test = {k: make_test_subset(v) for k, v in test_split.items()}
else:
test = make_test_subset(test_split)
test = {k: make_test_subset(v) for k, v in test_split.items()}

return test

Expand All @@ -422,27 +424,23 @@ def get_train_test_split(
train = self._get_subset(self.split[0], hide_targets=False, featurization_fn=featurization_fn)
test = self._get_test_set(hide_targets=True, featurization_fn=featurization_fn)

# For improved UX, we return the object instead of the dictionary if there is only one test set.
# Internally, however, assume that the test set is always a dictionary simplifies the code.
if len(test) == 1:
test = test["test"]
return train, test

def evaluate(
self, y_pred: Optional[PredictionsType] = None, y_prob: Optional[PredictionsType] = None
self,
y_pred: IncomingPredictionsType | None = None,
y_prob: IncomingPredictionsType | None = None,
) -> BenchmarkResults:
"""Execute the evaluation protocol for the benchmark, given a set of predictions.

info: What about `y_true`?
Contrary to other frameworks that you might be familiar with, we opted for a signature that includes just
the predictions. This reduces the chance of accidentally using the test targets during training.

info: Expected structure for `y_pred` and `y_prob` arguments
The supplied `y_pred` and `y_prob` arguments must adhere to a certain structure depending on the number of
tasks and test sets included in the benchmark. Refer to the following for guidance on the correct structure when
creating your `y_pred` and `y_prod` objects:

- Single task, single set: `[values...]`
- Multi-task, single set: `{task_name_1: [values...], task_name_2: [values...]}`
- Single task, multi-set: `{test_set_1: {task_name: [values...]}, test_set_2: {task_name: [values...]}}`
- Multi-task, multi-set: `{test_set_1: {task_name_1: [values...], task_name_2: [values...]}, test_set_2: {task_name_1: [values...], task_name_2: [values...]}}`

For this method, we make the following assumptions:

1. There can be one or multiple test set(s);
Expand All @@ -456,7 +454,8 @@ def evaluate(
If there are multiple targets, the predictions should be wrapped in a dictionary with the target labels as keys.
If there are multiple test sets, the predictions should be further wrapped in a dictionary
with the test subset labels as keys.
y_prob: The predicted probabilities for the test set, as NumPy arrays.
y_prob: The predicted probabilities for the test set, formatted similarly to predictions, based on the
number of tasks and test sets.

Returns:
A `BenchmarkResults` object. This object can be directly submitted to the Polaris Hub.
Expand All @@ -475,31 +474,22 @@ def evaluate(
"""

# Instead of having the user pass the ground truth, we extract it from the benchmark spec ourselves.
# The `evaluate_benchmark` function expects the benchmark labels to be of a certain structure which
# depends on the number of tasks and test sets defined for the benchmark. Below, we build the structure
# of the benchmark labels based on the aforementioned factors.
test = self._get_test_set(hide_targets=False)
if isinstance(test, dict):
#
# For multi-set benchmarks
y_true = {}
for test_set_name, values in test.items():
y_true[test_set_name] = {}
if isinstance(values.targets, dict):
#
# For multi-task, multi-set benchmarks
for task_name, values in values.targets.items():
y_true[test_set_name][task_name] = values
else:
#
# For single task, multi-set benchmarks
y_true[test_set_name][self.target_cols[0]] = values.targets
else:
#
# For single set benchmarks (single and multiple task)
y_true = test.targets
y_true_subset = self._get_test_set(hide_targets=False)
y_true_values = {k: v.targets for k, v in y_true_subset.items()}

# Simplify the case where there is only one test set
if len(y_true_values) == 1:
y_true_values = y_true_values["test"]

scores = evaluate_benchmark(self.target_cols, self.metrics, y_true, y_pred=y_pred, y_prob=y_prob)
scores = evaluate_benchmark(
target_cols=self.target_cols,
test_set_labels=self.test_set_labels,
test_set_sizes=self.test_set_sizes,
metrics=self.metrics,
y_true=y_true_values,
y_pred=y_pred,
y_prob=y_prob,
)

return BenchmarkResults(results=scores, benchmark_name=self.name, benchmark_owner=self.owner)

Expand Down
8 changes: 5 additions & 3 deletions polaris/evaluate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from polaris.evaluate._metric import Metric, MetricInfo
from polaris.evaluate._predictions import BenchmarkPredictions
from polaris.evaluate._results import (
BenchmarkResults,
ResultsType,
CompetitionResults,
CompetitionPredictions,
ResultsMetadata,
CompetitionResults,
EvaluationResult,
ResultsMetadata,
ResultsType,
)
from polaris.evaluate.utils import evaluate_benchmark

Expand All @@ -19,4 +20,5 @@
"ResultsType",
"evaluate_benchmark",
"CompetitionPredictions",
"BenchmarkPredictions",
]
14 changes: 6 additions & 8 deletions polaris/evaluate/_metric.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,28 @@
from enum import Enum
from typing import Callable, Literal, Optional
from typing import Callable, Literal

import numpy as np
from pydantic import BaseModel, Field

from sklearn.metrics import (
accuracy_score,
average_precision_score,
balanced_accuracy_score,
explained_variance_score,
f1_score,
matthews_corrcoef,
mean_absolute_error,
mean_squared_error,
r2_score,
roc_auc_score,
balanced_accuracy_score,
)

from polaris.evaluate.metrics import (
cohen_kappa_score,
absolute_average_fold_error,
spearman,
cohen_kappa_score,
pearsonr,
spearman,
)
from polaris.evaluate.metrics.docking_metrics import rmsd_coverage

from polaris.utils.types import DirectionType


Expand Down Expand Up @@ -107,7 +105,7 @@ def y_type(self) -> bool:
return self.value.y_type

def score(
self, y_true: np.ndarray, y_pred: Optional[np.ndarray] = None, y_prob: Optional[np.ndarray] = None
self, y_true: np.ndarray, y_pred: np.ndarray | None = None, y_prob: np.ndarray | None = None
) -> float:
"""Endpoint for computing the metric.

Expand All @@ -134,7 +132,7 @@ def score(
return self.fn(**kwargs, **self.value.kwargs)

def __call__(
self, y_true: np.ndarray, y_pred: Optional[np.ndarray] = None, y_prob: Optional[np.ndarray] = None
self, y_true: np.ndarray, y_pred: np.ndarray | None = None, y_prob: np.ndarray | None = None
) -> float:
"""For convenience, make metrics callable"""
return self.score(y_true, y_pred, y_prob)
Loading