forked from google-research/vet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
parameterized_sample_lib.py
481 lines (396 loc) · 14.8 KB
/
parameterized_sample_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
"""Copyright 2022 Google LLC.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Code support for parameterized stochastic models for probabilistic response.
This library supports the simulation of machine learning experiments in order
to improve experimental design. The tools presented here can help answer
questions such as:
+ How many human annotators are needed in order to create
reliable gold standard test data?
+ How many samples per prompt are needed in order to reliably distinguish
the performance of one generative model from another? How many prompts?
+ What is the best way to measure "reliability"?
This library provides stochastic parameterized **response grids** for
+ Two stochastic response models, such as two machine-learning-based models,
say machine 1 and machine 2.
+ One "true" stochastic model, say a crowd of human annotators.
A typical use case involves evaluating the machine responses are against the
human responses over a common set of items, so in simulations we need sample
from all three models at the same time. This library does exactly that.
"""
import datetime
import enum
import functools
import json
import pickle
import random as rand
from typing import Any, Callable, List, Tuple
from absl import logging
import numpy as np
import datatypes
def toxicity_mean_dist() -> float:
"""Return the mean of a toxicity human rater distribution.
For use with the toxicity dataset.
Returns:
The value accoring to the clamped normalvariate with the parameters shown
below.
"""
return clamp(abs(np.random.default_rng().normal(0, 0.28)), max_value=0.8)
def toxicity_stdev_dist() -> float:
"""Return the standard deviation of a toxicity human rater distribution.
For use with the toxicity dataset.
Returns:
The value accoring to the triagular distribution with the parameters shown
below.
"""
return clamp(np.random.default_rng().triangular(-0.06, 0.21, 0.45))
def clamp(num: float, min_value: float = 0.0, max_value: float = 1.0) -> float:
"""Clamp clamp num to be between min/max by clipping.
Args:
num: The number to clamp.
min_value: The minimum value clamped to.
max_value: The maximum value clamped to.
Returns:
The number after clamping.
"""
return np.clip(num, min_value, max_value)
def distort_shape(s_param: float, diff: float) -> float:
"""Randomly distort a parameter.
Args:
s_param: A scalar parameter.
diff: The maximum amount of distortion.
Returns:
The scalar s_param plus a random amount, determined by diff.
"""
return clamp(s_param + rand.uniform(-diff, diff))
def sample_from(distr: Callable[[], float], num: int) -> List[float]:
"""Draw a sample of size num from dist.
Args:
distr: A probability distribution function. Each call draws a random sample.
num: The number of samples to draw.
Returns:
A list of random values
"""
return [distr() for _ in range(num)]
def norm_distr_factory(
mean: float,
stdev: float,
h_dist: Callable[[float, float], float] = rand.normalvariate,
) -> Callable[[], float]:
"""Helper function for gen_alt_h_distrs_norm.
Args:
mean: The mean parameter for the distribution.
stdev: The standard deviation parameter distribution.
h_dist: A norm-based probability distribution function.
Returns:
A function that samples from h_dist whenever called.
"""
return lambda: h_dist(mean, stdev)
def gen_alt_h_distrs_norm(
mean_distr: Callable[[], float],
stdev_distr: Callable[[], float],
n: int,
alt_distortion: float = 0.1,
h_dist: Callable[[float, float], float] = rand.normalvariate,
) -> Tuple[
List[Callable[[], float]],
List[Callable[[], float]],
List[Callable[[], float]],
]:
"""Create parameterized normal probability distributions for each item.
Args:
mean_distr: A generating function for the mean parameter of each item's
response.
stdev_distr: A generating function for the standard deviation parameter of
each item's response distribution.
n: number of items.
alt_distortion: A parameter determining the bias of the second model's
response distribution
h_dist: the distribution for each human or machine responder.
Returns:
A triple of human, machine 1 and machine 2 response distributions.
"""
human_means = sample_from(mean_distr, n)
human_stdevs = sample_from(stdev_distr, n)
machine2_means = [distort_shape(s, alt_distortion) for s in human_means]
human_item_distrs = [
norm_distr_factory(mean, dev, h_dist)
for mean, dev in zip(human_means, human_stdevs)
]
machine1_item_distrs = [
norm_distr_factory(mean, dev, h_dist)
for mean, dev in zip(human_means, human_stdevs)
]
machine2_item_distrs = [
norm_distr_factory(mean, dev, h_dist)
for mean, dev in zip(machine2_means, human_stdevs)
]
return human_item_distrs, machine1_item_distrs, machine2_item_distrs
def sample_h(
hum_h_distrs: List[Callable[[], float]],
mach1_h_distrs: List[Callable[[], float]],
mach2_h_distrs: List[Callable[[], float]],
resps_per_item: int = 1,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Sample a number of responses in the "horizontal" direction for each item.
Args:
hum_h_distrs: List of human response distributions.
mach1_h_distrs: List of machine 1 response distributions.
mach2_h_distrs: List of machine 2 response distributions.
resps_per_item: The number of responses per item.
Returns:
Three response tables for humans, machine1, and machine2 responses,
respectively, as numpy arrays.
"""
gold = [sample_from(hdistr, resps_per_item) for hdistr in hum_h_distrs]
preds1 = [sample_from(hdistr, resps_per_item) for hdistr in mach1_h_distrs]
preds2 = [sample_from(hdistr, resps_per_item) for hdistr in mach2_h_distrs]
gold = np.array(gold)
preds1 = np.array(preds1)
preds2 = np.array(preds2)
return gold, preds1, preds2
def null_hypothesis_generator(
distr1: Callable[[], float], distr2: Callable[[], float]
) -> Callable[[], float]:
"""Create a null hypthesis generator.
Args:
distr1: One distribution, representing the response distribution of one
machine.
distr2: Another distribution, representing the response distribution of
another machine.
Returns:
A new distribution, which chooses uniformly from the two distributions
given.
"""
def null_dist():
f = rand.choice([distr1, distr2])
return f()
return null_dist
def alt_distr_gen(n: int, distortion: float = 0.3) -> Tuple[
List[Callable[[], float]],
List[Callable[[], float]],
List[Callable[[], float]],
]:
"""Generates a number of alternative distribution triples.
Args:
n: The number of triples to generate. It is the number of items in a
simulated response set.
distortion: the distortion value of the mean/variance.
Returns:
A 3-tuple of lists of distribution functions. Each list is of length
n, and corresponds to 1 of 2 machine responses, or a human response.
"""
return gen_alt_h_distrs_norm(
uniform_dist_factory(0, 1),
uniform_dist_factory(0, 0.3),
n,
alt_distortion=distortion,
)
def likert_norm_dist(mean: float, std: float, rate: int = 5) -> float:
"""Sample from a distribution over a likert-like domain.
Args:
mean: the mean of the generating normal distribution
std: the standard deviation of the generating normal distribution
rate: the number of levels in the likert-like domain
Returns:
A value between 0 and 1, generated by clamping the generating normal
distribution between 0 and 1 and then breaking it into rate evenly
distributed intervals. The value returned is the minimum value of
all values in the iterval. E.g., for the default value of rate=5, the
return values are {0, 0.2, 0.4, 0.6, 0.8}
"""
x = clamp(np.random.default_rng().normal(mean, std))
x = int(x * rate) / rate
return x if x < 1 else (rate - 1) / rate
def toxicity_distr_gen(n: int, distortion: float) -> Tuple[
List[Callable[[], float]],
List[Callable[[], float]],
List[Callable[[], float]],
]:
"""Generates a number alternative distribution triples.
This specific generator is based on the toxicity dataset from:
https://data.esrg.stanford.edu/study/toxicity-perspectives
Args:
n: The number of triples to generate. It is the number of items in a
simulated response set.
distortion: the distortion value.
Returns:
A 3-tuple of lists of distribution functions. Each list is of length
n, and corresponds to 1 of 2 machine responses, or a human response.
"""
return gen_alt_h_distrs_norm(
toxicity_mean_dist,
toxicity_stdev_dist,
n,
alt_distortion=distortion,
h_dist=likert_norm_dist,
)
def generate_response_tables(
n_items: int = 1000,
k_responses: int = 5,
distortion: float = 0.3,
num_samples: int = 1000,
alt_distr_generator: Callable[
[int, float],
Tuple[
List[Callable[[], float]],
List[Callable[[], float]],
List[Callable[[], float]],
],
] = alt_distr_gen,
) -> datatypes.ResponseSets:
"""Generates a collection of human and machine responses.
Generates tables ("sets"), for null and alternate hypotheses
Args:
n_items: Number of items per set.
k_responses: Number of responses/set
distortion: Mean/variance distortion value.
num_samples: Number of samples of size n_items x k_responses.
alt_distr_generator: Function that generates one <gold, machine1, machine2>
response table set.
Returns:
A dictionary organized by null and alt hypothesis, with list of
dictionaries, with each dictionary containing one <gold, machine1, machine2>
response.
"""
responses_alt = []
responses_null = []
for _ in range(num_samples):
# Obtain response tables and results
hum_h_distrs, mach1_h_distrs, mach2_h_distrs = alt_distr_generator(
n_items, distortion
)
gold_alt, preds1_alt, preds2_alt = sample_h(
hum_h_distrs, mach1_h_distrs, mach2_h_distrs, resps_per_item=k_responses
)
mach_null_h_distrs = [
null_hypothesis_generator(mach1_h_distr, mach2_h_distr)
for mach1_h_distr, mach2_h_distr in zip(mach1_h_distrs, mach2_h_distrs)
]
gold_null, preds1_null, preds2_null = sample_h(
hum_h_distrs,
mach_null_h_distrs,
mach_null_h_distrs,
resps_per_item=k_responses,
)
responses_alt.append(
datatypes.ResponseData(
gold=gold_alt, preds1=preds1_alt, preds2=preds2_alt
)
)
responses_null.append(
datatypes.ResponseData(
gold=gold_null, preds1=preds1_null, preds2=preds2_null
)
)
response_sets = datatypes.ResponseSets(
alt_data_list=responses_alt, null_data_list=responses_null
)
return response_sets
def uniform_dist_factory(minimum: float, maximum: float) -> Callable[[], float]:
"""Helper function for passing to gen_alt_h_distrs_norm.
Args:
minimum: The min parameter for the distribution.
maximum: The max parameter for the distribution.
Returns:
A function that samples from uniform[min, max] whenever called.
"""
return lambda: rand.uniform(minimum, maximum)
def norm_generator(
min_mean: float,
max_mean: float,
min_std: float,
max_std: float,
dist: float,
) -> Callable[[Any], Any]:
"""Helper function for generating triples of related norm distributions.
(I.e., related to human and machs 1 and 2's responses to the same data item.)
The mean and std_dev parameters are drawn from uniform intervals. The human
and mach 1 have the same amount of distortion. An amount of distortion can be
added to mach 2, it is also drawn from a uniform interval.
Args:
min_mean: the minimum value the mean may take.
max_mean: the maximum value the mean may take.
min_std: the minimum value the std_dev may take.
max_std: the maximum value the std may take.
dist: The distortion interval.
Returns:
The norm generators (as functions)
"""
def fn(x):
return gen_alt_h_distrs_norm(
(lambda: rand.uniform(min_mean, max_mean)),
(lambda: rand.uniform(min_std, max_std)),
x,
alt_distortion=dist,
)
# ugly, but we must save space
min_mean_str = f"{min_mean}".replace("0.", ".")
max_std_str = f"{max_std}".replace("0.", ".")
dist_str = f"{dist}".replace("0.", ".")
fn.__name__ = f"gen_alt_h_distrs_norm({min_mean_str},{max_mean},{min_std},{max_std_str},{dist_str})"
return fn
def read_samples_from_file(
input_filename: str,
use_pickle: bool,
) -> datatypes.ResponseSets:
"""Reads the sample data sets from a file.
Args:
input_filename: The input filename.
use_pickle: If true use pickle to deserialize data. Otherwise use json.
Returns:
The loaded datasets.
"""
open_mode = "rb" if use_pickle else "r"
with open(input_filename, open_mode) as f:
response_sets_dict = pickle.load(f) if use_pickle else json.load(f)
return datatypes.ResponseSets.from_dict(response_sets_dict)
def write_samples_to_file(
response_sets: datatypes.ResponseSets,
output_filename: str,
use_pickle: bool,
) -> None:
"""Outputs the sample data to a file.
Args:
response_sets: The sample datasets to output.
output_filename: The output filename.
use_pickle: If true use pickle to serialize data. Otherwise use json. Pickle
serialization is in binary format so it is more efficient.
"""
write_start_time = datetime.datetime.now()
open_mode = "wb" if use_pickle else "w"
with open(output_filename, open_mode) as f:
if use_pickle:
pickle.dump(response_sets.to_dict(), f)
else:
json.dump(response_sets.to_dict(), f)
elapsed_time = datetime.datetime.now() - write_start_time
logging.info("File writing time=%f", elapsed_time.total_seconds())
@enum.unique
class GeneratorType(enum.Enum):
"""Types of generator functions."""
def __call__(self, *args):
return self.value(*args)
ALT_DISTR_GEN: Callable[
[int],
Tuple[
List[Callable[[], float]],
List[Callable[[], float]],
List[Callable[[], float]],
],
] = functools.partial(alt_distr_gen)
TOXICITY_DISTR_GEN: Callable[
[int],
Tuple[
List[Callable[[], float]],
List[Callable[[], float]],
List[Callable[[], float]],
],
] = functools.partial(toxicity_distr_gen)