diff --git a/example/speech_command/sc/evaluator.py b/example/speech_command/sc/evaluator.py index 33198593ce..d8cce60d77 100644 --- a/example/speech_command/sc/evaluator.py +++ b/example/speech_command/sc/evaluator.py @@ -97,7 +97,7 @@ def _post(self, input: torch.Tensor) -> t.Tuple[t.List[int], t.List[float]]: input = input.squeeze() pred_value = input.argmax(-1).item() probability_matrix = np.exp(input.tolist()).tolist() - return pred_value, probability_matrix[0] + return pred_value, probability_matrix def _load_model(self, device): model = M5(n_input=1, n_output=len(ALL_LABELS))