Skip to content

Commit

Permalink
function descriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicki Skafte authored and Borda committed Jun 10, 2020
1 parent 4bbf9a7 commit b605b2d
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@

def to_onehot(tensor: torch.Tensor,
n_classes: Optional[int] = None) -> torch.Tensor:
""" Converts a dense label tensor to one-hot format
Args:
tensor: dense label tensor, with shape [N, d1, d2, ...]
n_classes: number of classes C
Output:
A sparse label tensor with shape [N, C, d1, d2, ...]
"""
if n_classes is None:
n_classes = int(tensor.max().detach().item() + 1)
dtype, device, shape = tensor.dtype, tensor.device, tensor.shape
Expand All @@ -19,11 +29,15 @@ def to_onehot(tensor: torch.Tensor,


def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
""" Converts a tensor of probabilities to a dense label tensor """
return torch.argmax(tensor, dim=argmax_dim)


def get_num_classes(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int]) -> int:
""" Returns the number of classes for a given prediction and
target tensor
"""
if num_classes is None:
if pred.ndim > target.ndim:
num_classes = pred.size(1)
Expand All @@ -36,6 +50,20 @@ def stat_scores(pred: torch.Tensor, target: torch.Tensor,
class_index: int, argmax_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
""" Calculates the number of true postive, false postive, true negative
and false negative for a specfic class
Args:
pred: prediction tensor
target: target tensor
class_index: class to calculate over
argmax_dim: if pred is a tensor of probabilities, this indicates the
axis the argmax transformation will be applied over
"""
if pred.ndim == target.ndim + 1:
pred = to_categorical(pred, argmax_dim=argmax_dim)

Expand All @@ -52,6 +80,21 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor,
argmax_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
""" Calls the stat_scores function iteratively for all classes, thus
calculating the number of true postive, false postive, true negative
and false negative for each class
Args:
pred: prediction tensor
target: target tensor
class_index: class to calculate over
argmax_dim: if pred is a tensor of probabilities, this indicates the
axis the argmax transformation will be applied over
"""
num_classes = get_num_classes(pred=pred, target=target,
num_classes=num_classes)

Expand Down

0 comments on commit b605b2d

Please sign in to comment.