Skip to content

Commit

Permalink
Use self.get_noise_pred_and_target and drop fixed timesteps
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Jan 6, 2025
1 parent 1c0ae30 commit bbf6bbd
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 86 deletions.
7 changes: 5 additions & 2 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def get_noise_pred_and_target(
network,
weight_dtype,
train_unet,
is_train=True
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
Expand Down Expand Up @@ -375,7 +376,7 @@ def get_noise_pred_and_target(
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:
# normal forward
with accelerator.autocast():
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
img=img,
Expand Down Expand Up @@ -420,7 +421,9 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
with torch.set_grad_enabled(is_train and train_unet):
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""

return model_pred
Expand Down
3 changes: 2 additions & 1 deletion sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def get_noise_pred_and_target(
network,
weight_dtype,
train_unet,
is_train=True
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
Expand Down Expand Up @@ -339,7 +340,7 @@ def get_noise_pred_and_target(
t5_attn_mask = None

# call model
with accelerator.autocast():
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
# TODO support attention mask
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)

Expand Down
116 changes: 33 additions & 83 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def get_noise_pred_and_target(
network,
weight_dtype,
train_unet,
is_train=True
):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
Expand All @@ -236,7 +237,7 @@ def get_noise_pred_and_target(
t.requires_grad_(True)

# Predict the noise residual
with accelerator.autocast():
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
noise_pred = self.call_unet(
args,
accelerator,
Expand Down Expand Up @@ -317,7 +318,7 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch,

# endregion

def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor:
def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor:

with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
Expand Down Expand Up @@ -372,91 +373,40 @@ def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: Au

batch_size = latents.shape[0]

# Sample noise,
noise = train_util.make_noise(args, latents)

def pick_timesteps_list() -> torch.IntTensor:
if timesteps_list is None or timesteps_list == []:
return typing.cast(torch.IntTensor, train_util.make_random_timesteps(args, noise_scheduler, batch_size, latents.device).unsqueeze(1))
else:
return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device))

chosen_timesteps_list = pick_timesteps_list()
total_loss = torch.zeros((batch_size, 1)).to(latents.device)

# Use input timesteps_list or use described timesteps above
for fixed_timesteps in chosen_timesteps_list:
fixed_timesteps = typing.cast(torch.IntTensor, fixed_timesteps)

# Predict the noise residual
# and add noise to the latents
# with noise offset and/or multires noise if specified
noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timesteps)

# ensure the hidden state will require grad
if args.gradient_checkpointing:
for x in noisy_latents:
x.requires_grad_(True)
for t in text_encoder_conds:
t.requires_grad_(True)

with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
noise_pred = self.call_unet(
args,
accelerator,
unet,
noisy_latents.requires_grad_(train_unet),
fixed_timesteps,
text_encoder_conds,
batch,
weight_dtype,
)

if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, fixed_timesteps)
else:
target = noise

# differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)

if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
noise_pred_prior = self.call_unet(
args,
accelerator,
unet,
noisy_latents,
fixed_timesteps,
text_encoder_conds,
batch,
weight_dtype,
indices=diff_output_pr_indices,
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)

huber_c = train_util.get_huber_threshold_if_needed(args, fixed_timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし
# Predict the noise residual
# and add noise to the latents
# with noise offset and/or multires noise if specified

if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
# sample noise, call unet, get target
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
is_train=is_train
)

loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
loss = loss * loss_weights
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])

loss = self.post_process_loss(loss, args, fixed_timesteps, noise_scheduler)
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights

total_loss += loss
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)

return total_loss / len(chosen_timesteps_list)
return loss.mean()

def train(self, args):
session_id = random.randint(0, 2**32)
Expand Down Expand Up @@ -1416,7 +1366,7 @@ def remove_model(old_ckpt_name):
if val_step >= validation_steps:
break

loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False)

val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item())
val_progress_bar.update(1)
Expand Down Expand Up @@ -1447,7 +1397,7 @@ def remove_model(old_ckpt_name):
if val_step >= validation_steps:
break

loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False)

current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
Expand Down

0 comments on commit bbf6bbd

Please sign in to comment.