Skip to content

Commit

Permalink
Integration test with metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
mjdenkowski committed Sep 8, 2022
1 parent 667b003 commit 4ee4d01
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 17 deletions.
45 changes: 40 additions & 5 deletions sockeye/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,14 @@ def generate_digits_file(source_path: str,
line_length: int = 9,
sort_target: bool = False,
line_count_empty: int = 0,
seed=13):
seed=13,
metadata_path: Optional[str] = None,
p_reverse_with_metadata_tag: float = 0.):
assert line_count_empty <= line_count
metadata_out = None
if p_reverse_with_metadata_tag > 0.:
assert metadata_path is not None
metadata_out = open(metadata_path, "w")
random_gen = random.Random(seed)
with open(source_path, "w") as source_out, open(target_path, "w") as target_out:
all_digits = []
Expand All @@ -55,7 +61,15 @@ def generate_digits_file(source_path: str,
print(C.TOKEN_SEPARATOR.join(digits), file=source_out)
if sort_target:
digits.sort()
if p_reverse_with_metadata_tag > 0.:
if random.random() < p_reverse_with_metadata_tag:
digits.reverse()
print(r'{"reversed": 1}', file=metadata_out)
else:
print('', file=metadata_out)
print(C.TOKEN_SEPARATOR.join(digits), file=target_out)
if metadata_out is not None:
metadata_out.close()


def generate_json_input_file_with_tgt_prefix(src_path:str, tgt_path: str, json_file_with_tgt_prefix_path: str, \
Expand Down Expand Up @@ -146,7 +160,8 @@ def tmp_digits_dataset(prefix: str,
sort_target: bool = False,
seed_train: int = 13, seed_dev: int = 13,
with_n_source_factors: int = 0,
with_n_target_factors: int = 0) -> Dict[str, Any]:
with_n_target_factors: int = 0,
p_reverse_with_metadata_tag: float = 0.) -> Dict[str, Any]:
"""
Creates a temporary dataset with train, dev, and test. Returns a dictionary with paths to the respective temporary
files.
Expand All @@ -155,17 +170,23 @@ def tmp_digits_dataset(prefix: str,
# Simple digits files for train/dev data
train_source_path = os.path.join(work_dir, "train.src")
train_target_path = os.path.join(work_dir, "train.tgt")
train_metadata_path = os.path.join(work_dir, "train.md")
dev_source_path = os.path.join(work_dir, "dev.src")
dev_target_path = os.path.join(work_dir, "dev.tgt")
dev_metadata_path = os.path.join(work_dir, "dev.md")
test_source_path = os.path.join(work_dir, "test.src")
test_target_path = os.path.join(work_dir, "test.tgt")
test_metadata_path = os.path.join(work_dir, "test.md")
test_source_with_target_prefix_path = os.path.join(work_dir, "test_source_with_target_prefix.json")
generate_digits_file(train_source_path, train_target_path, train_line_count, train_max_length,
line_count_empty=train_line_count_empty, sort_target=sort_target, seed=seed_train)
line_count_empty=train_line_count_empty, sort_target=sort_target, seed=seed_train,
metadata_path=train_metadata_path, p_reverse_with_metadata_tag=p_reverse_with_metadata_tag)
generate_digits_file(dev_source_path, dev_target_path, dev_line_count, dev_max_length, sort_target=sort_target,
seed=seed_dev)
seed=seed_dev, metadata_path=dev_metadata_path,
p_reverse_with_metadata_tag=p_reverse_with_metadata_tag)
generate_digits_file(test_source_path, test_target_path, test_line_count, test_max_length,
line_count_empty=test_line_count_empty, sort_target=sort_target, seed=seed_dev)
line_count_empty=test_line_count_empty, sort_target=sort_target, seed=seed_dev,
metadata_path=test_metadata_path, p_reverse_with_metadata_tag=p_reverse_with_metadata_tag)
data = {'work_dir': work_dir,
'train_source': train_source_path,
'train_target': train_target_path,
Expand Down Expand Up @@ -205,6 +226,11 @@ def tmp_digits_dataset(prefix: str,
data['dev_target_factors'].append(dev_factor_path)
data['test_target_factors'].append(test_factor_path)

if p_reverse_with_metadata_tag > 0.:
data['train_metadata'] = train_metadata_path
data['dev_metadata'] = dev_metadata_path
data['test_metadata'] = test_metadata_path

source_factors_path = None if 'test_source_factors' not in data else data['test_source_factors']
target_factors_path = None if 'test_target_factors' not in data else data['test_target_factors']
generate_json_input_file_with_tgt_prefix(test_source_path, test_target_path, test_source_with_target_prefix_path, \
Expand All @@ -224,6 +250,9 @@ def tmp_digits_dataset(prefix: str,
TRAIN_WITH_TARGET_FACTORS_COMMON = " --target-factors {target_factors}"
DEV_WITH_TARGET_FACTORS_COMMON = " --validation-target-factors {dev_target_factors}"

TRAIN_WITH_METADATA_COMMON = " --metadata {train_metadata}"
DEV_WITH_METADATA_COMMON = " --validation-metadata {dev_metadata}"

TRAIN_PARAMS_PREPARED_DATA_COMMON = "--use-cpu --max-seq-len {max_len} --prepared-data {prepared_data}" \
" --validation-source {dev_source} --validation-target {dev_target} " \
"--output {model}"
Expand Down Expand Up @@ -279,6 +308,9 @@ def run_train_translate(train_params: str,
prepare_params += TRAIN_WITH_TARGET_FACTORS_COMMON.format(
target_factors=" ".join(data['train_target_factors']))

if 'train_metadata' in data:
prepare_params += TRAIN_WITH_METADATA_COMMON.format(train_metadata=data['train_metadata'])

if '--weight-tying-type src_trg' in train_params:
prepare_params += ' --shared-vocab'

Expand All @@ -299,6 +331,9 @@ def run_train_translate(train_params: str,
if 'dev_target_factors' in data:
params += DEV_WITH_TARGET_FACTORS_COMMON.format(dev_target_factors=" ".join(data['dev_target_factors']))

if 'dev_metadata' in data:
prepare_params += DEV_WITH_METADATA_COMMON.format(dev_metadata=data['dev_metadata'])

logger.info("Starting training with parameters %s.", train_params)
with patch.object(sys, "argv", params.split()):
sockeye.train.main()
Expand Down
38 changes: 26 additions & 12 deletions test/integration/test_seq_copy_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,19 @@
# max updates independent of the checkpoint interval
" --checkpoint-interval 20 --optimizer adam --initial-learning-rate 0.01 --learning-rate-scheduler none",
"--beam-size 2 --nbest-size 2",
False, 0, 0),
False, 0, 0, 0.),
# Basic transformer with metadata support: target reversed/tagged with p=0.5
("--encoder transformer --decoder {decoder}"
" --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8"
" --transformer-feed-forward-num-hidden 16"
" --transformer-dropout-prepost 0.1 --transformer-preprocess n --transformer-postprocess dr"
" --weight-tying-type src_trg_softmax"
" --batch-size 2 --max-updates 2 --batch-type sentence --decode-and-evaluate 2"
# Note: We set the checkpoint interval > max updates in order to make sure we create a checkpoint when reaching
# max updates independent of the checkpoint interval
" --checkpoint-interval 20 --optimizer adam --initial-learning-rate 0.01 --learning-rate-scheduler none",
"--beam-size 2 --nbest-size 2",
False, 0, 0, 0.5),
# Basic transformer w/ Neural Vocabulary Selection
("--encoder transformer --decoder {decoder}"
" --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8"
Expand All @@ -63,7 +75,7 @@
" --checkpoint-interval 2 --optimizer adam --initial-learning-rate 0.01"
" --neural-vocab-selection logit_max --bow-task-weight 2",
"--beam-size 2 --nbest-size 2",
False, 0, 0),
False, 0, 0, 0.),
# Basic transformer w/ prepared data & greedy decoding
("--encoder transformer --decoder {decoder}"
" --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8"
Expand All @@ -73,7 +85,7 @@
" --batch-size 2 --max-updates 2 --batch-type sentence --decode-and-evaluate 2"
" --checkpoint-interval 2 --optimizer adam --initial-learning-rate 0.01",
"--beam-size 1 --greedy",
True, 0, 0),
True, 0, 0, 0.),
# Basic transformer with source and target factors, beam-search-stop first decoding
("--encoder transformer --decoder {decoder}"
" --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8"
Expand All @@ -87,7 +99,7 @@
" --target-factors-combine sum --target-factors-share-embedding false"
" --target-factors-num-embed 8",
"--beam-size 2 --beam-search-stop first",
True, 3, 1),
True, 3, 1, 0.),
# Basic transformer with LHUC DISABLE FOR MX2 FOR NOW (UNKNOWN FAILURE)
("--encoder transformer --decoder transformer"
" --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8"
Expand All @@ -97,7 +109,7 @@
" --batch-size 2 --max-updates 2 --batch-type sentence --decode-and-evaluate 2"
" --checkpoint-interval 2 --optimizer adam --initial-learning-rate 0.01 --lhuc all",
"--beam-size 2",
False, 0, 0),
False, 0, 0, 0.),
# Basic transformer and length ratio prediction, and learned brevity penalty during inference
("--encoder transformer --decoder {decoder}"
" --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8"
Expand All @@ -109,7 +121,7 @@
" --length-task ratio --length-task-weight 1.0 --length-task-layers 1",
"--beam-size 2"
" --brevity-penalty-type learned --brevity-penalty-weight 1.0",
True, 0, 0),
True, 0, 0, 0.),
# Basic transformer and absolute length prediction, and constant brevity penalty during inference
("--encoder transformer --decoder {decoder}"
" --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8"
Expand All @@ -121,7 +133,7 @@
" --length-task length --length-task-weight 1.0 --length-task-layers 1",
"--beam-size 2"
" --brevity-penalty-type constant --brevity-penalty-weight 2.0 --brevity-penalty-constant-length-ratio 1.5",
False, 0, 0),
False, 0, 0, 0.),
# Basic transformer with clamp-to-dtype during training and inference
("--encoder transformer --decoder {decoder}"
" --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8"
Expand All @@ -131,7 +143,7 @@
" --batch-size 2 --max-updates 2 --batch-type sentence --decode-and-evaluate 2"
" --checkpoint-interval 2 --optimizer adam --initial-learning-rate 0.01 --clamp-to-dtype",
"--beam-size 2 --clamp-to-dtype",
False, 0, 0),
False, 0, 0, 0.),
# Basic transformer, training only the decoder
("--encoder transformer --decoder {decoder}"
" --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8"
Expand All @@ -142,7 +154,7 @@
" --checkpoint-interval 2 --optimizer adam --initial-learning-rate 0.01"
" --fixed-param-strategy " + C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_DECODER,
"--beam-size 2",
False, 0, 0),
False, 0, 0, 0.),
]

# expand test cases across transformer & ssru, as well as use_pytorch true/false
Expand All @@ -152,12 +164,13 @@


@pytest.mark.parametrize("train_params, translate_params, use_prepared_data,"
"n_source_factors, n_target_factors", TEST_CASES)
"n_source_factors, n_target_factors, p_reverse_with_metadata_tag", TEST_CASES)
def test_seq_copy(train_params: str,
translate_params: str,
use_prepared_data: bool,
n_source_factors: int,
n_target_factors: int):
n_target_factors: int,
p_reverse_with_metadata_tag: float):
"""
Task: copy short sequences of digits
"""
Expand All @@ -173,7 +186,8 @@ def test_seq_copy(train_params: str,
test_max_length=_TEST_MAX_LENGTH,
sort_target=False,
with_n_source_factors=n_source_factors,
with_n_target_factors=n_target_factors) as data:
with_n_target_factors=n_target_factors,
p_reverse_with_metadata_tag=p_reverse_with_metadata_tag) as data:
# TODO: Here we temporarily switch off comparing translation and scoring scores, which
# sometimes produces inconsistent results for --batch-size > 1 (see issue #639 on github).
check_train_translate(train_params=train_params,
Expand Down

0 comments on commit 4ee4d01

Please sign in to comment.