Skip to content

Commit

Permalink
api/docs modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed Dec 17, 2021
1 parent cf32cb4 commit 5981015
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 21 deletions.
12 changes: 6 additions & 6 deletions torchaudio/csrc/decoder/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Flashlight Decoder Binding
CTC Decoder with KenLM and lexicon support based on [flashlight](https://github.com/flashlight/flashlight) decoder implementation
and fairseq [KenLMDecoder](https://github.com/pytorch/fairseq/blob/fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1/examples/speech_recognition/new/decoders/flashlight_decoder.py#L53)
and fairseq [KenLMDecoder](https://github.com/pytorch/fairseq/blob/fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1/examples/speech_recognition/new/decoders/flashlight_decoder.py#L53)
Python wrapper

## Setup
### Build KenLM
- Install KenLM following the instructions [here](https://github.com/kpu/kenlm#compiling)
- Install KenLM in your audio directory following the instructions [here](https://github.com/kpu/kenlm#compiling)
- set `KENLM_ROOT` variable to the KenLM installation path
### Build torchaudio with decoder support
```
Expand All @@ -17,7 +17,7 @@ BUILD_CTC_DECODER=1 python setup.py develop
from torchaudio.prototype import kenlm_lexicon_decoder
decoder = kenlm_lexicon_decoder(args...)
results = decoder(emissions) # dim (B, nbest) of dictionary of "tokens", "score", "words" keys
best_transcript = " ".join(results[0][0]["words"]).strip()
best_transcripts = [" ".join(results[i][0].words).strip() for i in range(B)]
```

## Required Files
Expand All @@ -26,11 +26,11 @@ best_transcript = " ".join(results[0][0]["words"]).strip()
- language model: n-gram KenLM model

## Experiment Results
LibriSpeech dev-other and test-other results using pretrained [Wav2Vec2](https://arxiv.org/pdf/2006.11477.pdf) models of
BASE configuration.
LibriSpeech dev-other and test-other results using pretrained [Wav2Vec2](https://arxiv.org/pdf/2006.11477.pdf) models of
BASE configuration.

| Model | Decoder | dev-other | test-other | beam search params |
| ----------- | ---------- | ----------- | ---------- | ------------------------------------------- |
| ----------- | ---------- | ----------- | ---------- |-------------------------------------------- |
| BASE_10M | Greedy | 51.6 | 51 | |
| | 4-gram LM | 15.95 | 15.9 | LM weight=3.23, word score=-0.26, beam=1500 |
| BASE_100H | Greedy | 13.6 | 13.3 | |
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/csrc/decoder/src/decoder/lm/KenLM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

#include <stdexcept>

#include <kenlm/lm/model.hh>
#include "kenlm/lm/model.hh"

namespace torchaudio {
namespace lib {
Expand Down
6 changes: 6 additions & 0 deletions torchaudio/prototype/ctc_decoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .ctc_decoder import KenLMLexiconDecoder, kenlm_lexicon_decoder

__all__ = [
"KenLMLexiconDecoder",
"kenlm_lexicon_decoder",
]
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import itertools as it
from typing import List, Optional, Dict
from collections import namedtuple

import torchaudio

Expand All @@ -26,6 +27,8 @@
__all__ = ["KenLMLexiconDecoder", "kenlm_lexicon_decoder"]


Hypothesis = namedtuple("Hypothesis", ["tokens", "words", "score"])

class KenLMLexiconDecoder:
def __init__(
self,
Expand All @@ -38,9 +41,11 @@ def __init__(
blank_token: str,
sil_token: str,
) -> None:

"""
Construct a KenLM CTC Lexcion Decoder.
KenLM CTC Decoder with Lexicon constraint.
Note:
To build the decoder, please use the factory function kenlm_lexicon_decoder.
Args:
nbest (int): number of best decodings to return
Expand Down Expand Up @@ -107,13 +112,13 @@ def decode(
in time axis of the output Tensor in each batch
Returns:
List[List[Dict]]:
List[Hypothesis]:
List of sorted best hypotheses for each audio sequence in the batch.
Each hypothesis is dictionary with the following mapping:
"tokens": torch.LongTensor of raw token IDs
"score": hypothesis score
"words": list of decoded words
Each hypothesis is named tuple with the following fields:
tokens: torch.LongTensor of raw token IDs
score: hypothesis score
words: list of decoded words
"""
B, T, N = emissions.size()
if lengths is None:
Expand All @@ -128,13 +133,11 @@ def decode(
nbest_results = results[: self.nbest]
hypos.append(
[
{
"tokens": self._get_tokens(result.tokens),
"score": result.score,
"words": [
self.word_dict.get_entry(x) for x in result.words if x >= 0
]
}
Hypothesis(
self._get_tokens(result.tokens), # token ids
list(self.word_dict.get_entry(x) for x in result.words if x >= 0), # words
result.score, # score
)
for result in nbest_results
]
)
Expand Down

0 comments on commit 5981015

Please sign in to comment.