Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metrics docs #2184

Merged
merged 81 commits into from
Jun 16, 2020
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
acbc1be
fixes
williamFalcon Jun 13, 2020
90e3c0d
fixes
williamFalcon Jun 13, 2020
a845179
fixes
williamFalcon Jun 13, 2020
70406bd
fixes
williamFalcon Jun 13, 2020
d8194f1
fixes
williamFalcon Jun 13, 2020
1c75220
fixes
williamFalcon Jun 13, 2020
5a9ddb2
fixes
williamFalcon Jun 13, 2020
8386615
fixes
williamFalcon Jun 14, 2020
da1f5ab
fixes
williamFalcon Jun 14, 2020
b7480e4
fixes
williamFalcon Jun 14, 2020
38777d3
fixes
williamFalcon Jun 14, 2020
6d51f55
fixes
williamFalcon Jun 14, 2020
a501d8b
fixes
williamFalcon Jun 14, 2020
dce2473
fixes
williamFalcon Jun 14, 2020
0f6eafa
fixes
williamFalcon Jun 14, 2020
5aa02eb
fixes
williamFalcon Jun 14, 2020
3f0ee25
fixes
williamFalcon Jun 14, 2020
b4d648f
fixes
williamFalcon Jun 14, 2020
1e5d7c0
fixes
williamFalcon Jun 14, 2020
16d8f0d
fixes
williamFalcon Jun 15, 2020
bcaa743
Apply suggestions from code review
Borda Jun 15, 2020
69bd7cb
add workers fix
williamFalcon Jun 15, 2020
ec277e9
add workers fix
williamFalcon Jun 15, 2020
0519206
add workers fix
williamFalcon Jun 15, 2020
d8a51b0
add workers fix
williamFalcon Jun 15, 2020
5be3a8b
add workers fix
williamFalcon Jun 15, 2020
faf362b
add workers fix
williamFalcon Jun 15, 2020
eef0c54
add workers fix
williamFalcon Jun 15, 2020
43876c0
add workers fix
williamFalcon Jun 15, 2020
992409a
add workers fix
williamFalcon Jun 15, 2020
94b1382
add workers fix
williamFalcon Jun 15, 2020
f946650
add workers fix
williamFalcon Jun 15, 2020
c0aea3f
add workers fix
williamFalcon Jun 15, 2020
bc9007d
add workers fix
williamFalcon Jun 15, 2020
55409fe
Update docs/source/metrics.rst
williamFalcon Jun 15, 2020
30c8148
Update docs/source/metrics.rst
williamFalcon Jun 15, 2020
4d70dcf
Update docs/source/metrics.rst
williamFalcon Jun 15, 2020
46c8c6f
Update docs/source/metrics.rst
williamFalcon Jun 15, 2020
e7f9b85
doctests
Borda Jun 15, 2020
d3e792c
add workers fix
williamFalcon Jun 15, 2020
596563a
add workers fix
williamFalcon Jun 15, 2020
3b0c296
fixes
williamFalcon Jun 15, 2020
41be1a7
fix docs
Jun 15, 2020
c7c88cb
fixes
williamFalcon Jun 15, 2020
fef6572
fixes
williamFalcon Jun 15, 2020
414dd81
fixes
williamFalcon Jun 15, 2020
8d2f9c5
fixes
williamFalcon Jun 15, 2020
de713fe
fixes
williamFalcon Jun 15, 2020
1f586bf
fixes
williamFalcon Jun 16, 2020
2c6ed8b
fixes
williamFalcon Jun 16, 2020
2422461
fixes
williamFalcon Jun 16, 2020
bb9c258
fixes
williamFalcon Jun 16, 2020
ae601a6
fixes
williamFalcon Jun 16, 2020
8ecc4a0
fixes
williamFalcon Jun 16, 2020
705cf63
fixes
williamFalcon Jun 16, 2020
022625e
fixes
williamFalcon Jun 16, 2020
babb117
fixes
williamFalcon Jun 16, 2020
6204316
fixes
williamFalcon Jun 16, 2020
2d9b90b
fixes
williamFalcon Jun 16, 2020
fb26791
fixes
williamFalcon Jun 13, 2020
14cf60a
Apply suggestions from code review
Borda Jun 15, 2020
d2d963d
add workers fix
williamFalcon Jun 15, 2020
464f367
Update docs/source/metrics.rst
williamFalcon Jun 15, 2020
95446e4
doctests
Borda Jun 15, 2020
a32fcf4
add workers fix
williamFalcon Jun 15, 2020
5fbda6e
fixes
williamFalcon Jun 15, 2020
339814f
fix docs
Jun 15, 2020
8d9f53a
fixes
williamFalcon Jun 15, 2020
8ce3687
fix doctests
Borda Jun 16, 2020
2a11c3e
Apply suggestions from code review
Borda Jun 16, 2020
d71c33a
fix doctests
Borda Jun 16, 2020
7472031
fix examples
Borda Jun 16, 2020
40bd93d
bug
Borda Jun 16, 2020
78e1198
Update docs/source/metrics.rst
williamFalcon Jun 16, 2020
ca4392a
Update docs/source/metrics.rst
williamFalcon Jun 16, 2020
be8d587
Update docs/source/metrics.rst
williamFalcon Jun 16, 2020
e9a90a0
fixes
williamFalcon Jun 16, 2020
2e832af
fixes
williamFalcon Jun 16, 2020
db0dc42
fixes
williamFalcon Jun 16, 2020
e7d8e50
fixes
williamFalcon Jun 16, 2020
a85818c
fixes
williamFalcon Jun 16, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
323 changes: 319 additions & 4 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,319 @@
.. automodule:: pytorch_lightning.metrics
:members:
:noindex:
:exclude-members:
.. testsetup:: *

