Skip to content

Commit

Permalink
Add LoRA+ support
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Apr 1, 2024
1 parent f931705 commit f99fe28
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 32 deletions.
2 changes: 2 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2789,6 +2789,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
default=1,
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
)
parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")


def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
Expand Down
45 changes: 33 additions & 12 deletions networks/dylora.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,27 +406,48 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
logger.info(f"weights are merged")
"""

def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
self.requires_grad_(True)
all_params = []

def enumerate_params(loras):
params = []
def assemble_params(loras, lr, lora_plus_ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
params.extend(lora.parameters())
for name, param in lora.named_parameters():
if lora_plus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param

# assigned_param_groups = ""
# for group in param_groups:
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
# logger.info(assigned_param_groups)

params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if lr is not None:
if key == "plus":
param_data["lr"] = lr * lora_plus_ratio
else:
param_data["lr"] = lr

if ("lr" in param_data) and (param_data["lr"] == 0):
continue

params.append(param_data)

return params

if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
all_params.extend(params)

if self.unet_loras:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
all_params.extend(params)

return all_params

Expand Down
54 changes: 35 additions & 19 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,21 +1035,43 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
return lr_weight

# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
self.requires_grad_(True)
all_params = []

def enumerate_params(loras):
params = []
def assemble_params(loras, lr, lora_plus_ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
params.extend(lora.parameters())
for name, param in lora.named_parameters():
if lora_plus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param

# assigned_param_groups = ""
# for group in param_groups:
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
# logger.info(assigned_param_groups)

params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if lr is not None:
if key == "plus":
param_data["lr"] = lr * lora_plus_ratio
else:
param_data["lr"] = lr

if ("lr" in param_data) and (param_data["lr"] == 0):
continue

params.append(param_data)

return params

if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
all_params.extend(params)

if self.unet_loras:
if self.block_lr:
Expand All @@ -1063,21 +1085,15 @@ def enumerate_params(loras):

# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
param_data = {"params": enumerate_params(block_loras)}

if unet_lr is not None:
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
elif default_lr is not None:
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
if ("lr" in param_data) and (param_data["lr"] == 0):
continue
all_params.append(param_data)
params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
all_params.extend(params)

else:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
all_params.extend(params)

return all_params

Expand Down
2 changes: 1 addition & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def train(self, args):

# 後方互換性を確保するよ
try:
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio)
except TypeError:
accelerator.print(
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
Expand Down

0 comments on commit f99fe28

Please sign in to comment.