-
Notifications
You must be signed in to change notification settings - Fork 3
/
discrete_diffusion_base.py
224 lines (192 loc) · 9.52 KB
/
discrete_diffusion_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import math
import numpy as np
import torch
import torch.nn as nn
from discrete_diffusions.utils import (
topk_masking
)
class DiscreteDiffusion(nn.Module):
"""
The parent class for discrete denoising diffusion probabilistic models.
It should support the following methods:
- q_sample()
Sample x_t ~ q(x_t | x_0) to construct noisy Transformer inputs.
- compute_losses()
Compute the loss L_t = KL(q||p) at t-th time step.
- sample_step()
Sample x_t ~ p(x_{t-1} | x_t, x_0) at t-th time step.
"""
def __init__(self, num_timesteps):
super().__init__()
self.num_timesteps = num_timesteps
def q_sample(self, x_0, t, **kwargs):
"""
Sample from q(x_t | x_0), which is used as the model inputs.
Args:
x_0: token ids with shape [B, N]
t: current time step, tensor with shape [B]
Returns:
would possibly return a dict of relevant outputs, including x_t.
"""
raise NotImplementedError
def compute_losses(self, inputs, **kwargs):
"""
Compute the loss objective KL(q||p) to train our generative process.
Args:
inputs: a dict that contains input types specific to different diffusion processes
Returns:
would possibly return a dict of relevant outputs, including loss used for training.
"""
raise NotImplementedError
def sample_step(self, decoder_out, denoising_fn, **kwargs):
"""
Given a time step t, start from x_t and sample x_{t-k} from q(x_{t-k} | x_t).
Args:
decoder_out: a namedtuple that contains decoding info, including
- x_t: token ids with shape [B, N]
- t: scalar timesteps
- max_steps: the maximum number of decoding steps
- ...
denoising_fn: a function that takes in x_t and t and returns model logits
kwargs: other arguments that are used to control decoding.
Returns:
return a new decoder_out namedtuple.
"""
raise NotImplementedError
def _get_batched_decoding_strategy(self, t, max_step):
"""
This function is used to compute the step size and the number of steps
when t and max_step are both tensors with shape [B].
It would be invoked in the case of adaptive decoding, where the sentences
within the batch will be generated with different time steps.
"""
_step_size = torch.div(self.num_timesteps, max_step, rounding_mode="floor")
b = self.num_timesteps - _step_size * max_step
step_size = torch.where(t < b, _step_size + 1, _step_size)
step = torch.where(t < b, self.num_timesteps - t * step_size, self.num_timesteps - b - t * step_size)
return step.long(), step_size.long()
def _get_decoding_strategy(self, t, decoding_strategy, max_step):
"""
This function is used to compute the step size and the number of steps
when max_step is a scalar.
"""
assert hasattr(self, "num_timesteps"), "num_timesteps is not set."
if getattr(self, 'step_array', None) is None:
if decoding_strategy == "spline":
B = math.floor(max_step / 3)
C = (self.num_timesteps - B) / (max_step - B)**2
linear, quad = np.arange(start=0, stop=B + 1), np.arange(start=1, stop=max_step - B + 1)
self.step_array = np.concatenate([linear, np.floor(B + C * quad** 2)])[::-1].astype(int)
else:
# t passed here ranges from 0 to max_step - 1.
if max_step > self.num_timesteps:
raise NotImplementedError(
"will run {} empty steps. "
"Consider increase diffusion time-steps or decrease max_step."
.format(max_step - self.num_timesteps)
)
elif self.num_timesteps % max_step == 0:
step_size = int(self.num_timesteps // max_step)
# terminates at max_step to ease computing the step_size.
self.step_array = np.array([(max_step - t) * step_size for t in range(max_step + 1)]).astype(int)
else:
# why this is the case? we are given total diffusion steps and max_step budget.
# first we compute the approximate step size, which is also given for now.
# then we assume uneven strides: either steps with step-size or (step-size + 1)
# as a result, we can solve this system:
# step-size * a + (step-size + 1) * b = total_diffusion_steps
# a + b = max_step.
# where a and b are number of steps of corresponding strides resp.
step_size = int(self.num_timesteps // max_step)
b = self.num_timesteps - step_size * max_step
self.step_array = np.concatenate(
[
[self.num_timesteps - t * (step_size + 1) for t in range(b)],
[step_size * (max_step - t) for t in range(b, max_step+1)]
]
).astype(int)
return self.step_array[t], (self.step_array[t] - self.step_array[t+1])
def _reparam_decoding(
self,
output_tokens,
output_scores,
cur_tokens,
cur_scores,
decoding_strategy,
xt_neq_x0,
non_special_sym_mask,
t,
max_step,
noise
):
"""
This function is used to perform reparameterized decoding.
"""
# output_tokens: [B, N]
# output_scores: [B, N]
# cur_tokens: [B, N]
# cur_scores: [B, N]
# xt_neq_x0: equivalent to not_b_t [B, N]
# non_special_sym_mask: [B, N]
# noise: either [B, N] or scalar (if using the mask noise)
# decoding_strategy needs to take the form of "reparam-<conditioning>-<topk_mode>-<schedule>"
_, condition, topk_mode, schedule = decoding_strategy.split("-")
# first set the denoising rate according to the schedule
if schedule == "linear":
rate = 1 - (t + 1) / max_step
elif schedule == "cosine":
rate = np.cos((t + 1) / max_step * np.pi * 0.5)
else:
raise NotImplementedError
# compute the cutoff length for denoising top-k positions
cutoff_len = (
non_special_sym_mask.sum(1, keepdim=True).type_as(output_scores) * rate
).long()
# set the scores of special symbols to a large value so that they will never be selected
_scores_for_topk = cur_scores.masked_fill(~non_special_sym_mask, 1000.0)
# the top-k selection can be done in two ways: stochastic by injecting Gumbel noise or deterministic
if topk_mode.startswith("stochastic"):
noise_scale = float(topk_mode.replace("stochastic", ""))
lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=True, temp=noise_scale * rate)
elif topk_mode == "deterministic":
lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=False)
else:
raise NotImplementedError
# Various choices to generate v_t := [v1_t, v2_t].
# Note that
# v1_t governs the outcomes of tokens where b_t = 1,
# v2_t governs the outcomes of tokens where b_t = 0.
# #### the `uncond` mode ####
# In our reparameterized decoding,
# both v1_t and v2_t can be fully determined by the current token scores .
# #### the `cond` mode ####
# However, we can also impose some conditional constraints on v1_t so that
# the decoding can be performed in a more conservative manner.
# For example, we can set v1_t = 0 only when
# (the newly output tokens are the same as previous denoised results, AND
# the current token score becomes lower, AND
# the current token score is not in the top-k share among all tokens).
if condition == "cond":
not_v1_t = (cur_tokens == output_tokens) & (cur_scores < output_scores) & lowest_k_mask
elif condition == "uncond":
not_v1_t = lowest_k_mask
else:
raise NotImplementedError
# for b_t = 0, the token is set to noise if it is in the lowest k scores.
not_v2_t = lowest_k_mask
masked_to_noise = (~xt_neq_x0 & not_v1_t) | (xt_neq_x0 & not_v2_t)
if isinstance(noise, torch.Tensor):
output_tokens.masked_scatter_(masked_to_noise, noise[masked_to_noise])
elif isinstance(noise, (int, float)):
output_tokens.masked_fill_(masked_to_noise, noise)
else:
raise NotImplementedError("noise should be either a tensor or a scalar")
output_scores.masked_fill_(masked_to_noise, -math.inf)
masked_to_x0 = xt_neq_x0 & ~not_v2_t
output_tokens.masked_scatter_(masked_to_x0, cur_tokens[masked_to_x0])
output_scores.masked_scatter_(masked_to_x0, cur_scores[masked_to_x0])
# b_{t} = (b_{t+1} & u_t) | v_t
# For convenience, save the NOT of b_t for the next iteration
# NOT_b_{t} = (NOT_b_{t+1} | not_v1_t) & not_v2_t
new_xt_neq_x0 = (xt_neq_x0 | not_v1_t) & not_v2_t
return new_xt_neq_x0