from torch.nn import Module
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.metrics import TensorMetric, NumpyMetric

Metrics
=======
This is a general package for PyTorch Metrics. These can also be used with regular non-lightning PyTorch code.
Metrics are used to monitor model performance.

In this package we provide two major pieces of functionality.

1. A Metric class you can use to implement metrics with built-in distributed (ddp) support which are device agnostic.
2. A collection of popular metrics already implemented for you.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we mention that the package also comes with a interface to sklearn metrics (with a warning that this is slow due to casting back-and-forth of the tensors)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should. We should also include that sklearn has to be installed separately

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make it clear how to import the different backends. Maybe something like:

Use native backend
import pytorch_lightning.metrics.native as plm

Use sklearn backend
import pytorch_lightning.metrics.sklearn as plm

Use default (native if available else sklearn)
import pytorch_lightning.metrics as plm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SkafteNicki can you add the sklearn details in here?

Example:

.. testcode::

from pytorch_lightning.metrics.functional import accuracy
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these imports wrong?
@SkafteNicki @justusschock is there a better way?


pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 1, 2, 2])

# calculates accuracy across all GPUs and all Nodes used in training
accuracy(pred, target)

.. testoutput::

tensor(0.7500)

--------------

Implement a metric
------------------
You can implement metrics as either a PyTorch metric or a Numpy metric. Numpy metrics
will slow down training, use PyTorch metrics when possible.

Use :class:`TensorMetric` to implement native PyTorch metrics. This class
handles automated DDP syncing and converts all inputs and outputs to tensors.

Use :class:`NumpyMetric` to implement numpy metrics. This class
handles automated DDP syncing and converts all inputs and outputs to tensors.

.. warning:: Numpy metrics might slow down your training substantially,
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved
since every metric computation requires a GPU sync to convert tensors to numpy.

TensorMetric
^^^^^^^^^^^^
Here's an example showing how to implement a TensorMetric

.. testcode::

class RMSE(TensorMetric):
def forward(self, x, y):
return torch.sqrt(torch.mean(torch.pow(x-y, 2.0)))

.. autoclass:: pytorch_lightning.metrics.metric.TensorMetric
:noindex:

NumpyMetric
^^^^^^^^^^^
Here's an example showing how to implement a NumpyMetric

.. testcode::

