Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

drop deprecated reorder from AUC #5004

Merged
merged 13 commits into from
Dec 9, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Removed `reorder` parameter of the `auc` metric ([#5004](https://github.com/PyTorchLightning/pytorch-lightning/pull/5004))



### Fixed
Expand Down
43 changes: 13 additions & 30 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,17 +482,13 @@ def __multiclass_roc(
def auc(
x: torch.Tensor,
y: torch.Tensor,
reorder: bool = True
) -> torch.Tensor:
"""
Computes Area Under the Curve (AUC) using the trapezoidal rule

Args:
x: x-coordinates
y: y-coordinates
reorder: reorder coordinates, so they are increasing. The unstable algorithm of torch.argsort is
used internally to sort `x` which may in some cases cause inaccuracies in the result.
WARNING: Deprecated and will be removed in v1.1.

Return:
Tensor containing AUC score (float)
Expand All @@ -504,51 +500,38 @@ def auc(
>>> auc(x, y)
tensor(4.)
"""
direction = 1.

if reorder:
rank_zero_warn("The `reorder` parameter to `auc` has been deprecated and will be removed in v1.1"
" Note that when `reorder` is True, the unstable algorithm of torch.argsort is"
" used internally to sort 'x' which may in some cases cause inaccuracies"
" in the result.",
DeprecationWarning)
# can't use lexsort here since it is not implemented for torch
order = torch.argsort(x)
x, y = x[order], y[order]
dx = x[1:] - x[:-1]
if (dx < 0).any():
if (dx <= 0).all():
direction = -1.
else:
raise ValueError(f"The 'x' array is neither increasing or decreasing: {x}. Reorder is not supported.")
else:
dx = x[1:] - x[:-1]
if (dx < 0).any():
if (dx, 0).all():
direction = -1.
else:
# TODO: Update message on removing reorder
raise ValueError("Reorder is not turned on, and the 'x' array is"
f" neither increasing or decreasing: {x}")

direction = 1.
return direction * torch.trapz(y, x)


def auc_decorator(reorder: bool = True) -> Callable:
def auc_decorator() -> Callable:
def wrapper(func_to_decorate: Callable) -> Callable:
@wraps(func_to_decorate)
def new_func(*args, **kwargs) -> torch.Tensor:
x, y = func_to_decorate(*args, **kwargs)[:2]

return auc(x, y, reorder=reorder)
return auc(x, y)

return new_func

return wrapper


def multiclass_auc_decorator(reorder: bool = True) -> Callable:
def multiclass_auc_decorator() -> Callable:
def wrapper(func_to_decorate: Callable) -> Callable:
@wraps(func_to_decorate)
def new_func(*args, **kwargs) -> torch.Tensor:
results = []
for class_result in func_to_decorate(*args, **kwargs):
x, y = class_result[:2]
results.append(auc(x, y, reorder=reorder))
results.append(auc(x, y))

return torch.stack(results)

Expand Down Expand Up @@ -587,7 +570,7 @@ def auroc(
' target tensor contains value different from 0 and 1.'
' Use `multiclass_auroc` for multi class classification.')

@auc_decorator(reorder=True)
@auc_decorator()
def _auroc(pred, target, sample_weight, pos_label):
return __roc(pred, target, sample_weight, pos_label)

Expand Down Expand Up @@ -640,7 +623,7 @@ def multiclass_auroc(
f"Number of classes deduced from 'pred' ({pred.size(1)}) does not equal"
f" the number of classes passed in 'num_classes' ({num_classes}).")

@multiclass_auc_decorator(reorder=False)
@multiclass_auc_decorator()
def _multiclass_auroc(pred, target, sample_weight, num_classes):
return __multiclass_roc(pred, target, sample_weight, num_classes)

Expand Down
5 changes: 0 additions & 5 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,3 @@ def test_dataloader(self):

def test_end(self, outputs):
return {'test_loss': torch.tensor(0.7)}


def test_reorder_remove_in_v1_1():
with pytest.deprecated_call(match='The `reorder` parameter to `auc` has been deprecated'):
_ = auc(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 2]), reorder=True)