Skip to content

Commit

Permalink
[scripts,src] Fix potential issue in scripts; minor fixes. (#2724)
Browse files Browse the repository at this point in the history
The use of split() in latin-1 encoding (which might be used for other ASCII-compatible encoded data like utf-8) is not right because character 160 (expressed here in decimal) is a NBSP in latin-8 encoding and is also in the range UTF-8 uses for encoding. The same goes for strip().  Thanks @ChunChiehChang for finding the issue.
  • Loading branch information
danpovey authored Sep 19, 2018
1 parent 9b9196b commit 69cd717
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 82 deletions.
1 change: 0 additions & 1 deletion egs/wsj/s5/steps/train_sat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -276,4 +276,3 @@ steps/info/gmm_dir_info.pl $dir
echo "$0: done training SAT system in $dir"

exit 0

82 changes: 47 additions & 35 deletions egs/wsj/s5/utils/apply_map.pl
Original file line number Diff line number Diff line change
Expand Up @@ -9,47 +9,59 @@
# be sequences of tokens. See the usage message.


if (@ARGV > 0 && $ARGV[0] eq "-f") {
shift @ARGV;
$field_spec = shift @ARGV;
if ($field_spec =~ m/^\d+$/) {
$field_begin = $field_spec - 1; $field_end = $field_spec - 1;
}
if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10)
if ($1 ne "") {
$field_begin = $1 - 1; # Change to zero-based indexing.
$permissive = 0;

for ($x = 0; $x <= 2; $x++) {

if (@ARGV > 0 && $ARGV[0] eq "-f") {
shift @ARGV;
$field_spec = shift @ARGV;
if ($field_spec =~ m/^\d+$/) {
$field_begin = $field_spec - 1; $field_end = $field_spec - 1;
}
if ($2 ne "") {
$field_end = $2 - 1; # Change to zero-based indexing.
if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10)
if ($1 ne "") {
$field_begin = $1 - 1; # Change to zero-based indexing.
}
if ($2 ne "") {
$field_end = $2 - 1; # Change to zero-based indexing.
}
}
if (!defined $field_begin && !defined $field_end) {
die "Bad argument to -f option: $field_spec";
}
}
if (!defined $field_begin && !defined $field_end) {
die "Bad argument to -f option: $field_spec";
}
}

# Mapping is obligatory
$permissive = 0;
if (@ARGV > 0 && $ARGV[0] eq '--permissive') {
shift @ARGV;
# Mapping is optional (missing key is printed to output)
$permissive = 1;
if (@ARGV > 0 && $ARGV[0] eq '--permissive') {
shift @ARGV;
# Mapping is optional (missing key is printed to output)
$permissive = 1;
}
}

if(@ARGV != 1) {
print STDERR "Invalid usage: " . join(" ", @ARGV) . "\n";
print STDERR "Usage: apply_map.pl [options] map <input >output\n" .
"options: [-f <field-range> ]\n" .
"Applies the map 'map' to all input text, where each line of the map\n" .
"is interpreted as a map from the first field to the list of the other fields\n" .
"Note: <field-range> can look like 4-5, or 4-, or 5-, or 1, it means the field\n" .
"range in the input to apply the map to.\n" .
"e.g.: echo A B | apply_map.pl a.txt\n" .
"where a.txt is:\n" .
"A a1 a2\n" .
"B b\n" .
"will produce:\n" .
"a1 a2 b\n";
print STDERR <<'EOF';
Usage: apply_map.pl [options] map <input >output
options: [-f <field-range> ] [--permissive]
This applies a map to some specified fields of some input text:
For each line in the map file: the first field is the thing wae
map from, and the remaining fields are the sequence we map it to.
The -f (field-range) option says which fields of the input file the map
map should apply to.
If the --permissive option is supplied, fields which are not present
in the map will be left as they were.
Applies the map 'map' to all input text, where each line of the map
is interpreted as a map from the first field to the list of the other fields
Note: <field-range> can look like 4-5, or 4-, or 5-, or 1, it means the field
range in the input to apply the map to.
e.g.: echo A B | apply_map.pl a.txt
where a.txt is:
A a1 a2
B b
will produce:
a1 a2 b
EOF
exit(1);
}

Expand All @@ -72,12 +84,12 @@
$a = $A[$x];
if (!defined $map{$a}) {
if (!$permissive) {
die "apply_map.pl: undefined key $a in $map_file\n";
die "apply_map.pl: undefined key $a in $map_file\n";
} else {
print STDERR "apply_map.pl: warning! missing key $a in $map_file\n";
}
} else {
$A[$x] = $map{$a};
$A[$x] = $map{$a};
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions egs/wsj/s5/utils/data/perturb_speed_to_allowed_lengths.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class Utterance:
"""

def __init__(self, uid, wavefile, speaker, transcription, dur):
self.wavefile = (wavefile if wavefile.rstrip().endswith('|') else
self.wavefile = (wavefile if wavefile.rstrip(" \t\r\n").endswith('|') else
'cat {} |'.format(wavefile))
self.speaker = speaker
self.transcription = transcription
Expand Down Expand Up @@ -130,7 +130,7 @@ def read_kaldi_mapfile(path):
m = {}
with open(path, 'r', encoding='latin-1') as f:
for line in f:
line = line.strip()
line = line.strip(" \t\r\n")
sp_pos = line.find(' ')
key = line[:sp_pos]
val = line[sp_pos+1:]
Expand Down
10 changes: 6 additions & 4 deletions egs/wsj/s5/utils/lang/bpe/prepend_words.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
# the beginning of the words for finding the initial-space of every word
# after decoding.

import sys, io
import sys
import io
import re

whitespace = re.compile("[ \t]+")
infile = io.TextIOWrapper(sys.stdin.buffer, encoding='latin-1')
output = io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')
for line in infile:
output.write(' '.join([ "|"+word for word in line.split()]) + '\n')


words = whitespace.split(line.strip(" \t\r\n"))
output.write(' '.join([ "|"+word for word in words]) + '\n')
12 changes: 6 additions & 6 deletions egs/wsj/s5/utils/lang/compute_sentence_probs_arpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def compute_begin_prob(sub_list):
for i in range(1, len(sub_list) - 1):
logprob += compute_sublist_prob(sub_list[:i + 1])
return logprob

# The probability is computed in this way:
# p(word_N | word_N-1 ... word_1) = ngram_dict[word_1 ... word_N][0].
# Here gram_dict is a dictionary stores a tuple corresponding to ngrams.
# The first element of tuple is probablity and the second is backoff probability (if exists).
# If the particular ngram (word_1 ... word_N) is not in the dictionary, then
# p(word_N | word_N-1 ... word_1) = p(word_N | word_(N-1) ... word_2) * backoff_weight(word_(N-1) | word_(N-2) ... word_1)
# p(word_N | word_N-1 ... word_1) = p(word_N | word_(N-1) ... word_2) * backoff_weight(word_(N-1) | word_(N-2) ... word_1)
# If the sequence (word_(N-1) ... word_1) is not in the dictionary, then the backoff_weight gets replaced with 0.0 (log1)
# More details can be found in https://cmusphinx.github.io/wiki/arpaformat/
def compute_sentence_prob(sentence, ngram_order):
Expand All @@ -127,7 +127,7 @@ def compute_sentence_prob(sentence, ngram_order):
logprob += compute_sublist_prob(cur_sublist)

return logprob


def output_result(text_in_handle, output_file_handle, ngram_order):
lines = text_in_handle.readlines()
Expand All @@ -139,8 +139,8 @@ def output_result(text_in_handle, output_file_handle, ngram_order):
output_file_handle.write("{}\n".format(new_logprob))
text_in_handle.close()
output_file_handle.close()


if __name__ == "__main__":
check_args(args)
ngram_dict, tot_num = load_model(args.arpa_lm)
Expand All @@ -149,7 +149,7 @@ def output_result(text_in_handle, output_file_handle, ngram_order):
if not num_valid:
sys.exit("compute_sentence_probs_arpa.py: Wrong loading model.")
if args.ngram_order <= 0 or args.ngram_order > max_ngram_order:
sys.exit("compute_sentence_probs_arpa.py: " +
sys.exit("compute_sentence_probs_arpa.py: " +
"Invalid ngram_order (either negative or greater than maximum ngram number ({}) allowed)".format(max_ngram_order))

output_result(args.text_in_handle, args.prob_file_handle, args.ngram_order)
10 changes: 6 additions & 4 deletions egs/wsj/s5/utils/lang/grammar/augment_phones_txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


import argparse
import re
import os
import sys

Expand Down Expand Up @@ -34,11 +35,12 @@ def read_phones_txt(filename):
# with utf-8 encoding as well as other encodings such as gbk, as long as the
# spaces are also spaces in ascii (which we check). It is basically how we
# emulate the behavior of python before python3.
whitespace = re.compile("[ \t]+")
with open(filename, 'r', encoding='latin-1') as f:
lines = [line.strip() for line in f]
lines = [line.strip(" \t\r\n") for line in f]
highest_numbered_symbol = 0
for line in lines:
s = line.split()
s = whitespace.split(line)
try:
i = int(s[1])
if i > highest_numbered_symbol:
Expand All @@ -57,9 +59,9 @@ def read_nonterminals(filename):
it has the expected format and has no duplicates, and returns the nonterminal
symbols as a list of strings, e.g.
['#nonterm:contact_list', '#nonterm:phone_number', ... ]. """
ans = [line.strip() for line in open(filename, 'r', encoding='latin-1')]
ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')]
if len(ans) == 0:
raise RuntimeError("The file {0} contains no nonterminals symbols.".format(filename))
raise RuntimeError("The file {0} contains no nonterminal symbols.".format(filename))
for nonterm in ans:
if nonterm[:9] != '#nonterm:':
raise RuntimeError("In file '{0}', expected nonterminal symbols to start with '#nonterm:', found '{1}'"
Expand Down
9 changes: 5 additions & 4 deletions egs/wsj/s5/utils/lang/grammar/augment_words_txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ def read_words_txt(filename):
# with utf-8 encoding as well as other encodings such as gbk, as long as the
# spaces are also spaces in ascii (which we check). It is basically how we
# emulate the behavior of python before python3.
whitespace = re.compile("[ \t]+")
with open(filename, 'r', encoding='latin-1') as f:
lines = [line.strip() for line in f]
lines = [line.strip(" \t\r\n") for line in f]
highest_numbered_symbol = 0
for line in lines:
s = line.split()
s = whitespace.split(line)
try:
i = int(s[1])
if i > highest_numbered_symbol:
Expand All @@ -58,9 +59,9 @@ def read_nonterminals(filename):
it has the expected format and has no duplicates, and returns the nonterminal
symbols as a list of strings, e.g.
['#nonterm:contact_list', '#nonterm:phone_number', ... ]. """
ans = [line.strip() for line in open(filename, 'r', encoding='latin-1')]
ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')]
if len(ans) == 0:
raise RuntimeError("The file {0} contains no nonterminals symbols.".format(filename))
raise RuntimeError("The file {0} contains no nonterminal symbols.".format(filename))
for nonterm in ans:
if nonterm[:9] != '#nonterm:':
raise RuntimeError("In file '{0}', expected nonterminal symbols to start with '#nonterm:', found '{1}'"
Expand Down
5 changes: 3 additions & 2 deletions egs/wsj/s5/utils/lang/limit_arpa_unk_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_ngram_stats(old_lm_lines):

def find_and_replace_unks(old_lm_lines, max_ngrams, skip_rows):
ngram_diffs = defaultdict(int)
whitespace_pattern = re.compile("[ \t]+")
unk_pattern = re.compile(
"[0-9.-]+(?:[\s\\t]\S+){1,3}[\s\\t]" + args.oov_dict_entry +
"[\s\\t](?!-[0-9]+\.[0-9]+).*")
Expand All @@ -70,7 +71,7 @@ def find_and_replace_unks(old_lm_lines, max_ngrams, skip_rows):
new_lm_lines = old_lm_lines[:skip_rows]

for i in range(skip_rows, len(old_lm_lines)):
line = old_lm_lines[i].strip()
line = old_lm_lines[i].strip(" \t\r\n")

if "\{}-grams:".format(3) in line:
passed_2grams = True
Expand Down Expand Up @@ -101,7 +102,7 @@ def find_and_replace_unks(old_lm_lines, max_ngrams, skip_rows):
if not last_ngram:
g_backoff = backoff_pattern.search(line)
if g_backoff:
updated_row = g_backoff.group(0).split()[:-1]
updated_row = whitespace_pattern.split(g_backoff.group(0))[:-1]
updated_row = updated_row[0] + \
"\t" + " ".join(updated_row[1:]) + "\n"
new_lm_lines.append(updated_row)
Expand Down
20 changes: 11 additions & 9 deletions egs/wsj/s5/utils/lang/make_lexicon_fst.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,28 +72,28 @@ def read_lexiconp(filename):
with open(filename, 'r', encoding='latin-1') as f:
whitespace = re.compile("[ \t]+")
for line in f:
a = whitespace.split(line.strip())
a = whitespace.split(line.strip(" \t\r\n"))
if len(a) < 2:
print("{0}: error: found bad line '{1}' in lexicon file {2} ".format(
sys.argv[0], line.strip(), filename), file=sys.stderr)
sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr)
sys.exit(1)
word = a[0]
if word == "<eps>":
# This would clash with the epsilon symbol normally used in OpenFst.
print("{0}: error: found <eps> as a word in lexicon file "
"{1}".format(line.strip(), filename), file=sys.stderr)
"{1}".format(line.strip(" \t\r\n"), filename), file=sys.stderr)
sys.exit(1)
try:
pron_prob = float(a[1])
except:
print("{0}: error: found bad line '{1}' in lexicon file {2}, 2nd field "
"should be pron-prob".format(sys.argv[0], line.strip(), filename),
"should be pron-prob".format(sys.argv[0], line.strip(" \t\r\n"), filename),
file=sys.stderr)
sys.exit(1)
prons = a[2:]
if pron_prob <= 0.0:
print("{0}: error: invalid pron-prob in line '{1}' of lexicon file {1} ".format(
sys.argv[0], line.strip(), filename), file=sys.stderr)
sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr)
sys.exit(1)
if len(prons) == 0:
found_empty_prons = True
Expand Down Expand Up @@ -324,7 +324,7 @@ def read_nonterminals(filename):
it has the expected format and has no duplicates, and returns the nonterminal
symbols as a list of strings, e.g.
['#nonterm:contact_list', '#nonterm:phone_number', ... ]. """
ans = [line.strip() for line in open(filename, 'r', encoding='latin-1')]
ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')]
if len(ans) == 0:
raise RuntimeError("The file {0} contains no nonterminals symbols.".format(filename))
for nonterm in ans:
Expand All @@ -338,11 +338,12 @@ def read_nonterminals(filename):
def read_left_context_phones(filename):
"""Reads, checks, and returns a list of left-context phones, in text form, one
per line. Returns a list of strings, e.g. ['a', 'ah', ..., '#nonterm_bos' ]"""
ans = [line.strip() for line in open(filename, 'r', encoding='latin-1')]
ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')]
if len(ans) == 0:
raise RuntimeError("The file {0} contains no left-context phones.".format(filename))
whitespace = re.compile("[ \t]+")
for s in ans:
if len(s.split()) != 1:
if len(whitespace.split(s)) != 1:
raise RuntimeError("The file {0} contains an invalid line '{1}'".format(filename, s) )

if len(set(ans)) != len(ans):
Expand All @@ -354,7 +355,8 @@ def is_token(s):
"""Returns true if s is a string and is space-free."""
if not isinstance(s, str):
return False
split_str = s.split()
whitespace = re.compile("[ \t\r\n]+")
split_str = whitespace.split(s);
return len(split_str) == 1 and s == split_str[0]


Expand Down
Loading

0 comments on commit 69cd717

Please sign in to comment.