Skip to content

Commit

Permalink
move load scheduler to method
Browse files Browse the repository at this point in the history
  • Loading branch information
KohakuBlueleaf committed Jun 25, 2024
1 parent 945688d commit c0a936e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
11 changes: 11 additions & 0 deletions hunyuan_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import argparse

import torch
from diffusers import DDPMScheduler
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()
Expand Down Expand Up @@ -77,6 +78,16 @@ def load_tokenizer(self, args):
tokenizer = hunyuan_utils.load_tokenizers()
return tokenizer

def load_noise_scheduler(self):
return DDPMScheduler(
beta_start=0.00085,
beta_end=0.03,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
steps_offset=1
)

def is_text_encoder_outputs_cached(self, args):
return args.cache_text_encoder_outputs

Expand Down
10 changes: 7 additions & 3 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def load_tokenizer(self, args):
tokenizer = train_util.load_tokenizer(args)
return tokenizer

def load_noise_scheduler(self):
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
return noise_scheduler

def is_text_encoder_outputs_cached(self, args):
return False

Expand Down Expand Up @@ -839,9 +845,7 @@ def load_model_hook(models, input_dir):

global_step = 0

noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
noise_scheduler = self.load_noise_scheduler()
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
Expand Down

0 comments on commit c0a936e

Please sign in to comment.