Skip to content

Commit

Permalink
fix(pipeline): compute fbank on selected device (#1529)
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin authored Nov 7, 2023
1 parent 0b45103 commit 40fa67b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- BREAKING(pipeline): remove `logging_hook` (use `ArtifactHook` instead)
- fix(pipeline): add missing "embedding" hook call in `SpeakerDiarization`
- feat(utils): add `"soft"` option to `Powerset.to_multilabel`
- improve(pipeline): compute `fbank` on GPU when requested

## Version 3.0.1 (2023-09-28)

Expand Down
7 changes: 4 additions & 3 deletions pyannote/audio/pipelines/speaker_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def compute_fbank(
for waveform in waveforms
]
)

return features - torch.mean(features, dim=1, keepdim=True)

def __call__(
Expand All @@ -578,12 +579,12 @@ def __call__(
batch_size, num_channels, num_samples = waveforms.shape
assert num_channels == 1

features = self.compute_fbank(waveforms)
features = self.compute_fbank(waveforms.to(self.device))
_, num_frames, _ = features.shape

if masks is None:
embeddings = self.session_.run(
output_names=["embs"], input_feed={"feats": features.numpy()}
output_names=["embs"], input_feed={"feats": features.numpy(force=True)}
)[0]

return embeddings
Expand All @@ -606,7 +607,7 @@ def __call__(

embeddings[f] = self.session_.run(
output_names=["embs"],
input_feed={"feats": masked_feature.numpy()[None]},
input_feed={"feats": masked_feature.numpy(force=True)[None]},
)[0][0]

return embeddings
Expand Down

0 comments on commit 40fa67b

Please sign in to comment.