class RMSE(NumpyMetric):
def forward(self, x, y):
return np.sqrt(np.mean(np.power(x-y, 2.0)))


.. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric
:noindex:

--------------

Class Metrics
-------------
The following are metrics which can be instantiated as part of a module definition (even with just
plain PyTorch).

.. testcode::

from pytorch_lightning.metrics import Accuracy

# Plain PyTorch
class MyModule(Module):
def __init__(self):
super().__init__()
self.metric = Accuracy()

def forward(self, x, y):
y_hat = # ...
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved
acc = self.metric(y_hat, y)

# PyTorch Lightning
class MyModule(LightningModule):
def __init__(self):
super().__init__()
self.metric = Accuracy()

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = # ...
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved
acc = self.metric(y_hat, y)

These metrics even work when using distributed training:

.. code-block:: python

model = MyModule()
trainer = Trainer(gpus=8, num_nodes=2)

# any metric automatically reduces across GPUs (even the ones you implement using Lightning)
trainer.fit(model)

Accuracy
^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.Accuracy
:noindex:

AveragePrecision
^^^^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision
:noindex:

AUROC
^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.AUROC
:noindex:

ConfusionMatrix
^^^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
:noindex:

DiceCoefficient
^^^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.DiceCoefficient
:noindex:

F1
^^

.. autoclass:: pytorch_lightning.metrics.classification.F1
:noindex:

FBeta
^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.FBeta
:noindex:

PrecisionRecall
^^^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecall
:noindex:

Precision
^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.Precision
:noindex:

Recall
^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.Recall
:noindex:

ROC
^^^

.. autoclass:: pytorch_lightning.metrics.classification.ROC
:noindex:

MulticlassROC
^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.MulticlassROC
:noindex:

MulticlassPrecisionRecall
^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.MulticlassPrecisionRecall
:noindex:

--------------

Functional Metrics
------------------

accuracy (F)
^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.accuracy
:noindex:

auc (F)
^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.auc
:noindex:

auroc (F)
^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.auroc
:noindex:

average_precision (F)
^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.average_precision
:noindex:

confusion_matrix (F)
^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.confusion_matrix
:noindex:

dice_score (F)
^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.dice_score
:noindex:

f1_score (F)
^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.f1_score
:noindex:

fbeta_score (F)
^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.fbeta_score
:noindex:

multiclass_precision_recall_curve (F)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.multiclass_precision_recall_curve
:noindex:

multiclass_roc (F)
^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.multiclass_roc
:noindex:

precision (F)
^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.precision
:noindex:

precision_recall (F)
^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.precision_recall
:noindex:

precision_recall_curve (F)
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve
:noindex:

recall (F)
^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.recall
:noindex:

roc (F)
^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.roc
:noindex:

stat_scores (F)
^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.stat_scores
:noindex:

stat_scores_multiple_classes (F)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.stat_scores_multiple_classes
:noindex:

----------------

Metric pre-processing
---------------------
Metric

to_categorical (F)
^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.to_categorical
:noindex:

to_onehot (F)
^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.to_onehot
:noindex:
39 changes: 12 additions & 27 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,15 @@
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep a short version here...

Metrics
=======

Metrics are generally used to monitor model performance.

The following package aims to provide the most convenient ones as well
as a structure to implement your custom metrics for all the fancy research
you want to do.

For native PyTorch implementations of metrics, it is recommended to use
the :class:`TensorMetric` which handles automated DDP syncing and conversions
to tensors for all inputs and outputs.

If your metrics implementation works on numpy, just use the
:class:`NumpyMetric`, which handles the automated conversion of
inputs to and outputs from numpy as well as automated ddp syncing.

.. warning:: Employing numpy in your metric calculation might slow
down your training substantially, since every metric computation
requires a GPU sync to convert tensors to numpy.


"""

from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
from pytorch_lightning.metrics.sklearn import (
SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
SklearnMetric,
Accuracy,
AveragePrecision,
AUC,
ConfusionMatrix,
F1,
FBeta,
Precision,
Recall,
PrecisionRecallCurve,
ROC,
AUROC)
Loading