-
Notifications
You must be signed in to change notification settings - Fork 92
/
sample_seq2seq.py
executable file
·224 lines (174 loc) · 6.87 KB
/
sample_seq2seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""
import argparse
import os, json
from tracemalloc import start
import numpy as np
import torch as th
import torch.distributed as dist
from transformers import set_seed
from diffuseq.rounding import denoised_fn_round
from diffuseq.text_datasets import load_data_text
# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import time
from diffuseq.utils import dist_util, logger
from functools import partial
from basic_utils import (
load_defaults_config,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
load_tokenizer
)
def create_argparser():
defaults = dict(model_path='', step=0, out_dir='', top_p=0)
decode_defaults = dict(split='valid', clamp_step=0, seed2=105, clip_denoised=False)
defaults.update(load_defaults_config())
defaults.update(decode_defaults)
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
@th.no_grad()
def main():
args = create_argparser().parse_args()
dist_util.setup_dist()
logger.configure()
world_size = dist.get_world_size() or 1
rank = dist.get_rank() or 0
# load configurations.
config_path = os.path.join(os.path.split(args.model_path)[0], "training_args.json")
print(config_path)
# sys.setdefaultencoding('utf-8')
with open(config_path, 'rb', ) as f:
training_args = json.load(f)
training_args['batch_size'] = args.batch_size
args.__dict__.update(training_args)
logger.log("### Creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, load_defaults_config().keys())
)
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu")
)
pytorch_total_params = sum(p.numel() for p in model.parameters())
logger.log(f'### The parameter count is {pytorch_total_params}')
model.eval().requires_grad_(False).to(dist_util.dev())
tokenizer = load_tokenizer(args)
model_emb = th.nn.Embedding(
num_embeddings=tokenizer.vocab_size,
embedding_dim=args.hidden_dim,
_weight=model.word_embedding.weight.clone().cpu()
).eval().requires_grad_(False)
set_seed(args.seed2)
print("### Sampling...on", args.split)
## load data
data_valid = load_data_text(
batch_size=args.batch_size,
seq_len=args.seq_len,
deterministic=True,
data_args=args,
split=args.split,
loaded_vocab=tokenizer,
model_emb=model_emb.cpu(), # using the same embedding wight with tranining data
loop=False
)
start_t = time.time()
# batch, cond = next(data_valid)
# print(batch.shape)
model_base_name = os.path.basename(os.path.split(args.model_path)[0]) + f'.{os.path.split(args.model_path)[1]}'
out_dir = os.path.join(args.out_dir, f"{model_base_name.split('.ema')[0]}")
if not os.path.isdir(out_dir):
os.mkdir(out_dir)
out_path = os.path.join(out_dir, f"ema{model_base_name.split('.ema')[1]}.samples")
if not os.path.isdir(out_path):
os.mkdir(out_path)
out_path = os.path.join(out_path, f"seed{args.seed2}_step{args.clamp_step}.json")
# fout = open(out_path, 'a')
all_test_data = []
idx = 0
try:
while True:
batch, cond = next(data_valid)
# print(batch.shape)
if idx % world_size == rank: # Split data per nodes
all_test_data.append(cond)
idx += 1
except StopIteration:
print('### End of reading iteration...')
model_emb.to(dist_util.dev())
if idx % world_size and rank >= idx % world_size:
all_test_data.append({}) # Dummy data for Remainder : for dist.barrier()
if rank == 0:
from tqdm import tqdm
iterator = tqdm(all_test_data)
else:
iterator = iter(all_test_data)
for cond in iterator:
if not cond: # Barrier for Remainder
for i in range(world_size):
dist.barrier()
continue
input_ids_x = cond.pop('input_ids').to(dist_util.dev())
x_start = model.get_embeds(input_ids_x)
input_ids_mask = cond.pop('input_mask')
input_ids_mask_ori = input_ids_mask
noise = th.randn_like(x_start)
input_ids_mask = th.broadcast_to(input_ids_mask.unsqueeze(dim=-1), x_start.shape).to(dist_util.dev())
x_noised = th.where(input_ids_mask == 0, x_start, noise)
model_kwargs = {}
if args.step == args.diffusion_steps:
args.use_ddim = False
step_gap = 1
else:
args.use_ddim = True
step_gap = args.diffusion_steps//args.step
sample_fn = (
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
)
sample_shape = (x_start.shape[0], args.seq_len, args.hidden_dim)
samples = sample_fn(
model,
sample_shape,
noise=x_noised,
clip_denoised=args.clip_denoised,
denoised_fn=partial(denoised_fn_round, args, model_emb),
model_kwargs=model_kwargs,
top_p=args.top_p,
clamp_step=args.clamp_step,
clamp_first=True,
mask=input_ids_mask,
x_start=x_start,
gap=step_gap
)
# print(samples[0].shape) # samples for each step
sample = samples[-1]
# print('decoding for seq2seq', )
# print(sample.shape)
logits = model.get_logits(sample) # bsz, seqlen, vocab
cands = th.topk(logits, k=1, dim=-1)
word_lst_recover = []
word_lst_ref = []
word_lst_source = []
# tokenizer = load_tokenizer(args)
for seq, input_mask in zip(cands.indices, input_ids_mask_ori):
len_x = args.seq_len - sum(input_mask).tolist()
tokens = tokenizer.decode_token(seq[len_x:])
word_lst_recover.append(tokens)
for seq, input_mask in zip(input_ids_x, input_ids_mask_ori):
# tokens = tokenizer.decode_token(seq)
len_x = args.seq_len - sum(input_mask).tolist()
word_lst_source.append(tokenizer.decode_token(seq[:len_x]))
word_lst_ref.append(tokenizer.decode_token(seq[len_x:]))
for i in range(world_size):
if i == rank: # Write files sequentially
fout = open(out_path, 'a')
for (recov, ref, src) in zip(word_lst_recover, word_lst_ref, word_lst_source):
print(json.dumps({"recover": recov, "reference": ref, "source": src}), file=fout)
fout.close()
dist.barrier()
print('### Total takes {:.2f}s .....'.format(time.time() - start_t))
print(f'### Written the decoded output to {out_path}')
if __name__ == "__main__":
main()