Skip to content

Commit

Permalink
fix: incorrect typing of mic name arg
Browse files Browse the repository at this point in the history
  • Loading branch information
m-barker committed Dec 5, 2023
1 parent e931db0 commit e1d7069
Showing 1 changed file with 19 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,24 @@ from lasr_speech_recognition_whisper import load_model # type: ignore

@dataclass
class speech_model_params:
"""Class for storing speech recognition model parameters."""
"""Class for storing speech recognition model parameters.
Args:
model_name (str, optional): Name of the speech recognition model. Defaults to "medium.en".
Must be a valid Whisper model name.
device (str, optional): Device to run the model on. Defaults to "cuda" if available, otherwise "cpu".
start_timeout (float): Max number of seconds of silence when starting listening before stopping. Defaults to 5.0.
end_timeout (Optional[float]): Max number of seconds of silence after starting listening before stopping. Defaults to None.
sample_rate (int): Sample rate of the microphone. Defaults to 16000Hz.
mic_device (Optional[str]): Microphone device index or name. Defaults to None.
"""

model_name: str = "medium.en"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
start_timeout: float = 5.0
end_timeout: Optional[float] = None
sample_rate: int = 16000
mic_device: Optional[Union[int, str]] = None
mic_device: Optional[str] = None


class TranscribeSpeechAction(object):
Expand Down Expand Up @@ -66,15 +76,15 @@ class TranscribeSpeechAction(object):
Returns: microphone object
"""

if (
isinstance(self._model_params.mic_device, int)
or self._model_params.mic_device.isdigit()
):
if self._model_params.mic_device is None:
# If no microphone device is specified, use the system default microphone
return sr.Microphone(sample_rate=self._model_params.sample_rate)
elif self._model_params.mic_device.isdigit():
return sr.Microphone(
device_index=int(self._model_params.mic_device),
sample_rate=self._model_params.sample_rate,
)
elif isinstance(self._model_params.mic_device, str):
else:
microphones = enumerate(sr.Microphone.list_microphone_names())
for index, name in microphones:
if self._model_params.mic_device in name:
Expand All @@ -85,8 +95,6 @@ class TranscribeSpeechAction(object):
raise ValueError(
f"Could not find microphone with name: {self._model_params.mic_device}"
)
# If no microphone device is specified, use the system default microphone
return sr.Microphone(sample_rate=self._model_params.sample_rate)

def _configure_recogniser(self, ambient_adj: bool = True) -> sr.Recognizer:
"""Configures the speech recogniser object.
Expand Down Expand Up @@ -186,8 +194,9 @@ def parse_args() -> dict:
)
parser.add_argument(
"--mic_device",
type=str,
default=None,
help="Microphone device index or name. Can be a string or an integer.",
help="Microphone device index or name",
)

return vars(parser.parse_args())
Expand Down

0 comments on commit e1d7069

Please sign in to comment.