diff --git a/flux_train_network.py b/flux_train_network.py index 704c4d32e..679db62b6 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -445,6 +445,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t if len(diff_output_pr_indices) > 0: network.set_multiplier(0.0) + unet.prepare_block_swap_before_forward() with torch.no_grad(): model_pred_prior = call_dit( img=packed_noisy_model_input[diff_output_pr_indices],