diff --git a/psifx/audio/diarization/pyannote/command.py b/psifx/audio/diarization/pyannote/command.py index 3707446..dd4cf5e 100644 --- a/psifx/audio/diarization/pyannote/command.py +++ b/psifx/audio/diarization/pyannote/command.py @@ -53,8 +53,8 @@ def setup(parser: argparse.ArgumentParser): parser.add_argument( "--model_name", type=str, - default="2.1.1", - help="version number of the pyannote/speaker-diarization model, c.f." + default="pyannote/speaker-diarization@2.1.1", + help="name of the diarization model used, c.f." " https://huggingface.co/pyannote/speaker-diarization/tree/main/reproducible_research", ) parser.add_argument( diff --git a/psifx/audio/diarization/pyannote/tool.py b/psifx/audio/diarization/pyannote/tool.py index 3dabde7..9d07a7e 100644 --- a/psifx/audio/diarization/pyannote/tool.py +++ b/psifx/audio/diarization/pyannote/tool.py @@ -18,7 +18,7 @@ class PyannoteDiarizationTool(DiarizationTool): def __init__( self, - model_name: str = "2.1.1", + model_name: str = "pyannote/speaker-diarization@2.1.1", api_token: Optional[str] = None, device: str = "cpu", overwrite: bool = False, @@ -34,7 +34,7 @@ def __init__( self.api_token = api_token self.model: Pipeline = Pipeline.from_pretrained( - checkpoint_path=f"pyannote/speaker-diarization@{model_name}", + checkpoint_path=model_name, use_auth_token=api_token, ).to(device=torch.device(device))