Skip to content

Commit

Permalink
fix(task): fix corner case with small (<9) number of validation samples
Browse files Browse the repository at this point in the history
  • Loading branch information
antoinelaurent authored Nov 21, 2024
1 parent 7d84f61 commit 385eba6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- fix(separation): fix clipping issue in speech separation pipeline ([@joonaskalda](https://github.com/joonaskalda/))
- fix(separation): fix alignment between separated sources and diarization ([@Lebourdais](https://github.com/Lebourdais/) and [@clement-pages](https://github.com/clement-pages/))
- fix(doc): fix link to pytorch ([@emmanuel-ferdman](https://github.com/emmanuel-ferdman/))
- fix(task): fix corner case with small (<9) number of validation samples ([@antoinelaurent](https://github.com/antoinelaurent/))

## Version 3.3.2 (2024-09-11)

Expand Down
3 changes: 2 additions & 1 deletion pyannote/audio/tasks/segmentation/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,8 @@ def validation_step(self, batch, batch_idx: int):
y_pred = multilabel.cpu().numpy()

# prepare 3 x 3 grid (or smaller if batch size is smaller)
num_samples = min(self.batch_size, 9)
num_samples = min(self.batch_size, 9, y.shape[0])

nrows = math.ceil(math.sqrt(num_samples))
ncols = math.ceil(num_samples / nrows)
fig, axes = plt.subplots(
Expand Down

0 comments on commit 385eba6

Please sign in to comment.