Skip to content

Commit

Permalink
ENH: Different initialization methods for LoRA (#1189)
Browse files Browse the repository at this point in the history
This PR adds the possibility to use different initialization methods for
LoRA, as is a requirement for a completely backwards compatible adoption
of PEFT in diffusers.

The default is still the same as always, namely the one from the
reference implementation by Microsoft. On top of that, it is now
possible to pass `init_lora_weights='gaussian'` to initialize the LoRA
weights in the same way as is default for diffusers, namely with a
normal distribution which is scaled by 1/r.

The init method currently applies to LoRA linear and conv layers, but
not embedding layers, which are always initialized from a normal
distribution (and are probably irrelevant for diffusers).

In the future, similar extensions could be added for other adapter
methods.
  • Loading branch information
BenjaminBossan authored Nov 29, 2023
1 parent 04c4110 commit f0fb951
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 12 deletions.
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
extras["quality"] = ["black ~= 22.0", "ruff>=0.0.241", "urllib3<=2.0.0"]
extras["docs_specific"] = ["hf-doc-builder"]
extras["dev"] = extras["quality"] + extras["docs_specific"]
extras["test"] = extras["dev"] + ["pytest", "pytest-cov", "pytest-xdist", "parameterized", "datasets", "diffusers<0.21.0"]
extras["test"] = extras["dev"] + [
"pytest", "pytest-cov", "pytest-xdist", "parameterized", "datasets", "diffusers<0.21.0", "scipy"
]

setup(
name="peft",
Expand Down
12 changes: 8 additions & 4 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from dataclasses import dataclass, field
from typing import List, Optional, Union
from typing import List, Literal, Optional, Union

from peft.config import PeftConfig
from peft.utils import PeftType
Expand Down Expand Up @@ -76,12 +78,14 @@ class LoraConfig(PeftConfig):
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
},
)
init_lora_weights: bool = field(
init_lora_weights: bool | Literal["gaussian"] = field(
default=True,
metadata={
"help": (
"Whether to initialize the weights of the Lora layers with their default initialization. Don't change "
"this setting, except if you know exactly what you're doing."
"How to initialize the weights of the LoRA layers. Passing True (default) results in the default "
"initialization from the reference implementation from Microsoft. Passing 'gaussian' results "
"in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization "
"to False leads to completely random initialization and is discouraged."
),
},
)
Expand Down
22 changes: 15 additions & 7 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)
self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights:
self.reset_lora_parameters(adapter_name)
self.reset_lora_parameters(adapter_name, init_lora_weights)

weight = getattr(self.get_base_layer(), "weight", None)
if weight is not None:
Expand Down Expand Up @@ -116,7 +116,7 @@ def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lo
self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)
self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights:
self.reset_lora_parameters(adapter_name)
self.reset_lora_parameters(adapter_name, init_lora_weights)

weight = getattr(base_layer, "weight", None)
if weight is not None:
Expand All @@ -142,8 +142,7 @@ def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init
self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A)
self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B)
self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights:
self.reset_lora_parameters(adapter_name)
self.reset_lora_parameters(adapter_name, init_lora_weights)

base_layer = self.get_base_layer()
weight = getattr(base_layer, "weight", None)
Expand All @@ -152,10 +151,19 @@ def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init
self.to(base_layer.weight.device, dtype=weight.dtype)
self.set_adapter(self.active_adapters)

def reset_lora_parameters(self, adapter_name):
def reset_lora_parameters(self, adapter_name, init_lora_weights):
if init_lora_weights is False:
return

if adapter_name in self.lora_A.keys():
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
if init_lora_weights is True:
# initialize A the same way as the default for nn.Linear and B to zero
# https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
elif init_lora_weights.lower() == "gaussian":
nn.init.normal_(self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name])
else:
raise ValueError(f"Unknown initialization {init_lora_weights=}")
nn.init.zeros_(self.lora_B[adapter_name].weight)
if adapter_name in self.lora_embedding_A.keys():
# initialize a the same way as the default for nn.linear and b to zero
Expand Down
232 changes: 232 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from scipy import stats
from torch import nn

from peft import LoraConfig, get_peft_model
from peft.utils import infer_device


class InitializationTest(unittest.TestCase):
"""Test class to check the initialization of adapters."""

torch_device = infer_device()

def get_uniform(self, amin, amax, size=(10000,)):
unif = torch.distributions.uniform.Uniform(amin, amax)
samples = unif.sample(size)
return samples

def get_normal(self, mean, std, size=(10000,)):
normal = torch.distributions.normal.Normal(mean, std)
samples = normal.sample(size)
return samples

def get_model(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
# choose a large weight so that averages are close to expected values
self.linear = nn.Linear(1000, 1000)
self.embed = nn.Embedding(1000, 1000)
self.conv2d = nn.Conv2d(100, 100, 3)

def forward(self, x):
return self.linear(x)

return MyModule().eval().to(self.torch_device)

def test_lora_linear_init_default(self):
# default is True
torch.manual_seed(0)

model = self.get_model()
config = LoraConfig(target_modules=["linear"])
model = get_peft_model(model, config)
weight_A = model.linear.lora_A["default"].weight
weight_B = model.linear.lora_B["default"].weight

# use statistical test to check if weight A is from a uniform distribution
unif = self.get_uniform(weight_A.min().item(), weight_A.max().item())
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy())
self.assertGreater(p_value, 0.5)

# check that weight A is *not* from a normal distribution
normal = self.get_normal(weight_A.mean().item(), weight_A.std().item())
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy())
self.assertLess(p_value, 0.05)

