Skip to content

Commit

Permalink
add ability to train only super-resoluting unets, for @hmza09
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 16, 2022
1 parent a936ea7 commit ae5a1e1
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 9 deletions.
56 changes: 56 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,62 @@ trainer.update(unet_number = 1)
images = trainer.sample(batch_size = 16) # (16, 3, 128, 128)
```

Or train only super-resoluting unets

```python
import torch
from imagen_pytorch import Unet, NullUnet, Imagen

# unet for imagen

unet1 = NullUnet() # add a placeholder "null" unet for the base unet

unet2 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
unets = (unet1, unet2),
image_sizes = (64, 256),
timesteps = 250,
cond_drop_prob = 0.1
).cuda()

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

loss = imagen(images, text_embeds = text_embeds, unet_number = 2)
loss.backward()

# do the above for many many many many steps
# now you can sample an image based on the text embeddings as well as low resolution images

lowres_images = torch.randn(3, 3, 64, 64).cuda() # starting un-resoluted images

images = imagen.sample(
texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
],
start_at_unet_number = 2, # start at unet number 2
start_image_or_video = lowres_images, # pass in low resolution images to be resoluted
cond_scale = 3.)

images.shape # (3, 3, 256, 256)
```

At any time you can save and load the trainer and all associated states with the `save` and `load` methods. It is recommended you use these methods instead of manually saving with a `state_dict` call, as there are some device memory management being done underneath the hood within the trainer.

ex.
Expand Down
1 change: 1 addition & 0 deletions imagen_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from imagen_pytorch.imagen_pytorch import Imagen, Unet
from imagen_pytorch.imagen_pytorch import NullUnet
from imagen_pytorch.imagen_pytorch import BaseUnet64, SRUnet256, SRUnet1024
from imagen_pytorch.trainer import ImagenTrainer
from imagen_pytorch.version import __version__
Expand Down
23 changes: 21 additions & 2 deletions imagen_pytorch/elucidated_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from imagen_pytorch.imagen_pytorch import (
GaussianDiffusionContinuousTimes,
Unet,
NullUnet,
first,
exists,
identity,
Expand Down Expand Up @@ -138,7 +139,7 @@ def __init__(
self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment

for ind, one_unet in enumerate(unets):
assert isinstance(one_unet, (Unet, Unet3D))
assert isinstance(one_unet, (Unet, Unet3D, NullUnet))
is_first = ind == 0

one_unet = one_unet.cast_model_parameters(
Expand Down Expand Up @@ -507,6 +508,8 @@ def sample(
batch_size = 1,
cond_scale = 1.,
lowres_sample_noise_level = None,
start_at_unet_number = 1,
start_image_or_video = None,
stop_at_unet_number = None,
return_all_unet_outputs = False,
return_pil_images = False,
Expand Down Expand Up @@ -562,9 +565,23 @@ def sample(
sigma_min = cast_tuple(sigma_min, num_unets)
sigma_max = cast_tuple(sigma_max, num_unets)

# handle starting at a unet greater than 1, for training only-upscaler training

if start_at_unet_number > 1:
assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets'
assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number
assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'

prev_image_size = self.image_sizes[start_at_unet_number - 2]
img = self.resize_to(start_image_or_video, prev_image_size)

# go through each unet in cascade

for unet_number, unet, channel, image_size, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm):
if unet_number < start_at_unet_number:
continue

assert not isinstance(unet, NullUnet), 'cannot sample from null unet'

context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext()

Expand Down Expand Up @@ -637,7 +654,7 @@ def noise_distribution(self, P_mean, P_std, batch_size):
def forward(
self,
images,
unet: Union[Unet, Unet3D] = None,
unet: Union[Unet, Unet3D, NullUnet] = None,
texts: List[str] = None,
text_embeds = None,
text_masks = None,
Expand All @@ -656,6 +673,8 @@ def forward(

unet = default(unet, lambda: self.get_unet(unet_number))

assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'

target_image_size = self.image_sizes[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None
Expand Down
37 changes: 35 additions & 2 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,6 +1641,20 @@ def forward(

return self.final_conv(x)

# null unet

class NullUnet(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.lowres_cond = False
self.dummy_parameter = nn.Parameter(torch.tensor([0.]))

def cast_model_parameters(self, *args, **kwargs):
return self

def forward(self, x, *args, **kwargs):
return x

# predefined unets, with configs lining up with hyperparameters in appendix of paper

class BaseUnet64(Unet):
Expand Down Expand Up @@ -1791,7 +1805,7 @@ def __init__(
self.only_train_unet_number = only_train_unet_number

for ind, one_unet in enumerate(unets):
assert isinstance(one_unet, (Unet, Unet3D))
assert isinstance(one_unet, (Unet, Unet3D, NullUnet))
is_first = ind == 0

one_unet = one_unet.cast_model_parameters(
Expand Down Expand Up @@ -2100,6 +2114,8 @@ def sample(
batch_size = 1,
cond_scale = 1.,
lowres_sample_noise_level = None,
start_at_unet_number = 1,
start_image_or_video = None,
stop_at_unet_number = None,
return_all_unet_outputs = False,
return_pil_images = False,
Expand Down Expand Up @@ -2155,10 +2171,25 @@ def sample(

skip_steps = cast_tuple(skip_steps, num_unets)

# handle starting at a unet greater than 1, for training only-upscaler training

if start_at_unet_number > 1:
assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets'
assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number
assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'

prev_image_size = self.image_sizes[start_at_unet_number - 2]
img = self.resize_to(start_image_or_video, prev_image_size)

# go through each unet in cascade

for unet_number, unet, channel, image_size, noise_scheduler, pred_objective, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.noise_schedulers, self.pred_objectives, self.dynamic_thresholding, cond_scale, init_images, skip_steps), disable = not use_tqdm):

if unet_number < start_at_unet_number:
continue

assert not isinstance(unet, NullUnet), 'one cannot sample from null / placeholder unets'

context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext()

with context:
Expand Down Expand Up @@ -2313,7 +2344,7 @@ def p_losses(
def forward(
self,
images,
unet: Unet = None,
unet: Union[Unet, Unet3D, NullUnet] = None,
texts: List[str] = None,
text_embeds = None,
text_masks = None,
Expand All @@ -2332,6 +2363,8 @@ def forward(

unet = default(unet, lambda: self.get_unet(unet_number))

assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'

noise_scheduler = self.noise_schedulers[unet_index]
p2_loss_weight_gamma = self.p2_loss_weight_gamma[unet_index]
pred_objective = self.pred_objectives[unet_index]
Expand Down
12 changes: 8 additions & 4 deletions imagen_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytorch_warmup as warmup

from imagen_pytorch.imagen_pytorch import Imagen
from imagen_pytorch.imagen_pytorch import Imagen, NullUnet
from imagen_pytorch.elucidated_imagen import ElucidatedImagen
from imagen_pytorch.data import cycle

Expand Down Expand Up @@ -486,12 +486,16 @@ def num_steps_taken(self, unet_number = None):
return self.steps[unet_number - 1].item()

def print_untrained_unets(self):
for ind, steps in enumerate(self.steps.tolist()):
if steps > 0:
print_final_error = False

for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)):
if steps > 0 or isinstance(unet, NullUnet):
continue

self.print(f'unet {ind + 1} has not been trained')
print_final_error = True

if torch.any(self.steps == 0):
if print_final_error:
self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets')

# data related functions
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.7.2'
__version__ = '1.8.0'

0 comments on commit ae5a1e1

Please sign in to comment.