diff --git a/CHANGELOG.md b/CHANGELOG.md index c63e65bea..f52704595 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ - BREAKING(pipeline): remove `logging_hook` (use `ArtifactHook` instead) - fix(pipeline): add missing "embedding" hook call in `SpeakerDiarization` - feat(utils): add `"soft"` option to `Powerset.to_multilabel` +- improve(pipeline): compute `fbank` on GPU when requested ## Version 3.0.1 (2023-09-28) diff --git a/pyannote/audio/pipelines/speaker_verification.py b/pyannote/audio/pipelines/speaker_verification.py index 6b39679dd..8a36fc70c 100644 --- a/pyannote/audio/pipelines/speaker_verification.py +++ b/pyannote/audio/pipelines/speaker_verification.py @@ -556,6 +556,7 @@ def compute_fbank( for waveform in waveforms ] ) + return features - torch.mean(features, dim=1, keepdim=True) def __call__( @@ -578,12 +579,12 @@ def __call__( batch_size, num_channels, num_samples = waveforms.shape assert num_channels == 1 - features = self.compute_fbank(waveforms) + features = self.compute_fbank(waveforms.to(self.device)) _, num_frames, _ = features.shape if masks is None: embeddings = self.session_.run( - output_names=["embs"], input_feed={"feats": features.numpy()} + output_names=["embs"], input_feed={"feats": features.numpy(force=True)} )[0] return embeddings @@ -606,7 +607,7 @@ def __call__( embeddings[f] = self.session_.run( output_names=["embs"], - input_feed={"feats": masked_feature.numpy()[None]}, + input_feed={"feats": masked_feature.numpy(force=True)[None]}, )[0][0] return embeddings