-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
75 lines (58 loc) · 1.87 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from dataclasses import dataclass
import io
import os
import uuid
import numpy as np
from typing import Any, List
import tensorflow as tf
from tqdm.std import trange
import math
def tile_images(images):
n_images = tf.cast(tf.shape(images)[0], float)
# Convert to side of square
n = int(tf.math.floor(tf.math.sqrt(n_images)))
_, height, width, channels = tf.shape(images)
images = tf.reshape(images, [n, n, height, width, channels])
images = tf.transpose(images, perm=[2, 0, 3, 1, 4])
return tf.reshape(images, [n * height, n * width, channels])
def sample_to_dir(model, batch_size, sample_size, temperature, output_dir, binary=False):
batches = max(sample_size // batch_size, 1)
for image_batch in trange(batches, desc="Generating samples"):
images, *_ = model.sample(
n_samples=batch_size, greyscale=not binary, temperature=temperature
)
save_images_to_dir(images, output_dir)
def save_images_to_dir(images, dir):
if images.dtype.is_floating:
images = tf.cast(images * 255, tf.uint8)
for image in images:
encoded = tf.io.encode_png(image)
tf.io.write_file(os.path.join(dir, f"{uuid.uuid4()}.png"), encoded)
def calculate_log_p(z, mu, sigma):
normalized_z = (z - mu) / sigma
log_p = (
-0.5 * normalized_z * normalized_z
- 0.5 * tf.math.log(2 * tf.constant(math.pi))
- tf.math.log(sigma)
)
return log_p
def softclamp5(x):
return 5.0 * tf.math.tanh(x / 5.0) # differentiable clamp [-5, 5]
@dataclass
class Metric:
mean: float
stddev: float
@staticmethod
def from_list(l):
return Metric(mean=np.mean(l), stddev=np.std(l))
@dataclass
class Metrics:
temperature: float
fid: float
ppl: Metric
precision: Metric
recall: Metric
@dataclass
class ModelEvaluation:
nll: Metric
sample_metrics: List[Metrics]