Skip to content

Commit

Permalink
Merge pull request ASUS-AICS#343 from ntumlgroup/update_pkgs
Browse files Browse the repository at this point in the history
Update packages: torch, torchmetrics, lightning
  • Loading branch information
Eleven1Liu authored Feb 10, 2024
2 parents 10819d0 + 5947a68 commit 1b7779b
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 191 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ This is an on-going development so many improvements are still being made. Comme

## Environments
- Python: 3.8+
- CUDA: 11.6 (if training neural networks by GPU)
- Pytorch 1.13.1
- CUDA: 11.8, 12.1 (if training neural networks by GPU)
- Pytorch: 2.0.1+

If you have a different version of CUDA, follow the installation instructions for PyTorch LTS at their [website](https://pytorch.org/).

Expand Down
4 changes: 2 additions & 2 deletions docs/cli/ov_data_format.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ Install LibMultiLabel from Source
* Environment

* Python: 3.8+
* CUDA: 11.6 (if training neural networks by GPU)
* Pytorch 1.13.1
* CUDA: 11.8, 12.1 (if training neural networks by GPU)
* Pytorch 2.0.1+

It is optional but highly recommended to
create a virtual environment.
Expand Down
69 changes: 37 additions & 32 deletions libmultilabel/linear/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _argsort_top_k(preds: np.ndarray, top_k: int) -> np.ndarray:
return np.take_along_axis(top_k_idx, argsort_top_k, axis=-1)


def _DCG_argsort(argsort_preds: np.ndarray, target: np.ndarray, top_k: int) -> np.ndarray:
def _dcg_argsort(argsort_preds: np.ndarray, target: np.ndarray, top_k: int) -> np.ndarray:
"""Computes DCG@k with a sorted preds array and a target array."""
top_k_idx = argsort_preds[:, -top_k:][:, ::-1]
gains = np.take_along_axis(target, top_k_idx, axis=-1)
Expand All @@ -27,7 +27,7 @@ def _DCG_argsort(argsort_preds: np.ndarray, target: np.ndarray, top_k: int) -> n
return dcgs


def _IDCG(target: np.ndarray, top_k: int) -> np.ndarray:
def _idcg(target: np.ndarray, top_k: int) -> np.ndarray:
"""Computes IDCG@k for a 0/1 target array. A 0/1 target is a special case that
doesn't require sorting. If IDCG is computed with DCG,
then target will need to be sorted, which incurs a large overhead.
Expand All @@ -43,10 +43,13 @@ def _IDCG(target: np.ndarray, top_k: int) -> np.ndarray:
return cum_discount[idx]


class NDCG:
def __init__(self, top_k: int):
"""Compute the normalized DCG@k (nDCG@k).
class NDCGAtK:
"""Compute the normalized DCG@k (nDCG@k). Please refer to the `implementation document`
(https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
"""

def __init__(self, top_k: int):
"""
Args:
top_k: Consider only the top k elements for each query.
"""
Expand All @@ -61,8 +64,8 @@ def update(self, preds: np.ndarray, target: np.ndarray):
return self.update_argsort(_argsort_top_k(preds, self.top_k), target)

def update_argsort(self, argsort_preds: np.ndarray, target: np.ndarray):
dcg = _DCG_argsort(argsort_preds, target, self.top_k)
idcg = _IDCG(target, self.top_k)
dcg = _dcg_argsort(argsort_preds, target, self.top_k)
idcg = _idcg(target, self.top_k)
ndcg_score = dcg / idcg
# by convention, ndcg is 0 for zero label instances
self.score += np.nan_to_num(ndcg_score, nan=0.0).sum()
Expand All @@ -76,10 +79,13 @@ def reset(self):
self.num_sample = 0


class RPrecision:
def __init__(self, top_k: int):
"""Compute the R-Precision@K.
class RPrecisionAtK:
"""Compute the R-Precision@K. Please refer to the `implementation document`
(https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
"""

def __init__(self, top_k: int):
"""
Args:
top_k: Consider only the top k elements for each query.
"""
Expand Down Expand Up @@ -108,18 +114,16 @@ def reset(self):
self.num_sample = 0


class Precision:
def __init__(self, num_classes: int, average: str, top_k: int):
"""Compute the Precision@K.
class PrecisionAtK:
"""Compute the Precision@K. Please refer to the `implementation document`
(https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
"""

def __init__(self, top_k: int):
"""
Args:
num_classes: The number of classes.
average: Define the reduction that is applied over labels. Currently only "samples" is supported.
top_k: Consider only the top k elements for each query.
"""
if average != "samples":
raise ValueError("unsupported average")

_check_top_k(top_k)

self.top_k = top_k
Expand All @@ -144,18 +148,16 @@ def reset(self):
self.num_sample = 0


class Recall:
def __init__(self, num_classes: int, average: str, top_k: int):
"""Compute the Recall@K.
class RecallAtK:
"""Compute the Recall@K. Please refer to the `implementation document`
(https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
"""

def __init__(self, top_k: int):
"""
Args:
num_classes: The number of classes.
average: Define the reduction that is applied over labels. Currently only "samples" is supported.
top_k: Consider only the top k elements for each query.
"""
if average != "samples":
raise ValueError("unsupported average")

_check_top_k(top_k)

self.top_k = top_k
Expand All @@ -182,9 +184,12 @@ def reset(self):


class F1:
def __init__(self, num_classes: int, average: str, multiclass=False):
"""Compute the F1 score.
"""Compute the F1 score. Please refer to the `implementation document`
(https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
"""

def __init__(self, num_classes: int, average: str, multiclass=False):
"""
Args:
num_classes: The number of labels.
average: Define the reduction that is applied over labels. Should be one of "macro", "micro",
Expand Down Expand Up @@ -296,13 +301,13 @@ def get_metrics(monitor_metrics: list[str], num_classes: int, multiclass: bool =
metrics = {}
for metric in monitor_metrics:
if re.match("P@\d+", metric):
metrics[metric] = Precision(num_classes, average="samples", top_k=int(metric[2:]))
metrics[metric] = PrecisionAtK(top_k=int(metric[2:]))
elif re.match("R@\d+", metric):
metrics[metric] = Recall(num_classes, average="samples", top_k=int(metric[2:]))
metrics[metric] = RecallAtK(top_k=int(metric[2:]))
elif re.match("RP@\d+", metric):
metrics[metric] = RPrecision(top_k=int(metric[3:]))
metrics[metric] = RPrecisionAtK(top_k=int(metric[3:]))
elif re.match("NDCG@\d+", metric):
metrics[metric] = NDCG(top_k=int(metric[5:]))
metrics[metric] = NDCGAtK(top_k=int(metric[5:]))
elif metric in {"Another-Macro-F1", "Macro-F1", "Micro-F1"}:
metrics[metric] = F1(num_classes, average=metric[:-3].lower(), multiclass=multiclass)
else:
Expand Down
Loading

0 comments on commit 1b7779b

Please sign in to comment.