-
Notifications
You must be signed in to change notification settings - Fork 1
/
isometric_slt_stat.py
77 lines (64 loc) · 3.16 KB
/
isometric_slt_stat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: CC-BY-4.0
import argparse
import numpy as np
from constants import seg_token, verbosity_tokens, LC_RANGE
parser = argparse.ArgumentParser()
parser.add_argument('--source', '-s', required=True,
help='Source file to compute length compliance.')
parser.add_argument('--target', '-t', required=True,
help='Target file to compute length compliance.')
parser.add_argument('--level', choices=['tok', 'char'], default='char',
help='Granularity level to compute length compliance.')
def get_length(sample, replace_):
# sample character length after removing vc tokens
for r in replace_:
sample = sample.replace(r, '')
return len(sample)
def sample_length(sample, level, replace_):
if level == 'tok':
len_ = len(sample.split(' '))
elif args.level == 'char':
len_ = get_length(sample, replace_)
return len_
def compute_len_compliance(source_samples, target_samples, level, replace_, lc_range_threshold=10):
"""
Compute target to source length ratio and length compliance within a +-10% range, following
https://arxiv.org/abs/2112.08682 or see https://iwslt.org/2022/isometric#length-compliance-lc
:param source_samples: list of source sentences
:param target_samples: list of translation sentences
:param level: granularity level for computing the length metrics (default=character)
:param replace_: char/tokens to remove before counting the length
:return: mean of target to source length ratio and % of translations within +-LC_RANGE
"""
len_ratio = []
len_range = 0
for src, tgt in zip(source_samples, target_samples):
src_len, tgt_len = float(sample_length(src, level, replace_)), float(sample_length(tgt, level, replace_))
len_ratio.append(tgt_len/src_len)
# LC is computed for source sample length > 10 char.
if src_len <= 10 and level == "char":
len_range += 1
else:
len_diff = tgt_len - src_len
range_ = len_diff*100/src_len
if abs(range_) <= lc_range_threshold:
len_range += 1
len_ratio_mean, len_range_perc = np.mean(np.array(len_ratio)), len_range*100/len(len_ratio)
return len_ratio_mean, len_range_perc
def read_samples(file_):
with open(file_, 'r') as file_o:
samples = []
for sample in file_o:
samples.append(sample.strip())
return samples
if __name__ == '__main__':
args = parser.parse_args()
# tokens to filter out for char length count (see ./scripts/constants.py)
replace_ = verbosity_tokens + [seg_token, ' ']
source_samples, target_samples = read_samples(args.source), read_samples(args.target)
assert len(source_samples) == len(target_samples), "source to target sample size mismatch."
# get stat
length_ratio, length_range = compute_len_compliance(source_samples, target_samples,
args.level, replace_, lc_range_threshold=int(LC_RANGE))
print(f"Length Ratio: {length_ratio:.3f}, Length Range +-{LC_RANGE}%: {length_range:.3f}")