# check that weight B is zero
self.assertTrue((weight_B == 0.0).all())

def test_lora_linear_init_gaussian(self):
# use gaussian init
torch.manual_seed(0)

model = self.get_model()
config = LoraConfig(target_modules=["linear"], init_lora_weights="gaussian")
model = get_peft_model(model, config)
weight_A = model.linear.lora_A["default"].weight
weight_B = model.linear.lora_B["default"].weight

# use statistical test to check if weight A is from a normal distribution
normal = self.get_normal(0.0, 1 / config.r)
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy())

# import matplotlib.pyplot as plt
# x = weight_A.detach().flatten().cpu().numpy()
# breakpoint()

self.assertGreater(p_value, 0.5)

# check that weight A is *not* from a uniform distribution
unif = self.get_uniform(weight_A.min().item(), weight_A.max().item())
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy())
self.assertLess(p_value, 0.05)

# check that weight B is zero
self.assertTrue((weight_B == 0.0).all())

def test_lora_linear_false(self):
torch.manual_seed(0)

model = self.get_model()
config = LoraConfig(target_modules=["linear"], init_lora_weights=False)
model = get_peft_model(model, config)
weight_B = model.linear.lora_B["default"].weight

# with init_lora_weights=False, weight B should *not* be zero. We don't care so much about the actual values
# as long as they are not zero, in order to avoid identity transformation.
self.assertFalse(torch.allclose(weight_B, torch.zeros_like(weight_B)))

def test_lora_embedding_default(self):
# embedding is initialized as a normal distribution, not kaiming uniform
torch.manual_seed(0)

model = self.get_model()
config = LoraConfig(target_modules=["embed"])
model = get_peft_model(model, config)
weight_A = model.embed.lora_embedding_A["default"]
weight_B = model.embed.lora_embedding_B["default"]

# use statistical test to check if weight B is from a normal distribution
normal = self.get_normal(0.0, 1.0)
_, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy())
self.assertGreater(p_value, 0.5)

# check that weight B is *not* from a uniform distribution
unif = self.get_uniform(weight_B.min().item(), weight_B.max().item())
_, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy())
self.assertLess(p_value, 0.05)

# check that weight A is zero
self.assertTrue((weight_A == 0.0).all())

def test_lora_embedding_gaussian(self):
# embedding does not change with init_lora_weights="gaussian" vs True
torch.manual_seed(0)

model = self.get_model()
config = LoraConfig(target_modules=["embed"], init_lora_weights="gaussian")
model = get_peft_model(model, config)
weight_A = model.embed.lora_embedding_A["default"]
weight_B = model.embed.lora_embedding_B["default"]

# use statistical test to check if weight B is from a normal distribution
normal = self.get_normal(0.0, 1.0)
_, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy())
self.assertGreater(p_value, 0.5)

# check that weight B is *not* from a uniform distribution
unif = self.get_uniform(weight_B.min().item(), weight_B.max().item())
_, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy())
self.assertLess(p_value, 0.05)

# check that weight A is zero
self.assertTrue((weight_A == 0.0).all())

def test_lora_embedding_false(self):
torch.manual_seed(0)

model = self.get_model()
config = LoraConfig(target_modules=["embed"], init_lora_weights=False)
model = get_peft_model(model, config)
weight_A = model.embed.lora_embedding_B["default"]

# with init_lora_weights=False, weight A should *not* be zero. We don't care so much about the actual values
# as long as they are not zero, in order to avoid identity transformation.
self.assertFalse(torch.allclose(weight_A, torch.zeros_like(weight_A)))

def test_lora_conv2d_default(self):
# default is True
torch.manual_seed(0)

model = self.get_model()
config = LoraConfig(target_modules=["conv2d"])
model = get_peft_model(model, config)
weight_A = model.conv2d.lora_A["default"].weight
weight_B = model.conv2d.lora_B["default"].weight

# use statistical test to check if weight A is from a uniform distribution
unif = self.get_uniform(weight_A.min().item(), weight_A.max().item())
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy())
self.assertGreater(p_value, 0.5)

# check that weight A is *not* from a normal distribution
normal = self.get_normal(weight_A.mean().item(), weight_A.std().item())
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy())
self.assertLess(p_value, 0.05)

# check that weight B is zero
self.assertTrue((weight_B == 0.0).all())

def test_lora_conv2d_init_gaussian(self):
# use gaussian init
torch.manual_seed(0)

model = self.get_model()
config = LoraConfig(target_modules=["conv2d"], init_lora_weights="gaussian")
model = get_peft_model(model, config)
weight_A = model.conv2d.lora_A["default"].weight
weight_B = model.conv2d.lora_B["default"].weight

# use statistical test to check if weight A is from a normal distribution
normal = self.get_normal(0.0, 1 / config.r)
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy())
self.assertGreater(p_value, 0.5)

# check that weight A is *not* from a uniform distribution
unif = self.get_uniform(weight_A.min().item(), weight_A.max().item())
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy())
self.assertLess(p_value, 0.05)

# check that weight B is zero
self.assertTrue((weight_B == 0.0).all())

def test_lora_conv2d_false(self):
torch.manual_seed(0)

model = self.get_model()
config = LoraConfig(target_modules=["conv2d"], init_lora_weights=False)
model = get_peft_model(model, config)
weight_B = model.conv2d.lora_B["default"].weight

# with init_lora_weights=False, weight B should *not* be zero. We don't care so much about the actual values
# as long as they are not zero, in order to avoid identity transformation.
self.assertFalse(torch.allclose(weight_B, torch.zeros_like(weight_B)))

0 comments on commit f0fb951

Please sign in to comment.