Skip to content

Commit

Permalink
v0.2 (#4)
Browse files Browse the repository at this point in the history
Removed torch c++ dependency as it wasn't used and replaced it with pybind11 numpy interface
package size is now 10 times less and calculation is 50% faster
  • Loading branch information
MahmoudAshraf97 authored May 29, 2024
1 parent 0aa47ff commit e2384a8
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 217 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/test_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest ninja wheel setuptools
python -m pip install flake8 pytest
# This is to avoid installing cuda dependencies
python -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu
python -m pip install -r requirements.txt
- name: Lint with flake8
run: |
Expand All @@ -44,7 +44,7 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Install package
run: |
python -m pip install -e . --no-build-isolation
python -m pip install -e .
- name: Test importing compiled extension
run: |
python -c "from ctc_forced_aligner import forced_align"
2 changes: 1 addition & 1 deletion ctc_forced_aligner/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def cli():
parser.add_argument(
"--compute_dtype",
type=str,
default="float32",
default="float16" if torch.cuda.is_available() else "float32",
choices=["bfloat16", "float16", "float32"],
help="Compute dtype for alignment model inference. Helps with speed and memory usage.",
)
Expand Down
95 changes: 33 additions & 62 deletions ctc_forced_aligner/alignment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,10 @@
__version__ as transformers_version,
)
from transformers.utils import is_flash_attn_2_available

try:
from .ctc_forced_aligner import forced_align as forced_align_cpp
except Exception as e:
if all(
substring in e.__repr__()
for substring in ["ctc_forced_aligner", "undefined symbol"]
):
raise ImportError(
"ctc-forced-aligner package was build using a different version of "
"torch than the one currently installed, reinstall the package again using: \n"
"pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git --force-reinstall --no-deps"
)
else:
raise e

from .ctc_forced_aligner import forced_align as forced_align_cpp
from typing import Optional, Tuple
from packaging import version
import numpy as np

SAMPLING_FREQ = 16000

Expand Down Expand Up @@ -169,35 +155,30 @@ def generate_emissions(


def forced_align(
log_probs: torch.Tensor,
targets: torch.Tensor,
input_lengths: Optional[torch.Tensor] = None,
target_lengths: Optional[torch.Tensor] = None,
log_probs: np.ndarray,
targets: np.ndarray,
input_lengths: Optional[np.ndarray] = None,
target_lengths: Optional[np.ndarray] = None,
blank: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[np.ndarray, np.ndarray]:
r"""Align a CTC label sequence to an emission.
.. devices:: CPU CUDA
.. properties:: TorchScript
Args:
log_probs (Tensor): log probability of CTC emission output.
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
log_probs (NDArray): log probability of CTC emission output.
NDArray of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
`C` is the number of characters in alphabet including blank.
targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
targets (NDArray): Target sequence. NDArray of shape `(B, L)`,
where `L` is the target length.
input_lengths (Tensor or None, optional):
Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
target_lengths (Tensor or None, optional):
Lengths of the targets. 1-D Tensor of shape `(B,)`.
input_lengths (NDArray or None, optional):
Lengths of the inputs (max value must each be <= `T`). 1-D NDArray of shape `(B,)`.
target_lengths (NDArray or None, optional):
Lengths of the targets. 1-D NDArray of shape `(B,)`.
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
Returns:
Tuple(Tensor, Tensor):
Tensor: Label for each time step in the alignment path computed using forced alignment.
Tuple(NDArray, NDArray):
NDArray: Label for each time step in the alignment path computed using forced alignment.
Tensor: Log probability scores of the labels for each time step.
NDArray: Log probability scores of the labels for each time step.
Note:
The sequence length of `log_probs` must satisfy:
Expand All @@ -216,26 +197,16 @@ def forced_align(
raise ValueError(
f"targets Tensor shouldn't contain blank index. Found {targets}."
)
if torch.max(targets) >= log_probs.shape[-1]:
raise ValueError("targets values must be less than the CTC dimension")

if input_lengths is None:
batch_size, length = log_probs.size(0), log_probs.size(1)
input_lengths = torch.full(
(batch_size,), length, dtype=torch.int64, device=log_probs.device
)
if target_lengths is None:
batch_size, length = targets.size(0), targets.size(1)
target_lengths = torch.full(
(batch_size,), length, dtype=torch.int64, device=targets.device
)

# For TorchScript compatibility
assert input_lengths is not None
assert target_lengths is not None
if blank >= log_probs.shape[-1] or blank < 0:
raise ValueError("blank must be within [0, log_probs.shape[-1])")
if np.max(targets) >= log_probs.shape[-1] and np.min(targets) >= 0:
raise ValueError("targets values must be within [0, log_probs.shape[-1])")
assert log_probs.dtype == np.float32, "log_probs must be float32"

paths, scores = forced_align_cpp(
log_probs, targets, input_lengths, target_lengths, blank
log_probs,
targets,
blank,
)
return paths, scores

Expand All @@ -256,15 +227,11 @@ def get_alignments(
blank_id = dictionary.get("<pad>") if blank_id is None else blank_id
if emissions.is_cuda:
emissions = emissions.cpu()
targets = torch.tensor(token_indices, dtype=torch.int32)
targets = np.asarray([token_indices], dtype=np.int64)

input_lengths = torch.tensor(emissions.shape[0]).unsqueeze(-1)
target_lengths = torch.tensor(targets.shape[0]).unsqueeze(-1)
path, _ = forced_align(
emissions.unsqueeze(0).float(),
targets.unsqueeze(0),
input_lengths,
target_lengths,
emissions.unsqueeze(0).float().numpy(),
targets,
blank=blank_id,
)
path = path.squeeze().tolist()
Expand All @@ -282,7 +249,11 @@ def load_alignment_model(
if attn_implementation is None:
if version.parse(transformers_version) < version.parse("4.41.0"):
attn_implementation = "eager"
elif is_flash_attn_2_available():
elif (
is_flash_attn_2_available()
and device == "cuda"
and dtype in [torch.float16, torch.bfloat16]
):
attn_implementation = "flash_attention_2"
else:
attn_implementation = "sdpa"
Expand Down
Loading

0 comments on commit e2384a8

Please sign in to comment.