diff --git a/nni/algorithms/nas/pytorch/cdarts/utils.py b/nni/algorithms/nas/pytorch/cdarts/utils.py index 780f6fdc0e..96afa94256 100644 --- a/nni/algorithms/nas/pytorch/cdarts/utils.py +++ b/nni/algorithms/nas/pytorch/cdarts/utils.py @@ -58,7 +58,7 @@ def accuracy(output, target, topk=(1,)): res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) + correct_k = correct[:k].reshape(-1).float().sum(0) res.append(correct_k.mul_(1.0 / batch_size)) return res