Skip to content

Commit

Permalink
Chore: First class predictions (#187)
Browse files Browse the repository at this point in the history
* add base benchmark predictions class, move tests

* wip, validations working

* trying to make types work with arbitrary incoming values

* fix equality checking issue in test. closes #169.

* add serializer to predictions

* remove separate competition predictions

* update evaluation usage to work with benchmark predictions instance

* update test set generation for evaluation

* run ruff autoformatting

* Update polaris/utils/types.py

Nicer union syntax

Co-authored-by: Cas Wognum <caswognum@outlook.com>

* Update polaris/utils/types.py

Co-authored-by: Cas Wognum <caswognum@outlook.com>

* wip

* add small docstring, allow string predictions

* safely get predictions in evaluation if available

* pass test set names to predictions and check for validity

* simplify safe_mask

* fix bad docstring path

* Reintroduce the CompetitionPredictions class because it includes additional metadata

* Add back the CompetitonPredictions to the docs

* Reordered docs

* Improved documentation and changed logic to disallow some edge cases

* Fixed docs

* Remove print statement

* Reorganize code

* Simplified evaluation logic

* Address all PR feedback

* Add extra test case

* Addressed PR feedback

* Fix type hint and fix model validator definition

* Fix import

---------

Co-authored-by: Kira McLean <kiramclean@users.noreply.github.com>
Co-authored-by: Cas Wognum <caswognum@outlook.com>
  • Loading branch information
3 people authored Nov 19, 2024
1 parent 225b405 commit 5eee7ea
Show file tree
Hide file tree
Showing 17 changed files with 650 additions and 239 deletions.
6 changes: 5 additions & 1 deletion docs/api/evaluation.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
::: polaris.evaluate.BenchmarkPredictions

---

::: polaris.evaluate.ResultsMetadata
options:
filters: ["!^_"]
Expand Down Expand Up @@ -25,4 +29,4 @@
::: polaris.evaluate.metrics.generic_metrics
::: polaris.evaluate.metrics.docking_metrics

---
---
3 changes: 3 additions & 0 deletions docs/tutorials/competition.participate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@
"competition_predictions = CompetitionPredictions(\n",
" name=\"hello-world-result\",\n",
" predictions=predictions,\n",
" target_labels=competition.target_cols,\n",
" test_set_labels=competition.test_set_labels,\n",
" test_set_sizes=competition.test_set_sizes,\n",
" github_url=\"https://github.com/polaris-hub/polaris-hub\",\n",
" paper_url=\"https://polarishub.io/\",\n",
" description=\"Hello, World!\",\n",
Expand Down
14 changes: 8 additions & 6 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,21 @@ nav:
- PDB Datasets: tutorials/dataset_pdb.ipynb
- SDF Datasets: tutorials/dataset_sdf.ipynb
- Optimization: tutorials/optimization.ipynb
- Competitions:
- tutorials/competition.participate.ipynb
# NOTE (cwognum): Competitions are currently gated.
# - Competitions:
# - tutorials/competition.participate.ipynb
- API Reference:
- Load: api/load.md
- Core:
- Dataset: api/dataset.md
- Benchmark: api/benchmark.md
- Subset: api/subset.md
- Evaluation: api/evaluation.md
- Competitions:
- Competition Dataset: api/competition.dataset.md
- Competition: api/competition.md
- Competiton Evaluation: api/competition.evaluation.md
# NOTE (cwognum): Competitions are currently gated.
# - Competitions:
# - Competition Dataset: api/competition.dataset.md
# - Competition: api/competition.md
# - Competiton Evaluation: api/competition.evaluation.md
- Hub:
- Client: api/hub.client.md
- External Auth Client: api/hub.external_client.md
Expand Down
6 changes: 5 additions & 1 deletion polaris/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,8 @@
SingleTaskBenchmarkSpecification,
)

