-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.py
130 lines (109 loc) · 3.86 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from pathlib import Path
import equinox as eqx
from einops import rearrange
from PIL import Image
import numpy as np
import jax.numpy as jnp
import jax
import jax.random as random
from configs import (
AnimeDatasetConfig,
DiffusionConfig,
MainConfig,
ModelConfig,
TrainerConfig,
)
from hydra import compose, initialize_config_dir
from jaxtyping import Array, Float, UInt
from src.model import UViT
from src.diffusion import decode_all, scheduler, CLIP
def load_model(config: Path, checkpoint: Path) -> tuple[UViT, MainConfig]:
# Load the initial configurations.
with initialize_config_dir(str(config.parent), version_base="1.1"):
config = compose(str(config.name))
config = MainConfig(
dataset=AnimeDatasetConfig(**config.dataset),
diffusion=DiffusionConfig(**config.diffusion),
model=ModelConfig(**config.model),
trainer=TrainerConfig(**config.trainer),
mode=config.mode,
)
schedule = scheduler(config.diffusion.steps)
# Instantiate the corresponding model with the right hyperparameters.
model = UViT(
num_channels=config.dataset.n_channels,
num_positions=(config.dataset.image_size // config.model.patch_size) ** 2,
num_timesteps=len(schedule),
patch_size=config.model.patch_size,
d_model=config.model.d_model,
num_heads=config.model.num_heads,
num_layers=config.model.num_layers,
key=random.key(config.model.seed),
)
# Load the trained weights.
model = eqx.tree_deserialise_leaves(checkpoint, model)
return model, config
def to_gif(images: UInt[Array, "steps channels height width"], path: Path):
"""Save a gif animation showing the transition from the steps."""
assert path.suffix == ".gif"
images = jnp.transpose(images, (0, 2, 3, 1))
images = np.array(images)
images = [Image.fromarray(im) for im in images]
images = [
im.resize((244, 244)) for im in images
] # Resize so that final size is < 100MB.
images = [images[i] for i in range(0, len(images), 10)]
images[0].save(path, save_all=True, append_images=images[1:], duration=10)
def make_grid(
images: UInt[Array, "batch channels height width"],
) -> UInt[Array, "grid_height grid_width channels"]:
nrows = int(jnp.sqrt(len(images)))
ncols = len(images) // nrows
grid = rearrange(images, "(nr nc) c h w -> nr nc h w c", nr=nrows, nc=ncols)
grid = [
jnp.concat(col, axis=0) # Shape of [(h nc), w, c].
for col in grid
]
grid = jnp.concat(grid, axis=1) # Shape of [(h nc), (w nr), c].
return grid
def generate(
model: UViT,
image_shape: tuple[int, int, int],
n_images: int,
schedule: Float[Array, " steps"],
image_path: Path,
gif_path: Path,
key: Array,
):
model = eqx.nn.inference_mode(model)
decode_fn = jax.vmap(decode_all, in_axes=(0, None, None, 0))
make_grid_fn = jax.vmap(make_grid)
sk_decode, sk_normal = random.split(key)
xT = random.normal(sk_normal, (n_images, *image_shape))
xT = jnp.clip(xT, -CLIP, CLIP)
x_dec = decode_fn(xT, model, schedule, random.split(sk_decode, n_images))
x_dec = rearrange(x_dec, "b s c h w -> s b c h w")
grid = make_grid_fn(x_dec)
Image.fromarray(np.array(grid[-1])).save(image_path)
to_gif(rearrange(grid, "s h w c -> s c h w"), gif_path)
if __name__ == "__main__":
exp_dir = Path("./final-run/").absolute()
model, config = load_model(
exp_dir / ".hydra/config.yaml", exp_dir / "checkpoint.eqx"
)
image_shape = (
config.dataset.n_channels,
config.dataset.image_size,
config.dataset.image_size,
)
schedule = scheduler(config.diffusion.steps)
n_images = 9
generate(
model,
image_shape,
n_images,
schedule,
Path("generation.png"),
Path("generation.gif"),
random.key(42),
)