From 368357f7e3a5638ac606ded6c2e7604faab1d0cb Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 18 Oct 2021 12:40:00 -0500 Subject: [PATCH 1/3] new Accuracy --- examples/cnn.py | 11 +- tests/metrics/test_accuracy.py | 11 +- treex/metrics/__init__.py | 8 +- treex/metrics/accuracy.py | 288 ++++++ treex/metrics/tm_port/__init__.py | 5 + .../{ => tm_port}/classification/__init__.py | 2 +- .../{ => tm_port}/classification/accuracy.py | 6 +- .../classification/stat_scores.py | 6 +- .../{ => tm_port}/functional/__init__.py | 0 .../functional/classification/__init__.py | 0 .../functional/classification/accuracy.py | 10 +- .../functional/classification/stat_scores.py | 6 +- treex/metrics/tm_port/utilities/__init__.py | 7 + .../metrics/{ => tm_port}/utilities/checks.py | 9 +- treex/metrics/{ => tm_port}/utilities/data.py | 2 +- .../{ => tm_port}/utilities/distributed.py | 0 .../metrics/{ => tm_port}/utilities/enums.py | 0 .../{ => tm_port}/utilities/exceptions.py | 0 .../{ => tm_port}/utilities/imports.py | 0 .../metrics/{ => tm_port}/utilities/prints.py | 0 treex/metrics/utilities/__init__.py | 7 - treex/metrics/utils.py | 818 ++++++++++++++++++ 22 files changed, 1149 insertions(+), 47 deletions(-) create mode 100644 treex/metrics/accuracy.py create mode 100644 treex/metrics/tm_port/__init__.py rename treex/metrics/{ => tm_port}/classification/__init__.py (95%) rename treex/metrics/{ => tm_port}/classification/accuracy.py (98%) rename treex/metrics/{ => tm_port}/classification/stat_scores.py (98%) rename treex/metrics/{ => tm_port}/functional/__init__.py (100%) rename treex/metrics/{ => tm_port}/functional/classification/__init__.py (100%) rename treex/metrics/{ => tm_port}/functional/classification/accuracy.py (98%) rename treex/metrics/{ => tm_port}/functional/classification/stat_scores.py (98%) create mode 100644 treex/metrics/tm_port/utilities/__init__.py rename treex/metrics/{ => tm_port}/utilities/checks.py (98%) rename treex/metrics/{ => tm_port}/utilities/data.py (99%) rename treex/metrics/{ => tm_port}/utilities/distributed.py (100%) rename treex/metrics/{ => tm_port}/utilities/enums.py (100%) rename treex/metrics/{ => tm_port}/utilities/exceptions.py (100%) rename treex/metrics/{ => tm_port}/utilities/imports.py (100%) rename treex/metrics/{ => tm_port}/utilities/prints.py (100%) delete mode 100644 treex/metrics/utilities/__init__.py create mode 100644 treex/metrics/utils.py diff --git a/examples/cnn.py b/examples/cnn.py index e16d2956..6fb176a5 100644 --- a/examples/cnn.py +++ b/examples/cnn.py @@ -16,7 +16,6 @@ Batch = tp.Mapping[str, np.ndarray] Model = tx.Sequential -Metric = tx.metrics.Accuracy np.random.seed(420) @@ -31,7 +30,7 @@ def init_step( @jax.jit -def reset_step(metric: Metric) -> Metric: +def reset_step(metric: tx.Metric) -> tx.Metric: metric.reset() return metric @@ -59,10 +58,10 @@ def loss_fn( def train_step( model: Model, optimizer: tx.Optimizer, - metric: Metric, + metric: tx.Metric, x: jnp.ndarray, y: jnp.ndarray, -) -> tp.Tuple[jnp.ndarray, Model, tx.Optimizer, Metric]: +) -> tp.Tuple[jnp.ndarray, Model, tx.Optimizer, tx.Metric]: print("JITTTTING") params = model.parameters() @@ -79,8 +78,8 @@ def train_step( @jax.jit def test_step( - model: Model, metric: Metric, x: jnp.ndarray, y: jnp.ndarray -) -> tp.Tuple[jnp.ndarray, Metric]: + model: Model, metric: tx.Metric, x: jnp.ndarray, y: jnp.ndarray +) -> tp.Tuple[jnp.ndarray, tx.Metric]: loss, (model, y_pred) = loss_fn(model, model, x, y) diff --git a/tests/metrics/test_accuracy.py b/tests/metrics/test_accuracy.py index a7dc56a8..b5dee6c9 100644 --- a/tests/metrics/test_accuracy.py +++ b/tests/metrics/test_accuracy.py @@ -5,15 +5,15 @@ import pytest import treex as tx -from treex.metrics.classification.accuracy import Accuracy -from treex.metrics.utilities.enums import DataType +from treex.metrics.accuracy import Accuracy +from treex.metrics.utils import DataType class TestAccuracy: def test_jit(self): N = 0 - @jax.jit + # @jax.jit def f(m, y_true, y_pred): nonlocal N N += 1 @@ -21,8 +21,8 @@ def f(m, y_true, y_pred): return m metric = Accuracy(num_classes=10) - y_true = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - y_pred = jnp.array([0, 1, 2, 3, 0, 5, 6, 7, 0, 9]) + y_true = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])[None, None, None, :] + y_pred = jnp.array([0, 1, 2, 3, 0, 5, 6, 7, 0, 9])[None, None, None, :] metric = f(metric, y_true, y_pred) assert N == 1 @@ -54,7 +54,6 @@ def f(m, y_true, y_pred): ] ) - metric0 = metric metric = f(metric, y_true, y_pred) assert N == 1 assert metric.compute() == 0.8 diff --git a/treex/metrics/__init__.py b/treex/metrics/__init__.py index 53e7e28d..54a32a01 100644 --- a/treex/metrics/__init__.py +++ b/treex/metrics/__init__.py @@ -1,8 +1,2 @@ -import logging as __logging - -_logger = __logging.getLogger("treex") -_logger.addHandler(__logging.StreamHandler()) -_logger.setLevel(__logging.INFO) - -from .classification.accuracy import Accuracy +from .accuracy import Accuracy from .metric import Metric diff --git a/treex/metrics/accuracy.py b/treex/metrics/accuracy.py new file mode 100644 index 00000000..9d558233 --- /dev/null +++ b/treex/metrics/accuracy.py @@ -0,0 +1,288 @@ +import typing +import typing as tp + +import jax.numpy as jnp +import treeo as to + +from treex import types +from treex.metrics import utils as metric_utils +from treex.metrics.metric import Metric +from treex.metrics.utils import AverageMethod, DataType, MDMCAverageMethod + + +class Accuracy(Metric): + r""" + Computes Accuracy_: + + .. math:: + \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. + + For multi-class and multi-dimensional multi-class data with probability or logits predictions, the + parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the + top-K highest probability or logit score items are considered to find the correct label. + + For multi-label and multi-dimensional multi-class inputs, this metric computes the "glob" + accuracy by default, which counts all labels or sub-samples separately. This can be + changed to subset accuracy (which requires all labels or sub-samples in the sample to + be correctly predicted) by setting ``subset_accuracy=True``. + + Accepts all input types listed in :ref:`references/modules:input types`. + + Args: + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + threshold: + Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + .. note:: What is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + .. note:: If ``'none'`` and a given class doesn't occur in the `preds` or `target`, + the value for the class will be ``nan``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`references/modules:input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs + (see :ref:`references/modules:input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + + top_k: + Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class inputs. The + default value (``None``) will be interpreted as 1 for these inputs. + + Should be left at default (``None``) for all other types of inputs. + + multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + subset_accuracy: + Whether to compute subset accuracy for multi-label and multi-dimensional + multi-class inputs (has no effect for other input types). + + - For multi-label inputs, if the parameter is set to ``True``, then all labels for + each sample must be correctly predicted for the sample to count as correct. If it + is set to ``False``, then all labels are counted separately - this is equivalent to + flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). + + - For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all + sub-sample (on the extra axis) must be correct for the sample to be counted as correct. + If it is set to ``False``, then all sub-samples are counter separately - this is equivalent, + in the case of label predictions, to flattening the inputs beforehand (i.e. + ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter + still applies in both cases, if set. + + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + Raises: + ValueError: + If ``top_k`` is not an ``integer`` larger than ``0``. + ValueError: + If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + ValueError: + If two different input modes are provided, eg. using ``multi-label`` with ``multi-class``. + ValueError: + If ``top_k`` parameter is set for ``multi-label`` inputs. + + Example: + >>> import torch + >>> from torchmetrics import Accuracy + >>> target = torch.tensor([0, 1, 2, 3]) + >>> preds = torch.tensor([0, 2, 1, 3]) + >>> accuracy = Accuracy() + >>> accuracy(preds, target) + tensor(0.5000) + + >>> target = torch.tensor([0, 1, 2]) + >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) + >>> accuracy = Accuracy(top_k=2) + >>> accuracy(preds, target) + tensor(0.6667) + + """ + tp: jnp.ndarray = types.MetricState.node() + fp: jnp.ndarray = types.MetricState.node() + tn: jnp.ndarray = types.MetricState.node() + fn: jnp.ndarray = types.MetricState.node() + + def __init__( + self, + threshold: float = 0.5, + num_classes: typing.Optional[int] = None, + average: typing.Union[str, AverageMethod] = AverageMethod.MICRO, + mdmc_average: typing.Union[str, MDMCAverageMethod] = MDMCAverageMethod.GLOBAL, + ignore_index: typing.Optional[int] = None, + top_k: typing.Optional[int] = None, + multiclass: typing.Optional[bool] = None, + subset_accuracy: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: typing.Optional[typing.Any] = None, + dist_sync_fn: typing.Callable = None, + mode: DataType = DataType.MULTICLASS, + on: typing.Optional[types.IndexLike] = None, + name: typing.Optional[str] = None, + dtype: typing.Optional[jnp.dtype] = None, + ): + + super().__init__(on=on, name=name, dtype=dtype) + + if isinstance(average, str): + average = AverageMethod[average.upper()] + + if isinstance(mdmc_average, str): + mdmc_average = MDMCAverageMethod[mdmc_average.upper()] + + average = ( + AverageMethod.MACRO + if average in [AverageMethod.WEIGHTED, AverageMethod.NONE] + else average + ) + + if average not in [ + AverageMethod.MICRO, + AverageMethod.MACRO, + # AverageMethod.SAMPLES, + ]: + raise ValueError(f"The `reduce` {average} is not valid.") + + if average == AverageMethod.MACRO and (not num_classes or num_classes < 1): + raise ValueError( + "When you set `reduce` as 'macro', you have to provide the number of classes." + ) + + if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): + raise ValueError( + f"The `top_k` should be an integer larger than 0, got {top_k}" + ) + + if ( + num_classes + and ignore_index is not None + and (not 0 <= ignore_index < num_classes or num_classes == 1) + ): + raise ValueError( + f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes" + ) + + # Update states + if average == AverageMethod.SAMPLES: + raise ValueError(f"The `average` method '{average}' is not yet supported.") + + if mdmc_average == MDMCAverageMethod.SAMPLEWISE: + raise ValueError( + f"The `mdmc_average` method '{mdmc_average}' is not yet supported." + ) + + self.average = average + self.mdmc_average = mdmc_average + self.num_classes = num_classes + self.threshold = threshold + self.multiclass = multiclass + self.ignore_index = ignore_index + self.top_k = top_k + self.subset_accuracy = subset_accuracy + self.mode = mode + + # nodes + if average == AverageMethod.MICRO: + zeros_shape = [] + elif average == AverageMethod.MACRO: + zeros_shape = [num_classes] + else: + raise ValueError(f'Wrong reduce="{average}"') + + initial_value = jnp.zeros(zeros_shape, dtype=jnp.uint32) + + self.tp = initial_value + self.fp = initial_value + self.tn = initial_value + self.fn = initial_value + + def update(self, y_pred: jnp.ndarray, y_true: jnp.ndarray) -> None: # type: ignore + """Update state with predictions and targets. See + :ref:`references/modules:input types` for more information on input + types. + + Args: + preds: Predictions from model (logits, probabilities, or labels) + target: Ground truth labels + """ + + tp, fp, tn, fn = metric_utils._stat_scores_update( + y_pred, + y_true, + intended_mode=self.mode, + average_method=self.average, + mdmc_average_method=self.mdmc_average, + threshold=self.threshold, + num_classes=self.num_classes, + top_k=self.top_k, + multiclass=self.multiclass, + ) + + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + + def compute(self) -> jnp.ndarray: + """Computes accuracy based on inputs passed in to ``update`` previously.""" + # if self.mode is None: + # raise RuntimeError("You have to have determined mode.") + + return metric_utils._accuracy_compute( + self.tp, + self.fp, + self.tn, + self.fn, + self.average, + self.mdmc_average, + self.mode, + ) diff --git a/treex/metrics/tm_port/__init__.py b/treex/metrics/tm_port/__init__.py new file mode 100644 index 00000000..5faf889a --- /dev/null +++ b/treex/metrics/tm_port/__init__.py @@ -0,0 +1,5 @@ +import logging as __logging + +_logger = __logging.getLogger("treex") +_logger.addHandler(__logging.StreamHandler()) +_logger.setLevel(__logging.INFO) diff --git a/treex/metrics/classification/__init__.py b/treex/metrics/tm_port/classification/__init__.py similarity index 95% rename from treex/metrics/classification/__init__.py rename to treex/metrics/tm_port/classification/__init__.py index 765ce6f7..c73a6670 100644 --- a/treex/metrics/classification/__init__.py +++ b/treex/metrics/tm_port/classification/__init__.py @@ -1,4 +1,4 @@ -from treex.metrics.classification.accuracy import Accuracy # noqa: F401 +from treex.metrics.tm_port.classification.accuracy import Accuracy # noqa: F401 # from torchmetrics.classification.auc import AUC # noqa: F401 # from torchmetrics.classification.auroc import AUROC # noqa: F401 diff --git a/treex/metrics/classification/accuracy.py b/treex/metrics/tm_port/classification/accuracy.py similarity index 98% rename from treex/metrics/classification/accuracy.py rename to treex/metrics/tm_port/classification/accuracy.py index d5ccf101..4cfc1404 100644 --- a/treex/metrics/classification/accuracy.py +++ b/treex/metrics/tm_port/classification/accuracy.py @@ -8,8 +8,8 @@ Tensor = jnp.ndarray tensor = jnp.array -from treex.metrics.classification.stat_scores import StatScores -from treex.metrics.functional.classification.accuracy import ( +from treex.metrics.tm_port.classification.stat_scores import StatScores +from treex.metrics.tm_port.functional.classification.accuracy import ( _accuracy_compute, _accuracy_update, _check_subset_validity, @@ -17,7 +17,7 @@ _subset_accuracy_compute, _subset_accuracy_update, ) -from treex.metrics.utilities.enums import DataType +from treex.metrics.tm_port.utilities.enums import DataType class Accuracy(StatScores): diff --git a/treex/metrics/classification/stat_scores.py b/treex/metrics/tm_port/classification/stat_scores.py similarity index 98% rename from treex/metrics/classification/stat_scores.py rename to treex/metrics/tm_port/classification/stat_scores.py index 0e92000b..48d75907 100644 --- a/treex/metrics/classification/stat_scores.py +++ b/treex/metrics/tm_port/classification/stat_scores.py @@ -6,12 +6,12 @@ Tensor = jnp.ndarray tensor = jnp.array -from treex.metrics.functional.classification.stat_scores import ( +from treex.metrics.tm_port.functional.classification.stat_scores import ( _stat_scores_compute, _stat_scores_update, ) -from treex.metrics.metric import Metric -from treex.metrics.utilities.enums import AverageMethod, MDMCAverageMethod +from treex.metrics.tm_port.metric import Metric +from treex.metrics.tm_port.utilities.enums import AverageMethod, MDMCAverageMethod class StatScores(Metric): diff --git a/treex/metrics/functional/__init__.py b/treex/metrics/tm_port/functional/__init__.py similarity index 100% rename from treex/metrics/functional/__init__.py rename to treex/metrics/tm_port/functional/__init__.py diff --git a/treex/metrics/functional/classification/__init__.py b/treex/metrics/tm_port/functional/classification/__init__.py similarity index 100% rename from treex/metrics/functional/classification/__init__.py rename to treex/metrics/tm_port/functional/classification/__init__.py diff --git a/treex/metrics/functional/classification/accuracy.py b/treex/metrics/tm_port/functional/classification/accuracy.py similarity index 98% rename from treex/metrics/functional/classification/accuracy.py rename to treex/metrics/tm_port/functional/classification/accuracy.py index 9414f0e1..ff758737 100644 --- a/treex/metrics/functional/classification/accuracy.py +++ b/treex/metrics/tm_port/functional/classification/accuracy.py @@ -5,16 +5,20 @@ Tensor = jnp.ndarray tensor = jnp.array -from treex.metrics.functional.classification.stat_scores import ( +from treex.metrics.tm_port.functional.classification.stat_scores import ( _reduce_stat_scores, _stat_scores_update, ) -from treex.metrics.utilities.checks import ( +from treex.metrics.tm_port.utilities.checks import ( _check_classification_inputs, _input_format_classification, _input_squeeze, ) -from treex.metrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod +from treex.metrics.tm_port.utilities.enums import ( + AverageMethod, + DataType, + MDMCAverageMethod, +) def _check_subset_validity(mode: DataType) -> bool: diff --git a/treex/metrics/functional/classification/stat_scores.py b/treex/metrics/tm_port/functional/classification/stat_scores.py similarity index 98% rename from treex/metrics/functional/classification/stat_scores.py rename to treex/metrics/tm_port/functional/classification/stat_scores.py index e3e72dc5..2fa712ed 100644 --- a/treex/metrics/functional/classification/stat_scores.py +++ b/treex/metrics/tm_port/functional/classification/stat_scores.py @@ -4,8 +4,8 @@ Tensor = jnp.ndarray -from treex.metrics.utilities.checks import _input_format_classification -from treex.metrics.utilities.enums import AverageMethod, MDMCAverageMethod +from treex.metrics.tm_port.utilities.checks import _input_format_classification +from treex.metrics.tm_port.utilities.enums import AverageMethod, MDMCAverageMethod def _del_column(data: Tensor, idx: int) -> Tensor: @@ -374,7 +374,7 @@ def stat_scores( If inputs are ``multi-dimensional multi-class`` and ``mdmc_reduce`` is not provided. Example: - >>> from treex.metrics.functional import stat_scores + >>> from treex.metrics.tm_port.functional import stat_scores >>> preds = torch.tensor([1, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) >>> stat_scores(preds, target, reduce='macro', num_classes=3) diff --git a/treex/metrics/tm_port/utilities/__init__.py b/treex/metrics/tm_port/utilities/__init__.py new file mode 100644 index 00000000..bb0ebf54 --- /dev/null +++ b/treex/metrics/tm_port/utilities/__init__.py @@ -0,0 +1,7 @@ +# from treex.metrics.tm_port.utilities.data import apply_to_collection # noqa: F401 +# from treex.metrics.tm_port.utilities.distributed import class_reduce, reduce # noqa: F401 +# from treex.metrics.tm_port.utilities.prints import ( +# rank_zero_debug, +# rank_zero_info, +# rank_zero_warn, +# ) # noqa: F401 diff --git a/treex/metrics/utilities/checks.py b/treex/metrics/tm_port/utilities/checks.py similarity index 98% rename from treex/metrics/utilities/checks.py rename to treex/metrics/tm_port/utilities/checks.py index 24d864d9..f0f38167 100644 --- a/treex/metrics/utilities/checks.py +++ b/treex/metrics/tm_port/utilities/checks.py @@ -7,8 +7,8 @@ Tensor = jnp.ndarray -from treex.metrics.utilities.data import select_topk, to_onehot -from treex.metrics.utilities.enums import DataType +from treex.metrics.tm_port.utilities.data import select_topk, to_onehot +from treex.metrics.tm_port.utilities.enums import DataType def _check_same_shape(preds: Tensor, target: Tensor) -> None: @@ -416,11 +416,6 @@ def _input_format_classification( # Remove excess dimensions preds, target = _input_squeeze(preds, target) - # Convert half precision tensors to full precision, as not all ops are supported - # for example, min() is not supported - if preds.dtype == jnp.float16: - preds = preds.float() - case = _check_classification_inputs( preds, target, diff --git a/treex/metrics/utilities/data.py b/treex/metrics/tm_port/utilities/data.py similarity index 99% rename from treex/metrics/utilities/data.py rename to treex/metrics/tm_port/utilities/data.py index 4b351ada..5af35bf0 100644 --- a/treex/metrics/utilities/data.py +++ b/treex/metrics/tm_port/utilities/data.py @@ -7,7 +7,7 @@ Tensor = jnp.ndarray tensor = jnp.array -from treex.metrics.utilities.prints import rank_zero_warn +from treex.metrics.tm_port.utilities.prints import rank_zero_warn METRIC_EPS = 1e-6 diff --git a/treex/metrics/utilities/distributed.py b/treex/metrics/tm_port/utilities/distributed.py similarity index 100% rename from treex/metrics/utilities/distributed.py rename to treex/metrics/tm_port/utilities/distributed.py diff --git a/treex/metrics/utilities/enums.py b/treex/metrics/tm_port/utilities/enums.py similarity index 100% rename from treex/metrics/utilities/enums.py rename to treex/metrics/tm_port/utilities/enums.py diff --git a/treex/metrics/utilities/exceptions.py b/treex/metrics/tm_port/utilities/exceptions.py similarity index 100% rename from treex/metrics/utilities/exceptions.py rename to treex/metrics/tm_port/utilities/exceptions.py diff --git a/treex/metrics/utilities/imports.py b/treex/metrics/tm_port/utilities/imports.py similarity index 100% rename from treex/metrics/utilities/imports.py rename to treex/metrics/tm_port/utilities/imports.py diff --git a/treex/metrics/utilities/prints.py b/treex/metrics/tm_port/utilities/prints.py similarity index 100% rename from treex/metrics/utilities/prints.py rename to treex/metrics/tm_port/utilities/prints.py diff --git a/treex/metrics/utilities/__init__.py b/treex/metrics/utilities/__init__.py deleted file mode 100644 index 8f787cee..00000000 --- a/treex/metrics/utilities/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# from treex.metrics.utilities.data import apply_to_collection # noqa: F401 -# from treex.metrics.utilities.distributed import class_reduce, reduce # noqa: F401 -# from treex.metrics.utilities.prints import ( -# rank_zero_debug, -# rank_zero_info, -# rank_zero_warn, -# ) # noqa: F401 diff --git a/treex/metrics/utils.py b/treex/metrics/utils.py new file mode 100644 index 00000000..ef78d6e4 --- /dev/null +++ b/treex/metrics/utils.py @@ -0,0 +1,818 @@ +import enum +import typing +import typing as tp + +import einops +import jax +import jax.numpy as jnp +import numpy as np + + +class DataType(enum.Enum): + """Enum to represent data type. + + >>> "Binary" in list(DataType) + True + """ + + BINARY = enum.auto() + MULTILABEL = enum.auto() + MULTICLASS = enum.auto() + # MULTIDIM_MULTICLASS = enum.auto() + + +class AverageMethod(enum.Enum): + """Enum to represent average method. + + >>> None in list(AverageMethod) + True + >>> AverageMethod.NONE == None + True + >>> AverageMethod.NONE == 'none' + True + """ + + MICRO = enum.auto() + MACRO = enum.auto() + WEIGHTED = enum.auto() + NONE = enum.auto() + SAMPLES = enum.auto() + + +class MDMCAverageMethod(enum.Enum): + """Enum to represent multi-dim multi-class average method.""" + + GLOBAL = enum.auto() + SAMPLEWISE = enum.auto() + + +def _input_squeeze( + preds: jnp.ndarray, + target: jnp.ndarray, +) -> tp.Tuple[jnp.ndarray, jnp.ndarray]: + """Remove excess dimensions.""" + if preds.shape[0] == 1: + preds = jnp.expand_dims(preds.squeeze(), axis=0) + target = jnp.expand_dims(target.squeeze(), axis=0) + else: + preds, target = preds.squeeze(), target.squeeze() + return preds, target + + +def _stat_scores_update( + preds: jnp.ndarray, + target: jnp.ndarray, + intended_mode: DataType, + average_method: AverageMethod = AverageMethod.MICRO, + mdmc_average_method: tp.Optional[MDMCAverageMethod] = None, + num_classes: tp.Optional[int] = None, + top_k: tp.Optional[int] = None, + threshold: float = 0.5, + multiclass: tp.Optional[bool] = None, +) -> tp.Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Updates and returns the the number of true positives, false positives, true negatives, false negatives. + Raises ValueError if: + + - The `ignore_index` is not valid + - When `ignore_index` is used with binary data + - When inputs are multi-dimensional multi-class, and the `mdmc_average` parameter is not set + + Args: + preds: Predicted tensor + target: Ground truth tensor + reduce: Defines the reduction that is applied + mdmc_average: Defines how the multi-dimensional multi-class inputs are handeled + num_classes: Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + top_k: Number of highest probability or logit score predictions considered to find the correct label, + relevant only for (multi-dimensional) multi-class inputs + threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case + of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities + multiclass: Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be + ignore_index: Specify a class (label) to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and + ``reduce='macro'``, the class statistics for the ignored class will all be returned + as ``-1``. + """ + + preds, target, mode = _input_format_classification( + preds, + target, + mode=intended_mode, + threshold=threshold, + num_classes=num_classes, + multiclass=multiclass, + top_k=top_k, + ) + + if intended_mode != mode: + raise ValueError( + f"The intended mode '{intended_mode}' does not match the found mode '{mode}'." + ) + + if mode == DataType.MULTILABEL and top_k: + raise ValueError( + "You can not use the `top_k` parameter to calculate accuracy for multi-label inputs." + ) + + if preds.ndim == 3: + if mdmc_average_method is None: + raise ValueError( + "When your inputs are multi-dimensional multi-class, you have to set the `mdmc_average` parameter" + ) + if mdmc_average_method == MDMCAverageMethod.GLOBAL: + preds = jnp.swapaxes(preds, 1, 2).reshape(-1, preds.shape[1]) + target = jnp.swapaxes(target, 1, 2).reshape(-1, target.shape[1]) + + tp, fp, tn, fn = _stat_scores(preds, target, reduce=average_method) + + return tp, fp, tn, fn + + +def _stat_scores( + preds: jnp.ndarray, + target: jnp.ndarray, + reduce: tp.Optional[AverageMethod] = AverageMethod.MICRO, +) -> tp.Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Calculate the number of tp, fp, tn, fn. + + Args: + preds: + An ``(N, C)`` or ``(N, C, X)`` tensor of predictions (0 or 1) + target: + An ``(N, C)`` or ``(N, C, X)`` tensor of true labels (0 or 1) + reduce: + One of ``'MICRO'``, ``'macro'``, ``'samples'`` + + Return: + Returns a list of 4 tensors; tp, fp, tn, fn. + The shape of the returned tensors depnds on the shape of the inputs + and the ``reduce`` parameter: + + If inputs are of the shape ``(N, C)``, then + - If ``reduce='MICRO'``, the returned tensors are 1 element tensors + - If ``reduce='macro'``, the returned tensors are ``(C,)`` tensors + - If ``reduce'samples'``, the returned tensors are ``(N,)`` tensors + + If inputs are of the shape ``(N, C, X)``, then + - If ``reduce='MICRO'``, the returned tensors are ``(N,)`` tensors + - If ``reduce='macro'``, the returned tensors are ``(N,C)`` tensors + - If ``reduce='samples'``, the returned tensors are ``(N,X)`` tensors + """ + + dim: typing.Union[int, typing.List[int]] = 1 # for "samples" + if reduce == AverageMethod.MICRO: + dim = [0, 1] if preds.ndim == 2 else [1, 2] + elif reduce == AverageMethod.MACRO: + dim = 0 if preds.ndim == 2 else 2 + + true_pred, false_pred = target == preds, target != preds + pos_pred, neg_pred = preds == 1, preds == 0 + + tp = (true_pred * pos_pred).sum(axis=dim) + fp = (false_pred * pos_pred).sum(axis=dim) + + tn = (true_pred * neg_pred).sum(axis=dim) + fn = (false_pred * neg_pred).sum(axis=dim) + + return ( + tp.astype(jnp.uint32), + fp.astype(jnp.uint32), + tn.astype(jnp.uint32), + fn.astype(jnp.uint32), + ) + + +def _input_format_classification( + preds: jnp.ndarray, + target: jnp.ndarray, + mode: DataType, + threshold: float = 0.5, + top_k: tp.Optional[int] = None, + num_classes: tp.Optional[int] = None, + multiclass: tp.Optional[bool] = None, +) -> tp.Tuple[jnp.ndarray, jnp.ndarray, DataType]: + """Convert preds and target tensors into common format. + + Preds and targets are supposed to fall into one of these categories (and are + validated to make sure this is the case): + + * Both preds and target are of shape ``(N,)``, and both are integers (multi-class) + * Both preds and target are of shape ``(N,)``, and target is binary, while preds + are a float (binary) + * preds are of shape ``(N, C)`` and are floats, and target is of shape ``(N,)`` and + is integer (multi-class) + * preds and target are of shape ``(N, ...)``, target is binary and preds is a float + (multi-label) + * preds are of shape ``(N, C, ...)`` and are floats, target is of shape ``(N, ...)`` + and is integer (multi-dimensional multi-class) + * preds and target are of shape ``(N, ...)`` both are integers (multi-dimensional + multi-class) + + To avoid ambiguities, all dimensions of size 1, except the first one, are squeezed out. + + The returned output tensors will be binary tensors of the same shape, either ``(N, C)`` + of ``(N, C, X)``, the details for each case are described below. The function also returns + a ``case`` string, which describes which of the above cases the inputs belonged to - regardless + of whether this was "overridden" by other settings (like ``multiclass``). + + In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed + into a binary tensor (elements become 1 if the probability is greater than or equal to + ``threshold`` or 0 otherwise). If ``multiclass=True``, then then both targets are preds + become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to + preds first. + + In multi-class case, normally both preds and targets become ``(N, C)`` binary tensors; targets + by a one-hot transformation and preds by selecting ``top_k`` largest entries (if their original + shape was ``(N,C)``). However, if ``multiclass=False``, then targets and preds will be + returned as ``(N,1)`` tensor. + + In multi-label case, normally targets and preds are returned as ``(N, C)`` binary tensors, with + preds being binarized as in the binary case. Here the ``C`` dimension is obtained by flattening + all dimensions after the first one. However if ``multiclass=True``, then both are returned as + ``(N, 2, C)``, by an equivalent transformation as in the binary case. + + In multi-dimensional multi-class case, normally both target and preds are returned as + ``(N, C, X)`` tensors, with ``X`` resulting from flattening of all dimensions except ``N`` and + ``C``. The transformations performed here are equivalent to the multi-class case. However, if + ``multiclass=False`` (and there are up to two classes), then the data is returned as + ``(N, X)`` binary tensors (multi-label). + + Note: + Where a one-hot transformation needs to be performed and the number of classes + is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be + equal to ``num_classes``, if it is given, or the maximum label value in preds and + target. + + Args: + preds: jnp.ndarray with predictions (labels or probabilities) + target: jnp.ndarray with ground truth labels, always integers (labels) + threshold: + Threshold value for transforming probability/logit predictions to binary + (0 or 1) predictions, in the case of binary or multi-label inputs. + num_classes: + Number of classes. If not explicitly set, the number of classes will be inferred + either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` + tensor, where applicable. + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interepreted as 1 for these inputs. + + Should be left unset (``None``) for all other types of inputs. + multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + Returns: + preds: binary tensor of shape ``(N, C)`` or ``(N, C, X)`` + target: binary tensor of shape ``(N, C)`` or ``(N, C, X)`` + case: The case the inputs fall in, one of ``'binary'``, ``'multi-class'``, ``'multi-label'`` or + ``'multi-dim multi-class'`` + """ + # Remove excess dimensions + preds, target = _input_squeeze(preds, target) + + case = _check_classification_inputs( + preds, + target, + mode=mode, + threshold=threshold, + num_classes=num_classes, + multiclass=multiclass, + top_k=top_k, + ) + + if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k: + preds = (preds >= threshold).int() + num_classes = num_classes if not multiclass else 2 + + if case == DataType.MULTILABEL and top_k: + preds = select_topk(preds, top_k) + + if case == DataType.MULTICLASS or multiclass: + if _is_floating_point(preds): + num_classes = preds.shape[1] + preds = select_topk(preds, top_k or 1) + else: + if num_classes is None: + raise ValueError( + f"Cannot infer number of classes when preds are integers with no class dimension, please specify `num_classes`, got shape {preds.shape}" + ) + + preds = jax.nn.one_hot(preds, max(2, num_classes)) + + target = jax.nn.one_hot(target, max(2, num_classes)) # type: ignore + + if multiclass is False: + preds, target = preds[..., 1], target[..., 1] + + if (case == DataType.MULTICLASS and multiclass == True) or multiclass: + # target = target.reshape(-1, target.shape[-2], target.shape[-1]) + # preds = preds.reshape( -1, preds.shape[-2], preds.shape[-1]) + target = einops.rearrange(target, "... N C -> (...) N C") + preds = einops.rearrange(preds, "... N C -> (...) N C") + else: + # target = target.reshape(target.shape[0], -1) + # preds = preds.reshape(preds.shape[0], -1) + target = einops.rearrange(target, "... N -> (...) N") + preds = einops.rearrange(preds, "... N -> (...) N") + + # Some operations above create an extra dimension for MC/binary case - this removes it + if preds.ndim > 2: + preds, target = preds.squeeze(0), target.squeeze(0) + + return preds.astype(jnp.int32), target.astype(jnp.int32), case + + +def _check_classification_inputs( + preds: jnp.ndarray, + target: jnp.ndarray, + threshold: float, + num_classes: tp.Optional[int], + multiclass: tp.Optional[bool], + top_k: tp.Optional[int], + mode: DataType, +) -> DataType: + """Performs error checking on inputs for classification. + + This ensures that preds and target take one of the shape/type combinations that are + specified in ``_input_format_classification`` docstring. It also checks the cases of + over-rides with ``multiclass`` by checking (for multi-class and multi-dim multi-class + cases) that there are only up to 2 distinct labels. + + In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval. + + When ``num_classes`` is given, it is checked that it is consistent with input cases (binary, + multi-label, ...), and that, if available, the implied number of classes in the ``C`` + dimension is consistent with it (as well as that max label in target is smaller than it). + + When ``num_classes`` is not specified in these cases, consistency of the highest target + value against ``C`` dimension is checked for (multi-dimensional) multi-class cases. + + If ``top_k`` is set (not None) for inputs that do not have probability predictions (and + are not binary), an error is raised. Similarly if ``top_k`` is set to a number that + is higher than or equal to the ``C`` dimension of ``preds``, an error is raised. + + Preds and target tensors are expected to be squeezed already - all dimensions should be + greater than 1, except perhaps the first one (``N``). + + Args: + preds: jnp.ndarray with predictions (labels or probabilities) + target: jnp.ndarray with ground truth labels, always integers (labels) + threshold: + Threshold value for transforming probability/logit predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. + num_classes: + Number of classes. If not explicitly set, the number of classes will be inferred + either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` + tensor, where applicable. + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. The default value (``None``) will be + interpreted as 1 for these inputs. If this parameter is set for multi-label inputs, + it will take precedence over threshold. + + Should be left unset (``None``) for inputs with label predictions. + multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + + Return: + case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or + 'multi-dim multi-class' + """ + + # Basic validation (that does not need case/type information) + _basic_input_validation(preds, target, threshold, multiclass) + + # Check that shape/types fall into one of the cases + _check_shape_and_type_consistency(preds, target, mode) + implied_classes = preds.shape[-1] if preds.shape != target.shape else None + + if ( + implied_classes is not None + and num_classes is not None + and implied_classes != num_classes + ): + raise ValueError( + f"Number of classes in preds ({implied_classes}) and target ({num_classes}) do not match" + ) + + # Check consistency with the `C` dimension in case of multi-class data + if preds.shape != target.shape: + + if multiclass is False and implied_classes != 2: + raise ValueError( + "You have set `multiclass=False`, but have more than 2 classes in your data," + " based on the C dimension of `preds`." + ) + + # Check that num_classes is consistent + if num_classes: + + if mode == DataType.BINARY: + _check_num_classes_binary(num_classes, multiclass, implied_classes) + elif mode == DataType.MULTICLASS: + _check_num_classes_mc( + preds, target, num_classes, multiclass, implied_classes + ) + elif mode.MULTILABEL: + assert implied_classes is not None + _check_num_classes_ml(num_classes, multiclass, implied_classes) + + # Check that top_k is consistent + if top_k is not None: + assert implied_classes is not None + _check_top_k( + top_k, mode, implied_classes, multiclass, _is_floating_point(preds) + ) + + return mode + + +def _basic_input_validation( + preds: jnp.ndarray, + target: jnp.ndarray, + threshold: float, + multiclass: tp.Optional[bool], +) -> None: + """Perform basic validation of inputs that does not require deducing any information of the type of inputs.""" + + if _is_floating_point(target): + raise ValueError("The `target` has to be an integer tensor.") + # if target.min() < 0: + # raise ValueError("The `target` has to be a non-negative tensor.") + + preds_float = _is_floating_point(preds) + # if not preds_float and preds.min() < 0: + # raise ValueError("If `preds` are integers, they have to be non-negative.") + + if not preds.shape[0] == target.shape[0]: + raise ValueError( + "The `preds` and `target` should have the same first dimension." + ) + + if multiclass is False and target.max() > 1: + raise ValueError( + "If you set `multiclass=False`, then `target` should not exceed 1." + ) + + if multiclass is False and not preds_float and preds.max() > 1: + raise ValueError( + "If you set `multiclass=False` and `preds` are integers, then `preds` should not exceed 1." + ) + + +def _is_floating_point(x: jnp.ndarray) -> bool: + """Check if the input is a floating point tensor.""" + return x.dtype == jnp.float16 or x.dtype == jnp.float32 or x.dtype == jnp.float64 + + +def _check_shape_and_type_consistency( + preds: jnp.ndarray, target: jnp.ndarray, mode: DataType +) -> None: + """This checks that the shape and type of inputs are consistent with each other and fall into one of the + allowed input types (see the documentation of docstring of ``_input_format_classification``). It does not check + for consistency of number of classes, other functions take care of that. + + It returns the name of the case in which the inputs fall, and the implied number of classes (from the ``C`` dim for + multi-class data, or extra dim(s) for multi-label data). + """ + + preds_float = _is_floating_point(preds) + + if preds.ndim == target.ndim: + if preds.shape != target.shape: + raise ValueError( + "The `preds` and `target` should have the same shape,", + f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", + ) + + elif preds.ndim == target.ndim + 1: + if mode == DataType.BINARY: + raise ValueError( + "If `preds` have a 1 extra dimension, then `mode` should not be `binary`." + ) + + if not preds_float: + raise ValueError( + "If `preds` have one dimension more than `target`, `preds` should be a float tensor." + ) + if preds.shape[:-1] != target.shape: + raise ValueError( + "If `preds` have one dimension more than `target`, the shape of `preds` should be" + " (..., C), and the shape of `target` should be (...)." + ) + + else: + raise ValueError( + "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" + " and `preds` should be (N, C, ...)." + ) + + +def _check_num_classes_binary( + num_classes: int, multiclass: tp.Optional[bool], implied_classes: tp.Optional[int] +) -> None: + """This checks that the consistency of `num_classes` with the data and `multiclass` param for binary data.""" + + if implied_classes is not None and implied_classes != 2: + raise ValueError( + "If `preds` have one dimension more than `target`, then `num_classes` should be 2 for binary data." + ) + + if num_classes > 2: + raise ValueError("Your data is binary, but `num_classes` is larger than 2.") + if num_classes == 2 and not multiclass: + raise ValueError( + "Your data is binary and `num_classes=2`, but `multiclass` is not True." + " Set it to True if you want to transform binary data to multi-class format." + ) + if num_classes == 1 and multiclass: + raise ValueError( + "You have binary data and have set `multiclass=True`, but `num_classes` is 1." + " Either set `multiclass=None`(default) or set `num_classes=2`" + " to transform binary data to multi-class format." + ) + + +def select_topk(prob_tensor: jnp.ndarray, topk: int = 1, dim: int = 1) -> jnp.ndarray: + """Convert a probability tensor to binary by selecting top-k highest entries. + + Args: + prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the + position defined by the ``dim`` argument + topk: number of highest entries to turn into 1s + dim: dimension on which to compare entries + + Returns: + A binary tensor of the same shape as the input tensor of type torch.int32 + + Example: + >>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) + >>> select_topk(x, topk=2) + tensor([[0, 1, 1], + [1, 1, 0]], dtype=torch.int32) + """ + + if prob_tensor.ndim > 2: + raise NotImplementedError( + "Support for arrays with more than 2 dimension is not yet supported" + ) + + zeros = jnp.zeros(prob_tensor.shape, dtype=jnp.uint32) + idx_axis0 = jnp.expand_dims(jnp.arange(prob_tensor.shape[0]), axis=1) + val, idx_axis1 = jax.lax.top_k(prob_tensor, topk) + + return jax.ops.index_update(zeros, jax.ops.index[idx_axis0, idx_axis1], 1) + + +def _check_num_classes_mc( + preds: jnp.ndarray, + target: jnp.ndarray, + num_classes: int, + multiclass: tp.Optional[bool], + implied_classes: tp.Optional[int], +) -> None: + """This checks that the consistency of `num_classes` with the data and `multiclass` param for (multi- + dimensional) multi-class data.""" + + if num_classes == 1 and multiclass is not False: + raise ValueError( + "You have set `num_classes=1`, but predictions are integers." + " If you want to convert (multi-dimensional) multi-class data with 2 classes" + " to binary/multi-label, set `multiclass=False`." + ) + if num_classes > 1: + if multiclass is False and implied_classes != num_classes: + raise ValueError( + "You have set `multiclass=False`, but the implied number of classes " + " (from shape of inputs) does not match `num_classes`. If you are trying to" + " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" + " should be either None or the product of the size of extra dimensions (...)." + " See Input Types in Metrics documentation." + ) + # if num_classes <= target.max(): + # raise ValueError( + # "The highest label in `target` should be smaller than `num_classes`." + # ) + if preds.shape != target.shape and num_classes != implied_classes: + raise ValueError( + "The size of C dimension of `preds` does not match `num_classes`." + ) + + +def _check_num_classes_ml( + num_classes: int, multiclass: tp.Optional[bool], implied_classes: int +) -> None: + """This checks that the consistency of `num_classes` with the data and `multiclass` param for multi-label + data.""" + + if multiclass and num_classes != 2: + raise ValueError( + "Your have set `multiclass=True`, but `num_classes` is not equal to 2." + " If you are trying to transform multi-label data to 2 class multi-dimensional" + " multi-class, you should set `num_classes` to either 2 or None." + ) + if not multiclass and num_classes != implied_classes: + raise ValueError( + "The implied number of classes (from shape of inputs) does not match num_classes." + ) + + +def _check_top_k( + top_k: int, + case: DataType, + implied_classes: int, + multiclass: tp.Optional[bool], + preds_float: bool, +) -> None: + if case == DataType.BINARY: + raise ValueError("You can not use `top_k` parameter with binary data.") + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("The `top_k` has to be an integer larger than 0.") + if not preds_float: + raise ValueError( + "You have set `top_k`, but you do not have probability predictions." + ) + if multiclass is False: + raise ValueError("If you set `multiclass=False`, you can not set `top_k`.") + if case == DataType.MULTILABEL and multiclass: + raise ValueError( + "If you want to transform multi-label data to 2 class multi-dimensional" + "multi-class data using `multiclass=True`, you can not use `top_k`." + ) + if top_k >= implied_classes: + raise ValueError( + "The `top_k` has to be strictly smaller than the `C` dimension of `preds`." + ) + + +def _accuracy_compute( + tp: jnp.ndarray, + fp: jnp.ndarray, + tn: jnp.ndarray, + fn: jnp.ndarray, + average: tp.Optional[AverageMethod], + mdmc_average: tp.Optional[MDMCAverageMethod], + mode: DataType, +) -> jnp.ndarray: + """Computes accuracy from stat scores: true positives, false positives, true negatives, false negatives. + + Args: + tp: True positives + fp: False positives + tn: True negatives + fn: False negatives + average: Defines the reduction that is applied. + mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). + mode: Mode of the input tensors + + Example: + >>> preds = torch.tensor([0, 2, 1, 3]) + >>> target = torch.tensor([0, 1, 2, 3]) + >>> threshold = 0.5 + >>> reduce = average = 'micro' + >>> mdmc_average = 'global' + >>> mode = _mode(preds, target, threshold, top_k=None, num_classes=None, multiclass=None) + >>> tp, fp, tn, fn = _accuracy_update( + ... preds, + ... target, + ... reduce, + ... mdmc_average, + ... threshold=0.5, + ... num_classes=None, + ... top_k=None, + ... multiclass=None, + ... ignore_index=None, + ... mode=mode) + >>> _accuracy_compute(tp, fp, tn, fn, average, mdmc_average, mode) + tensor(0.5000) + + >>> target = torch.tensor([0, 1, 2]) + >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) + >>> top_k, threshold = 2, 0.5 + >>> reduce = average = 'micro' + >>> mdmc_average = 'global' + >>> mode = _mode(preds, target, threshold, top_k, num_classes=None, multiclass=None) + >>> tp, fp, tn, fn = _accuracy_update( + ... preds, + ... target, + ... reduce, + ... mdmc_average, + ... threshold, + ... num_classes=None, + ... top_k=top_k, + ... multiclass=None, + ... ignore_index=None, + ... mode=mode) + >>> _accuracy_compute(tp, fp, tn, fn, average, mdmc_average, mode) + tensor(0.6667) + """ + + if ( + mode == DataType.BINARY + and average in [AverageMethod.MICRO, AverageMethod.SAMPLES] + ) or mode == DataType.MULTILABEL: + numerator = tp + tn + denominator = tp + tn + fp + fn + else: + numerator = tp + denominator = tp + fn + + if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + cond = tp + fp + fn == 0 + numerator = numerator[~cond] + denominator = denominator[~cond] + + if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + # a class is not present if there exists no TPs, no FPs, and no FNs + meaningless_indeces = jnp.nonzero((tp | fn | fp) == 0) + numerator[meaningless_indeces, ...] = -1 + denominator[meaningless_indeces, ...] = -1 + + return _reduce_stat_scores( + numerator=numerator, + denominator=denominator, + weights=None if average != AverageMethod.WEIGHTED else tp + fn, + average=average, + mdmc_average=mdmc_average, + ) + + +def _reduce_stat_scores( + numerator: jnp.ndarray, + denominator: jnp.ndarray, + weights: tp.Optional[jnp.ndarray], + average: tp.Optional[AverageMethod], + mdmc_average: tp.Optional[MDMCAverageMethod], + zero_division: int = 0, +) -> jnp.ndarray: + """Reduces scores of type ``numerator/denominator`` or. + + ``weights * (numerator/denominator)``, if ``average='weighted'``. + + Args: + numerator: A tensor with numerator numbers. + denominator: A tensor with denominator numbers. If a denominator is + negative, the class will be ignored (if averaging), or its score + will be returned as ``nan`` (if ``average=None``). + If the denominator is zero, then ``zero_division`` score will be + used for those elements. + weights: A tensor of weights to be used if ``average='weighted'``. + average: The method to average the scores + mdmc_average: The method to average the scores if inputs were multi-dimensional multi-class (MDMC) + zero_division: The value to use for the score if denominator equals zero. + """ + numerator, denominator = numerator.astype(jnp.float32), denominator.astype( + jnp.float32 + ) + zero_div_mask = denominator == 0 + ignore_mask = denominator < 0 + + if weights is None: + weights_ = jnp.ones_like(denominator) + else: + weights_ = weights.astype(jnp.float32) + + numerator = jnp.where( + zero_div_mask, + jnp.array(float(zero_division)), + numerator, + ) + denominator = jnp.where( + zero_div_mask | ignore_mask, + jnp.array(1.0, dtype=denominator.dtype), + denominator, + ) + weights_ = jnp.where(ignore_mask, jnp.array(0.0, dtype=weights_.dtype), weights_) + + if average not in (AverageMethod.MICRO, AverageMethod.NONE, None): + weights_ = weights_ / weights_.sum(axis=-1, keepdims=True) + + scores = weights_ * (numerator / denominator) + + # This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted' + scores = jnp.where( + jnp.isnan(scores), jnp.array(float(zero_division), dtype=scores.dtype), scores + ) + + if mdmc_average == MDMCAverageMethod.SAMPLEWISE: + scores = scores.mean(axis=0) + ignore_mask = ignore_mask.sum(axis=0).astype(jnp.bool_) + + if average in (AverageMethod.NONE, None): + scores = jnp.where( + ignore_mask, jnp.array(float("nan"), dtype=scores.dtype), scores + ) + else: + scores = scores.sum() + + return scores From df501680f51f63340e6fc33e82a386ec67ab0052 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 18 Oct 2021 16:15:59 -0500 Subject: [PATCH 2/3] add Metrics composer --- tests/metrics/test_accuracy.py | 4 +- tests/metrics/test_metrics.py | 64 ++++++++++++++++++++++++++++++++ treex/metrics/__init__.py | 1 + treex/metrics/metric.py | 2 +- treex/metrics/metrics.py | 64 ++++++++++++++++++++++++++++++++ treex/types.py | 1 + treex/utils.py | 68 ++++++++++++++-------------------- 7 files changed, 161 insertions(+), 43 deletions(-) create mode 100644 tests/metrics/test_metrics.py create mode 100644 treex/metrics/metrics.py diff --git a/tests/metrics/test_accuracy.py b/tests/metrics/test_accuracy.py index b5dee6c9..58e55747 100644 --- a/tests/metrics/test_accuracy.py +++ b/tests/metrics/test_accuracy.py @@ -13,7 +13,7 @@ class TestAccuracy: def test_jit(self): N = 0 - # @jax.jit + @jax.jit def f(m, y_true, y_pred): nonlocal N N += 1 @@ -29,7 +29,7 @@ def f(m, y_true, y_pred): assert metric.compute() == 0.8 metric = f(metric, y_true, y_pred) - # assert N == 1 + assert N == 1 assert metric.compute() == 0.8 def test_logits_preds(self): diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py new file mode 100644 index 00000000..bfbb1944 --- /dev/null +++ b/tests/metrics/test_metrics.py @@ -0,0 +1,64 @@ +import jax +import jax.numpy as jnp +import pytest + +import treex as tx +from treex import metrics + + +class TestAccuracy: + def test_list(self): + + N = 0 + + @jax.jit + def f(m, y_true, y_pred): + nonlocal N + N += 1 + m(y_true=y_true, y_pred=y_pred) + return m + + metrics = tx.metrics.Metrics( + [ + tx.metrics.Accuracy(num_classes=10), + tx.metrics.Accuracy(num_classes=10), + ] + ) + y_true = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])[None, None, None, :] + y_pred = jnp.array([0, 1, 2, 3, 0, 5, 6, 7, 0, 9])[None, None, None, :] + + metrics = f(metrics, y_true, y_pred) + assert N == 1 + assert metrics.compute() == {"accuracy": 0.8, "accuracy2": 0.8} + + metrics = f(metrics, y_true, y_pred) + assert N == 1 + assert metrics.compute() == {"accuracy": 0.8, "accuracy2": 0.8} + + def test_dict(self): + + N = 0 + + @jax.jit + def f(m, y_true, y_pred): + nonlocal N + N += 1 + m(y_true=y_true, y_pred=y_pred) + return m + + metrics = tx.metrics.Metrics( + dict( + a=tx.metrics.Accuracy(num_classes=10), + b=tx.metrics.Accuracy(num_classes=10), + ) + ) + y_true = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])[None, None, None, :] + y_pred = jnp.array([0, 1, 2, 3, 0, 5, 6, 7, 0, 9])[None, None, None, :] + + metrics = f(metrics, y_true, y_pred) + assert N == 1 + assert metrics.compute() == {"a/accuracy": 0.8, "b/accuracy": 0.8} + + metrics = f(metrics, y_true, y_pred) + assert N == 1 + assert metrics.compute() == {"a/accuracy": 0.8, "b/accuracy": 0.8} diff --git a/treex/metrics/__init__.py b/treex/metrics/__init__.py index 54a32a01..aa2ccc5c 100644 --- a/treex/metrics/__init__.py +++ b/treex/metrics/__init__.py @@ -1,2 +1,3 @@ from .accuracy import Accuracy from .metric import Metric +from .metrics import Metrics diff --git a/treex/metrics/metric.py b/treex/metrics/metric.py index a868f85a..dd24f168 100644 --- a/treex/metrics/metric.py +++ b/treex/metrics/metric.py @@ -80,7 +80,7 @@ def reset(self): self.__dict__.update(self._initial_state) @abstractmethod - def update(self, **kwargs): + def update(self, **kwargs) -> None: ... @abstractmethod diff --git a/treex/metrics/metrics.py b/treex/metrics/metrics.py new file mode 100644 index 00000000..023fb8e3 --- /dev/null +++ b/treex/metrics/metrics.py @@ -0,0 +1,64 @@ +import typing as tp + +import jax.numpy as jnp + +from treex import types, utils +from treex.metrics.metric import Metric + + +class Metrics(Metric): + metrics: tp.Dict[str, Metric] + + def __init__( + self, + modules: tp.Any, + on: tp.Optional[types.IndexLike] = None, + name: tp.Optional[str] = None, + dtype: tp.Optional[jnp.dtype] = None, + ): + super().__init__(on=on, name=name, dtype=dtype) + + names: tp.Set[str] = set() + + def get_name(path, metric): + name = utils._get_name(metric) + return f"{path}/{name}" if path else name + + self.metrics = { + utils._unique_name(names, get_name(path, metric)): metric + for path, metric in utils._flatten_names(modules) + } + + def update(self, **kwargs) -> None: + for name, metric in self.metrics.items(): + update_kwargs = utils._function_argument_names(metric.update) + + if update_kwargs is None: + metric_kwargs = kwargs + + else: + metric_kwargs = {} + + for arg in update_kwargs: + if arg not in kwargs: + raise ValueError(f"Missing argument {arg} for metric {name}") + + metric_kwargs[arg] = kwargs[arg] + + metric.update(**metric_kwargs) + + def compute(self) -> tp.Any: + outputs = {} + names = set() + + for name, metric in self.metrics.items(): + + value = metric.compute() + + for path, value in utils._flatten_names(value): + name = f"{name}/{path}" if path else name + name = utils._unique_name(names, name) + + outputs[name] = value + + return outputs diff --git a/treex/types.py b/treex/types.py index ccd74460..e68dc974 100644 --- a/treex/types.py +++ b/treex/types.py @@ -13,6 +13,7 @@ InputLike = tp.Union[tp.Any, tp.Tuple[tp.Any, ...], tp.Dict[str, tp.Any], "Inputs"] IndexLike = tp.Union[str, int, tp.Sequence[tp.Union[str, int]]] +PathLike = tp.Tuple[IndexLike, ...] # ----------------------------------------- # TreeParts diff --git a/treex/utils.py b/treex/utils.py index a10b5b73..38579ed2 100644 --- a/treex/utils.py +++ b/treex/utils.py @@ -12,6 +12,7 @@ import yaml from jax._src.numpy.lax_numpy import split from rich.console import Console +from treeo.utils import _get_name, _lower_snake_case, _unique_name, _unique_names from treex import types @@ -456,51 +457,38 @@ def wrapper(*args, **kwargs): return wrapper -def _unique_name( - names: tp.Set[str], - name: str, -): +def _flatten_names(inputs: tp.Any) -> tp.List[tp.Tuple[str, tp.Any]]: + return [ + ("/".join(map(str, path)), value) + for path, value in _flatten_names_helper((), inputs) + ] - if name in names: - i = 1 - while f"{name}_{i}" in names: - i += 1 - name = f"{name}_{i}" - - names.add(name) - return name - - -def _unique_names( - names: tp.Iterable[str], -) -> tp.Iterable[str]: - new_names: tp.Set[str] = set() - - for name in names: - yield _unique_name(new_names, name) +def _flatten_names_helper( + path: types.PathLike, inputs: tp.Any +) -> tp.Iterable[tp.Tuple[types.PathLike, tp.Any]]: + if isinstance(inputs, (tp.Tuple, tp.List)): + for i, value in enumerate(inputs): + yield from _flatten_names_helper(path, value) + elif isinstance(inputs, tp.Dict): + for name, value in inputs.items(): + yield from _flatten_names_helper(path + (name,), value) + else: + yield (path, inputs) -def _lower_snake_case(s: str) -> str: - s = re.sub(r"(? 1: - output_parts.append(parts[i]) - else: - output_parts[-1] += parts[i] +def _function_argument_names(f) -> tp.Optional[tp.List[str]]: + """ + Returns: + A list of keyword argument names or None if variable keyword arguments (`**kwargs`) are present. + """ + kwarg_names = [] - return "_".join(output_parts) + for k, v in inspect.signature(f).parameters.items(): + if v.kind == inspect.Parameter.VAR_KEYWORD: + return None + kwarg_names.append(k) -def _get_name(obj) -> str: - if hasattr(obj, "name") and obj.name: - return obj.name - elif hasattr(obj, "__name__") and obj.__name__: - return obj.__name__ - elif hasattr(obj, "__class__") and obj.__class__.__name__: - return _lower_snake_case(obj.__class__.__name__) - else: - raise ValueError(f"Could not get name for: {obj}") + return kwarg_names From 3e7762aa26b5f022df33f7b734cba86be23a0717 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 18 Oct 2021 16:56:05 -0500 Subject: [PATCH 3/3] export Mean and Reduce --- treex/metrics/__init__.py | 2 ++ treex/metrics/{old => }/mean.py | 0 treex/metrics/{old => }/reduce.py | 2 +- treex/metrics/tm_port/classification/accuracy.py | 2 +- treex/metrics/tm_port/classification/stat_scores.py | 2 +- .../tm_port/functional/classification/stat_scores.py | 8 ++++---- treex/metrics/tm_port/utilities/checks.py | 2 +- treex/metrics/tm_port/utilities/data.py | 2 +- 8 files changed, 11 insertions(+), 9 deletions(-) rename treex/metrics/{old => }/mean.py (100%) rename treex/metrics/{old => }/reduce.py (98%) diff --git a/treex/metrics/__init__.py b/treex/metrics/__init__.py index aa2ccc5c..d81a6331 100644 --- a/treex/metrics/__init__.py +++ b/treex/metrics/__init__.py @@ -1,3 +1,5 @@ from .accuracy import Accuracy +from .mean import Mean from .metric import Metric from .metrics import Metrics +from .reduce import Reduce, Reduction diff --git a/treex/metrics/old/mean.py b/treex/metrics/mean.py similarity index 100% rename from treex/metrics/old/mean.py rename to treex/metrics/mean.py diff --git a/treex/metrics/old/reduce.py b/treex/metrics/reduce.py similarity index 98% rename from treex/metrics/old/reduce.py rename to treex/metrics/reduce.py index 0722624f..fc0a71b4 100644 --- a/treex/metrics/old/reduce.py +++ b/treex/metrics/reduce.py @@ -40,7 +40,7 @@ def __init__( Reduction.sum_over_batch_size, Reduction.weighted_mean, ): - self.count = jnp.array(0, dtype=jnp.uint64) + self.count = jnp.array(0, dtype=jnp.uint32) else: self.count = None diff --git a/treex/metrics/tm_port/classification/accuracy.py b/treex/metrics/tm_port/classification/accuracy.py index 4cfc1404..0c17e784 100644 --- a/treex/metrics/tm_port/classification/accuracy.py +++ b/treex/metrics/tm_port/classification/accuracy.py @@ -203,7 +203,7 @@ def __init__( dtype=dtype, ) - self.correct = tensor(0, dtype=jnp.uint64) # , dist_reduce_fx="sum") + self.correct = tensor(0, dtype=jnp.uint32) # , dist_reduce_fx="sum") self.total = tensor(0, dtype=jnp.float32) # , dist_reduce_fx="sum") if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): diff --git a/treex/metrics/tm_port/classification/stat_scores.py b/treex/metrics/tm_port/classification/stat_scores.py index 48d75907..570ff187 100644 --- a/treex/metrics/tm_port/classification/stat_scores.py +++ b/treex/metrics/tm_port/classification/stat_scores.py @@ -183,7 +183,7 @@ def __init__( zeros_shape = [num_classes] else: raise ValueError(f'Wrong reduce="{reduce}"') - default = lambda: jnp.zeros(zeros_shape, dtype=jnp.uint64) + default = lambda: jnp.zeros(zeros_shape, dtype=jnp.uint32) reduce_fn = "sum" for s in ("tp", "fp", "tn", "fn"): diff --git a/treex/metrics/tm_port/functional/classification/stat_scores.py b/treex/metrics/tm_port/functional/classification/stat_scores.py index 2fa712ed..ac473bf1 100644 --- a/treex/metrics/tm_port/functional/classification/stat_scores.py +++ b/treex/metrics/tm_port/functional/classification/stat_scores.py @@ -59,10 +59,10 @@ def _stat_scores( fn = (false_pred * neg_pred).sum(axis=dim) return ( - tp.astype(jnp.uint64), - fp.astype(jnp.uint64), - tn.astype(jnp.uint64), - fn.astype(jnp.uint64), + tp.astype(jnp.uint32), + fp.astype(jnp.uint32), + tn.astype(jnp.uint32), + fn.astype(jnp.uint32), ) diff --git a/treex/metrics/tm_port/utilities/checks.py b/treex/metrics/tm_port/utilities/checks.py index f0f38167..2158dda5 100644 --- a/treex/metrics/tm_port/utilities/checks.py +++ b/treex/metrics/tm_port/utilities/checks.py @@ -3,7 +3,7 @@ import jax import jax.numpy as jnp import numpy as np -from numpy import uint64 +from numpy import uint32 Tensor = jnp.ndarray diff --git a/treex/metrics/tm_port/utilities/data.py b/treex/metrics/tm_port/utilities/data.py index 5af35bf0..821f1799 100644 --- a/treex/metrics/tm_port/utilities/data.py +++ b/treex/metrics/tm_port/utilities/data.py @@ -64,7 +64,7 @@ def _flatten(x: Sequence) -> list: # # device=label_tensor.device, # ) # index = jnp.broadcast_to( -# label_tensor.astype(jnp.uint64)[:None], tensor_onehot.shape +# label_tensor.astype(jnp.uint32)[:None], tensor_onehot.shape # ) # return tensor_onehot.scatter_(1, index, 1.0)