-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metrics docs #2184
Merged
Merged
Metrics docs #2184
Changes from 13 commits
Commits
Show all changes
81 commits
Select commit
Hold shift + click to select a range
acbc1be
fixes
williamFalcon 90e3c0d
fixes
williamFalcon a845179
fixes
williamFalcon 70406bd
fixes
williamFalcon d8194f1
fixes
williamFalcon 1c75220
fixes
williamFalcon 5a9ddb2
fixes
williamFalcon 8386615
fixes
williamFalcon da1f5ab
fixes
williamFalcon b7480e4
fixes
williamFalcon 38777d3
fixes
williamFalcon 6d51f55
fixes
williamFalcon a501d8b
fixes
williamFalcon dce2473
fixes
williamFalcon 0f6eafa
fixes
williamFalcon 5aa02eb
fixes
williamFalcon 3f0ee25
fixes
williamFalcon b4d648f
fixes
williamFalcon 1e5d7c0
fixes
williamFalcon 16d8f0d
fixes
williamFalcon bcaa743
Apply suggestions from code review
Borda 69bd7cb
add workers fix
williamFalcon ec277e9
add workers fix
williamFalcon 0519206
add workers fix
williamFalcon d8a51b0
add workers fix
williamFalcon 5be3a8b
add workers fix
williamFalcon faf362b
add workers fix
williamFalcon eef0c54
add workers fix
williamFalcon 43876c0
add workers fix
williamFalcon 992409a
add workers fix
williamFalcon 94b1382
add workers fix
williamFalcon f946650
add workers fix
williamFalcon c0aea3f
add workers fix
williamFalcon bc9007d
add workers fix
williamFalcon 55409fe
Update docs/source/metrics.rst
williamFalcon 30c8148
Update docs/source/metrics.rst
williamFalcon 4d70dcf
Update docs/source/metrics.rst
williamFalcon 46c8c6f
Update docs/source/metrics.rst
williamFalcon e7f9b85
doctests
Borda d3e792c
add workers fix
williamFalcon 596563a
add workers fix
williamFalcon 3b0c296
fixes
williamFalcon 41be1a7
fix docs
c7c88cb
fixes
williamFalcon fef6572
fixes
williamFalcon 414dd81
fixes
williamFalcon 8d2f9c5
fixes
williamFalcon de713fe
fixes
williamFalcon 1f586bf
fixes
williamFalcon 2c6ed8b
fixes
williamFalcon 2422461
fixes
williamFalcon bb9c258
fixes
williamFalcon ae601a6
fixes
williamFalcon 8ecc4a0
fixes
williamFalcon 705cf63
fixes
williamFalcon 022625e
fixes
williamFalcon babb117
fixes
williamFalcon 6204316
fixes
williamFalcon 2d9b90b
fixes
williamFalcon fb26791
fixes
williamFalcon 14cf60a
Apply suggestions from code review
Borda d2d963d
add workers fix
williamFalcon 464f367
Update docs/source/metrics.rst
williamFalcon 95446e4
doctests
Borda a32fcf4
add workers fix
williamFalcon 5fbda6e
fixes
williamFalcon 339814f
fix docs
8d9f53a
fixes
williamFalcon 8ce3687
fix doctests
Borda 2a11c3e
Apply suggestions from code review
Borda d71c33a
fix doctests
Borda 7472031
fix examples
Borda 40bd93d
bug
Borda 78e1198
Update docs/source/metrics.rst
williamFalcon ca4392a
Update docs/source/metrics.rst
williamFalcon be8d587
Update docs/source/metrics.rst
williamFalcon e9a90a0
fixes
williamFalcon 2e832af
fixes
williamFalcon db0dc42
fixes
williamFalcon e7d8e50
fixes
williamFalcon a85818c
fixes
williamFalcon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,319 @@ | ||
.. automodule:: pytorch_lightning.metrics | ||
:members: | ||
:noindex: | ||
:exclude-members: | ||
.. testsetup:: * | ||
|
||
from torch.nn import Module | ||
from pytorch_lightning.core.lightning import LightningModule | ||
from pytorch_lightning.metrics import TensorMetric, NumpyMetric | ||
|
||
Metrics | ||
======= | ||
This is a general package for PyTorch Metrics. These can also be used with regular non-lightning PyTorch code. | ||
Metrics are used to monitor model performance. | ||
|
||
In this package we provide two major pieces of functionality. | ||
|
||
1. A Metric class you can use to implement metrics with built-in distributed (ddp) support which are device agnostic. | ||
2. A collection of popular metrics already implemented for you. | ||
|
||
Example: | ||
|
||
.. testcode:: | ||
|
||
from pytorch_lightning.metrics.functional import accuracy | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are these imports wrong? |
||
|
||
pred = torch.tensor([0, 1, 2, 3]) | ||
target = torch.tensor([0, 1, 2, 2]) | ||
|
||
# calculates accuracy across all GPUs and all Nodes used in training | ||
accuracy(pred, target) | ||
|
||
.. testoutput:: | ||
|
||
tensor(0.7500) | ||
|
||
-------------- | ||
|
||
Implement a metric | ||
------------------ | ||
You can implement metrics as either a PyTorch metric or a Numpy metric. Numpy metrics | ||
will slow down training, use PyTorch metrics when possible. | ||
|
||
Use :class:`TensorMetric` to implement native PyTorch metrics. This class | ||
handles automated DDP syncing and converts all inputs and outputs to tensors. | ||
|
||
Use :class:`NumpyMetric` to implement numpy metrics. This class | ||
handles automated DDP syncing and converts all inputs and outputs to tensors. | ||
|
||
.. warning:: Numpy metrics might slow down your training substantially, | ||
williamFalcon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
since every metric computation requires a GPU sync to convert tensors to numpy. | ||
|
||
TensorMetric | ||
^^^^^^^^^^^^ | ||
Here's an example showing how to implement a TensorMetric | ||
|
||
.. testcode:: | ||
|
||
class RMSE(TensorMetric): | ||
def forward(self, x, y): | ||
return torch.sqrt(torch.mean(torch.pow(x-y, 2.0))) | ||
|
||
.. autoclass:: pytorch_lightning.metrics.metric.TensorMetric | ||
:noindex: | ||
|
||
NumpyMetric | ||
^^^^^^^^^^^ | ||
Here's an example showing how to implement a NumpyMetric | ||
|
||
.. testcode:: | ||
|
||
class RMSE(NumpyMetric): | ||
def forward(self, x, y): | ||
return np.sqrt(np.mean(np.power(x-y, 2.0))) | ||
|
||
|
||
.. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric | ||
:noindex: | ||
|
||
-------------- | ||
|
||
Class Metrics | ||
------------- | ||
The following are metrics which can be instantiated as part of a module definition (even with just | ||
plain PyTorch). | ||
|
||
.. testcode:: | ||
|
||
from pytorch_lightning.metrics import Accuracy | ||
|
||
# Plain PyTorch | ||
class MyModule(Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.metric = Accuracy() | ||
|
||
def forward(self, x, y): | ||
y_hat = # ... | ||
williamFalcon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
acc = self.metric(y_hat, y) | ||
|
||
# PyTorch Lightning | ||
class MyModule(LightningModule): | ||
def __init__(self): | ||
super().__init__() | ||
self.metric = Accuracy() | ||
|
||
def training_step(self, batch, batch_idx): | ||
x, y = batch | ||
y_hat = # ... | ||
williamFalcon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
acc = self.metric(y_hat, y) | ||
|
||
These metrics even work when using distributed training: | ||
|
||
.. code-block:: python | ||
|
||
model = MyModule() | ||
trainer = Trainer(gpus=8, num_nodes=2) | ||
|
||
# any metric automatically reduces across GPUs (even the ones you implement using Lightning) | ||
trainer.fit(model) | ||
|
||
Accuracy | ||
^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.Accuracy | ||
:noindex: | ||
|
||
AveragePrecision | ||
^^^^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision | ||
:noindex: | ||
|
||
AUROC | ||
^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.AUROC | ||
:noindex: | ||
|
||
ConfusionMatrix | ||
^^^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix | ||
:noindex: | ||
|
||
DiceCoefficient | ||
^^^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.DiceCoefficient | ||
:noindex: | ||
|
||
F1 | ||
^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.F1 | ||
:noindex: | ||
|
||
FBeta | ||
^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.FBeta | ||
:noindex: | ||
|
||
PrecisionRecall | ||
^^^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecall | ||
:noindex: | ||
|
||
Precision | ||
^^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.Precision | ||
:noindex: | ||
|
||
Recall | ||
^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.Recall | ||
:noindex: | ||
|
||
ROC | ||
^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.ROC | ||
:noindex: | ||
|
||
MulticlassROC | ||
^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.MulticlassROC | ||
:noindex: | ||
|
||
MulticlassPrecisionRecall | ||
^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: pytorch_lightning.metrics.classification.MulticlassPrecisionRecall | ||
:noindex: | ||
|
||
-------------- | ||
|
||
Functional Metrics | ||
------------------ | ||
|
||
accuracy (F) | ||
^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.accuracy | ||
:noindex: | ||
|
||
auc (F) | ||
^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.auc | ||
:noindex: | ||
|
||
auroc (F) | ||
^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.auroc | ||
:noindex: | ||
|
||
average_precision (F) | ||
^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.average_precision | ||
:noindex: | ||
|
||
confusion_matrix (F) | ||
^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.confusion_matrix | ||
:noindex: | ||
|
||
dice_score (F) | ||
^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.dice_score | ||
:noindex: | ||
|
||
f1_score (F) | ||
^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.f1_score | ||
:noindex: | ||
|
||
fbeta_score (F) | ||
^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.fbeta_score | ||
:noindex: | ||
|
||
multiclass_precision_recall_curve (F) | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.multiclass_precision_recall_curve | ||
:noindex: | ||
|
||
multiclass_roc (F) | ||
^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.multiclass_roc | ||
:noindex: | ||
|
||
precision (F) | ||
^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.precision | ||
:noindex: | ||
|
||
precision_recall (F) | ||
^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall | ||
:noindex: | ||
|
||
precision_recall_curve (F) | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve | ||
:noindex: | ||
|
||
recall (F) | ||
^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.recall | ||
:noindex: | ||
|
||
roc (F) | ||
^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.roc | ||
:noindex: | ||
|
||
stat_scores (F) | ||
^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores | ||
:noindex: | ||
|
||
stat_scores_multiple_classes (F) | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores_multiple_classes | ||
:noindex: | ||
|
||
---------------- | ||
|
||
Metric pre-processing | ||
--------------------- | ||
Metric | ||
|
||
to_categorical (F) | ||
^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.to_categorical | ||
:noindex: | ||
|
||
to_onehot (F) | ||
^^^^^^^^^^^^^ | ||
|
||
.. autofunction:: pytorch_lightning.metrics.functional.to_onehot | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,15 @@ | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would keep a short version here... |
||
Metrics | ||
======= | ||
|
||
Metrics are generally used to monitor model performance. | ||
|
||
The following package aims to provide the most convenient ones as well | ||
as a structure to implement your custom metrics for all the fancy research | ||
you want to do. | ||
|
||
For native PyTorch implementations of metrics, it is recommended to use | ||
the :class:`TensorMetric` which handles automated DDP syncing and conversions | ||
to tensors for all inputs and outputs. | ||
|
||
If your metrics implementation works on numpy, just use the | ||
:class:`NumpyMetric`, which handles the automated conversion of | ||
inputs to and outputs from numpy as well as automated ddp syncing. | ||
|
||
.. warning:: Employing numpy in your metric calculation might slow | ||
down your training substantially, since every metric computation | ||
requires a GPU sync to convert tensors to numpy. | ||
|
||
|
||
""" | ||
|
||
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric | ||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric | ||
from pytorch_lightning.metrics.sklearn import ( | ||
SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta, | ||
Precision, Recall, PrecisionRecallCurve, ROC, AUROC) | ||
SklearnMetric, | ||
Accuracy, | ||
AveragePrecision, | ||
AUC, | ||
ConfusionMatrix, | ||
F1, | ||
FBeta, | ||
Precision, | ||
Recall, | ||
PrecisionRecallCurve, | ||
ROC, | ||
AUROC) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we mention that the package also comes with a interface to sklearn metrics (with a warning that this is slow due to casting back-and-forth of the tensors)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should. We should also include that sklearn has to be installed separately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make it clear how to import the different backends. Maybe something like:
Use native backend
import pytorch_lightning.metrics.native as plm
Use sklearn backend
import pytorch_lightning.metrics.sklearn as plm
Use default (native if available else sklearn)
import pytorch_lightning.metrics as plm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SkafteNicki can you add the sklearn details in here?