Skip to content

Commit

Permalink
drop deprecated reorder from AUC (#5004)
Browse files Browse the repository at this point in the history
* drop deprecated reorder from AUC

* chlog

* fix

* fix

* simple

* fix

* fix

* fix

Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
  • Loading branch information
Borda and s-rog authored Dec 9, 2020
1 parent 20b806a commit 90d1d9f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 35 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Removed `reorder` parameter of the `auc` metric ([#5004](https://github.com/PyTorchLightning/pytorch-lightning/pull/5004))



### Fixed
Expand Down
43 changes: 13 additions & 30 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,17 +482,13 @@ def __multiclass_roc(
def auc(
x: torch.Tensor,
y: torch.Tensor,
reorder: bool = True
) -> torch.Tensor:
"""
Computes Area Under the Curve (AUC) using the trapezoidal rule
Args:
x: x-coordinates
y: y-coordinates
reorder: reorder coordinates, so they are increasing. The unstable algorithm of torch.argsort is
used internally to sort `x` which may in some cases cause inaccuracies in the result.
WARNING: Deprecated and will be removed in v1.1.
Return:
Tensor containing AUC score (float)
Expand All @@ -504,51 +500,38 @@ def auc(
>>> auc(x, y)
tensor(4.)
"""
direction = 1.

if reorder:
rank_zero_warn("The `reorder` parameter to `auc` has been deprecated and will be removed in v1.1"
" Note that when `reorder` is True, the unstable algorithm of torch.argsort is"
" used internally to sort 'x' which may in some cases cause inaccuracies"
" in the result.",
DeprecationWarning)
# can't use lexsort here since it is not implemented for torch
order = torch.argsort(x)
x, y = x[order], y[order]
dx = x[1:] - x[:-1]
if (dx < 0).any():
if (dx <= 0).all():
direction = -1.
else:
raise ValueError(f"The 'x' array is neither increasing or decreasing: {x}. Reorder is not supported.")
else:
dx = x[1:] - x[:-1]
if (dx < 0).any():
if (dx, 0).all():
direction = -1.
else:
# TODO: Update message on removing reorder
raise ValueError("Reorder is not turned on, and the 'x' array is"
f" neither increasing or decreasing: {x}")

direction = 1.
return direction * torch.trapz(y, x)


def auc_decorator(reorder: bool = True) -> Callable:
def auc_decorator() -> Callable:
def wrapper(func_to_decorate: Callable) -> Callable:
@wraps(func_to_decorate)
def new_func(*args, **kwargs) -> torch.Tensor:
x, y = func_to_decorate(*args, **kwargs)[:2]

return auc(x, y, reorder=reorder)
return auc(x, y)

return new_func

return wrapper


def multiclass_auc_decorator(reorder: bool = True) -> Callable:
def multiclass_auc_decorator() -> Callable:
def wrapper(func_to_decorate: Callable) -> Callable:
@wraps(func_to_decorate)
def new_func(*args, **kwargs) -> torch.Tensor:
results = []
for class_result in func_to_decorate(*args, **kwargs):
x, y = class_result[:2]
results.append(auc(x, y, reorder=reorder))
results.append(auc(x, y))

return torch.stack(results)

Expand Down Expand Up @@ -587,7 +570,7 @@ def auroc(
' target tensor contains value different from 0 and 1.'
' Use `multiclass_auroc` for multi class classification.')

@auc_decorator(reorder=True)
@auc_decorator()
def _auroc(pred, target, sample_weight, pos_label):
return __roc(pred, target, sample_weight, pos_label)

Expand Down Expand Up @@ -640,7 +623,7 @@ def multiclass_auroc(
f"Number of classes deduced from 'pred' ({pred.size(1)}) does not equal"
f" the number of classes passed in 'num_classes' ({num_classes}).")

@multiclass_auc_decorator(reorder=False)
@multiclass_auc_decorator()
def _multiclass_auroc(pred, target, sample_weight, num_classes):
return __multiclass_roc(pred, target, sample_weight, num_classes)

Expand Down
5 changes: 0 additions & 5 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,3 @@ def test_dataloader(self):

def test_end(self, outputs):
return {'test_loss': torch.tensor(0.7)}


def test_reorder_remove_in_v1_1():
with pytest.deprecated_call(match='The `reorder` parameter to `auc` has been deprecated'):
_ = auc(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 2]), reorder=True)

0 comments on commit 90d1d9f

Please sign in to comment.