__all__ = ["BenchmarkSpecification", "SingleTaskBenchmarkSpecification", "MultiTaskBenchmarkSpecification"]
__all__ = [
"BenchmarkSpecification",
"SingleTaskBenchmarkSpecification",
"MultiTaskBenchmarkSpecification",
]
120 changes: 55 additions & 65 deletions polaris/benchmark/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from polaris.utils.types import (
AccessType,
HubOwner,
PredictionsType,
IncomingPredictionsType,
SplitType,
TargetType,
TaskType,
Expand Down Expand Up @@ -131,7 +131,9 @@ def _validate_cols(cls, v, info: ValidationInfo):
if info.data.get("dataset") is not None and not all(
c in info.data["dataset"].table.columns for c in v
):
raise InvalidBenchmarkError("Not all specified target columns were found in the dataset.")
raise InvalidBenchmarkError("Not all specified columns were found in the dataset.")
if len(set(v)) != len(v):
raise InvalidBenchmarkError("The task specifies duplicate columns")
return v

@field_validator("metrics")
Expand Down Expand Up @@ -173,19 +175,18 @@ def _validate_split(self) -> Self:
4) There is no overlap between the train and test set
5) No row exists in the test set where all labels are missing/empty
"""

if not isinstance(self.split[1], dict):
self.split = self.split[0], {"test": self.split[1]}
split = self.split

# Train partition can be empty (zero-shot)
# Test partitions cannot be empty
if (isinstance(split[1], dict) and any(len(v) == 0 for v in split[1].values())) or (
not isinstance(split[1], dict) and len(split[1]) == 0
):
if any(len(v) == 0 for v in split[1].values()):
raise InvalidBenchmarkError("The predefined split contains empty test partitions")

train_idx_list = split[0]
full_test_idx_list = (
list(chain.from_iterable(split[1].values())) if isinstance(split[1], dict) else split[1]
)
full_test_idx_list = list(chain.from_iterable(split[1].values()))

if len(train_idx_list) == 0:
logger.info(
Expand All @@ -206,14 +207,11 @@ def _validate_split(self) -> Self:
# Check for duplicate indices within a given test set. Because a user can specify
# multiple test sets for a given benchmark and it is acceptable for indices to be shared
# across test sets, we check for duplicates in each test set independently.
if isinstance(split[1], dict):
for test_set_name, test_set_idx_list in split[1].items():
if len(test_set_idx_list) != len(set(test_set_idx_list)):
raise InvalidBenchmarkError(
f'Test set with name "{test_set_name}" contains duplicate indices'
)
elif len(full_test_idx_set) != len(full_test_idx_list):
raise InvalidBenchmarkError("The test set contains duplicate indices")
for test_set_name, test_set_idx_list in split[1].items():
if len(test_set_idx_list) != len(set(test_set_idx_list)):
raise InvalidBenchmarkError(
f'Test set with name "{test_set_name}" contains duplicate indices'
)

# All indices are valid given the dataset
dataset = self.dataset
Expand Down Expand Up @@ -307,18 +305,13 @@ def _compute_checksum(self):
for m in sorted(self.metrics, key=lambda k: k.name):
hash_fn.update(m.name.encode("utf-8"))

if not isinstance(self.split[1], dict):
split = self.split[0], {"test": self.split[1]}
else:
split = self.split

# Train set
s = json.dumps(sorted(split[0]))
s = json.dumps(sorted(self.split[0]))
hash_fn.update(s.encode("utf-8"))

# Test sets
for k in sorted(split[1].keys()):
s = json.dumps(sorted(split[1][k]))
for k in sorted(self.split[1].keys()):
s = json.dumps(sorted(self.split[1][k]))
hash_fn.update(k.encode("utf-8"))
hash_fn.update(s.encode("utf-8"))

Expand All @@ -335,7 +328,7 @@ def n_train_datapoints(self) -> int:
@property
def n_test_sets(self) -> int:
"""The number of test sets"""
return len(self.split[1]) if isinstance(self.split[1], dict) else 1
return len(self.split[1])

@computed_field
@property
Expand Down Expand Up @@ -370,6 +363,18 @@ def task_type(self) -> str:
v = TaskType.MULTI_TASK if len(self.target_cols) > 1 else TaskType.SINGLE_TASK
return v.value

@computed_field
@property
def test_set_labels(self) -> list[str]:
"""The labels of the test sets."""
return sorted(list(self.split[1].keys()))

@computed_field
@property
def test_set_sizes(self) -> list[str]:
"""The sizes of the test sets."""
return {k: len(v) for k, v in self.split[1].items()}

def _get_subset(self, indices, hide_targets=True, featurization_fn=None):
"""Returns a [`Subset`][polaris.dataset.Subset] using the given indices. Used
internally to construct the train and test sets."""
Expand All @@ -393,10 +398,7 @@ def make_test_subset(vals):
return self._get_subset(vals, hide_targets=hide_targets, featurization_fn=featurization_fn)

test_split = self.split[1]
if isinstance(test_split, dict):
test = {k: make_test_subset(v) for k, v in test_split.items()}
else:
test = make_test_subset(test_split)
test = {k: make_test_subset(v) for k, v in test_split.items()}

return test

Expand All @@ -422,27 +424,23 @@ def get_train_test_split(
train = self._get_subset(self.split[0], hide_targets=False, featurization_fn=featurization_fn)
test = self._get_test_set(hide_targets=True, featurization_fn=featurization_fn)

# For improved UX, we return the object instead of the dictionary if there is only one test set.
# Internally, however, assume that the test set is always a dictionary simplifies the code.
if len(test) == 1:
test = test["test"]
return train, test

def evaluate(
self, y_pred: Optional[PredictionsType] = None, y_prob: Optional[PredictionsType] = None
self,
y_pred: IncomingPredictionsType | None = None,
y_prob: IncomingPredictionsType | None = None,
) -> BenchmarkResults:
"""Execute the evaluation protocol for the benchmark, given a set of predictions.
info: What about `y_true`?
Contrary to other frameworks that you might be familiar with, we opted for a signature that includes just
the predictions. This reduces the chance of accidentally using the test targets during training.
info: Expected structure for `y_pred` and `y_prob` arguments
The supplied `y_pred` and `y_prob` arguments must adhere to a certain structure depending on the number of
tasks and test sets included in the benchmark. Refer to the following for guidance on the correct structure when
creating your `y_pred` and `y_prod` objects:
- Single task, single set: `[values...]`
- Multi-task, single set: `{task_name_1: [values...], task_name_2: [values...]}`
- Single task, multi-set: `{test_set_1: {task_name: [values...]}, test_set_2: {task_name: [values...]}}`
- Multi-task, multi-set: `{test_set_1: {task_name_1: [values...], task_name_2: [values...]}, test_set_2: {task_name_1: [values...], task_name_2: [values...]}}`
For this method, we make the following assumptions:
1. There can be one or multiple test set(s);
Expand All @@ -456,7 +454,8 @@ def evaluate(
If there are multiple targets, the predictions should be wrapped in a dictionary with the target labels as keys.
If there are multiple test sets, the predictions should be further wrapped in a dictionary
with the test subset labels as keys.
y_prob: The predicted probabilities for the test set, as NumPy arrays.
y_prob: The predicted probabilities for the test set, formatted similarly to predictions, based on the
number of tasks and test sets.
Returns:
A `BenchmarkResults` object. This object can be directly submitted to the Polaris Hub.
Expand All @@ -475,31 +474,22 @@ def evaluate(
"""

# Instead of having the user pass the ground truth, we extract it from the benchmark spec ourselves.
# The `evaluate_benchmark` function expects the benchmark labels to be of a certain structure which
# depends on the number of tasks and test sets defined for the benchmark. Below, we build the structure
# of the benchmark labels based on the aforementioned factors.
test = self._get_test_set(hide_targets=False)
if isinstance(test, dict):
#
# For multi-set benchmarks
y_true = {}
for test_set_name, values in test.items():
y_true[test_set_name] = {}
if isinstance(values.targets, dict):
#
# For multi-task, multi-set benchmarks
for task_name, values in values.targets.items():
y_true[test_set_name][task_name] = values
else:
#
# For single task, multi-set benchmarks
y_true[test_set_name][self.target_cols[0]] = values.targets
else:
#
# For single set benchmarks (single and multiple task)
y_true = test.targets
y_true_subset = self._get_test_set(hide_targets=False)
y_true_values = {k: v.targets for k, v in y_true_subset.items()}

# Simplify the case where there is only one test set
if len(y_true_values) == 1:
y_true_values = y_true_values["test"]

scores = evaluate_benchmark(self.target_cols, self.metrics, y_true, y_pred=y_pred, y_prob=y_prob)
scores = evaluate_benchmark(
target_cols=self.target_cols,
test_set_labels=self.test_set_labels,
test_set_sizes=self.test_set_sizes,
metrics=self.metrics,
y_true=y_true_values,
y_pred=y_pred,
y_prob=y_prob,
)

return BenchmarkResults(results=scores, benchmark_name=self.name, benchmark_owner=self.owner)

Expand Down
8 changes: 5 additions & 3 deletions polaris/evaluate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from polaris.evaluate._metric import Metric, MetricInfo
from polaris.evaluate._predictions import BenchmarkPredictions
from polaris.evaluate._results import (
BenchmarkResults,
ResultsType,
CompetitionResults,
CompetitionPredictions,
ResultsMetadata,
CompetitionResults,
EvaluationResult,
ResultsMetadata,
ResultsType,
)
from polaris.evaluate.utils import evaluate_benchmark

Expand All @@ -19,4 +20,5 @@
"ResultsType",
"evaluate_benchmark",
"CompetitionPredictions",
"BenchmarkPredictions",
]
14 changes: 6 additions & 8 deletions polaris/evaluate/_metric.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,28 @@
from enum import Enum
from typing import Callable, Literal, Optional
from typing import Callable, Literal

import numpy as np
from pydantic import BaseModel, Field

from sklearn.metrics import (
accuracy_score,
average_precision_score,
balanced_accuracy_score,
explained_variance_score,
f1_score,
matthews_corrcoef,
mean_absolute_error,
mean_squared_error,
r2_score,
roc_auc_score,
balanced_accuracy_score,
)

from polaris.evaluate.metrics import (
cohen_kappa_score,
absolute_average_fold_error,
spearman,
cohen_kappa_score,
pearsonr,
spearman,
)
from polaris.evaluate.metrics.docking_metrics import rmsd_coverage

from polaris.utils.types import DirectionType


Expand Down Expand Up @@ -107,7 +105,7 @@ def y_type(self) -> bool:
return self.value.y_type

def score(
self, y_true: np.ndarray, y_pred: Optional[np.ndarray] = None, y_prob: Optional[np.ndarray] = None
self, y_true: np.ndarray, y_pred: np.ndarray | None = None, y_prob: np.ndarray | None = None
) -> float:
"""Endpoint for computing the metric.
Expand All @@ -134,7 +132,7 @@ def score(
return self.fn(**kwargs, **self.value.kwargs)

def __call__(
self, y_true: np.ndarray, y_pred: Optional[np.ndarray] = None, y_prob: Optional[np.ndarray] = None
self, y_true: np.ndarray, y_pred: np.ndarray | None = None, y_prob: np.ndarray | None = None
) -> float:
"""For convenience, make metrics callable"""
return self.score(y_true, y_pred, y_prob)
Loading

0 comments on commit 5eee7ea

Please sign in to comment.