-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Different initialization methods for LoRA #1189
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
# coding=utf-8 | ||
# Copyright 2023-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import torch | ||
from scipy import stats | ||
from torch import nn | ||
|
||
from peft import LoraConfig, get_peft_model | ||
from peft.utils import infer_device | ||
|
||
|
||
class InitializationTest(unittest.TestCase): | ||
"""Test class to check the initialization of adapters.""" | ||
|
||
torch_device = infer_device() | ||
|
||
def get_uniform(self, amin, amax, size=(10000,)): | ||
unif = torch.distributions.uniform.Uniform(amin, amax) | ||
samples = unif.sample(size) | ||
return samples | ||
|
||
def get_normal(self, mean, std, size=(10000,)): | ||
normal = torch.distributions.normal.Normal(mean, std) | ||
samples = normal.sample(size) | ||
return samples | ||
|
||
def get_model(self): | ||
class MyModule(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
# choose a large weight so that averages are close to expected values | ||
self.linear = nn.Linear(1000, 1000) | ||
self.embed = nn.Embedding(1000, 1000) | ||
self.conv2d = nn.Conv2d(100, 100, 3) | ||
|
||
def forward(self, x): | ||
return self.linear(x) | ||
|
||
return MyModule().eval().to(self.torch_device) | ||
|
||
def test_lora_linear_init_default(self): | ||
# default is True | ||
torch.manual_seed(0) | ||
|
||
model = self.get_model() | ||
config = LoraConfig(target_modules=["linear"]) | ||
model = get_peft_model(model, config) | ||
weight_A = model.linear.lora_A["default"].weight | ||
weight_B = model.linear.lora_B["default"].weight | ||
|
||
# use statistical test to check if weight A is from a uniform distribution | ||
unif = self.get_uniform(weight_A.min().item(), weight_A.max().item()) | ||
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy()) | ||
self.assertGreater(p_value, 0.5) | ||
|
||
# check that weight A is *not* from a normal distribution | ||
normal = self.get_normal(weight_A.mean().item(), weight_A.std().item()) | ||
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy()) | ||
self.assertLess(p_value, 0.05) | ||
|
||
# check that weight B is zero | ||
self.assertTrue((weight_B == 0.0).all()) | ||
|
||
def test_lora_linear_init_gaussian(self): | ||
# use gaussian init | ||
torch.manual_seed(0) | ||
|
||
model = self.get_model() | ||
config = LoraConfig(target_modules=["linear"], init_lora_weights="gaussian") | ||
model = get_peft_model(model, config) | ||
weight_A = model.linear.lora_A["default"].weight | ||
weight_B = model.linear.lora_B["default"].weight | ||
|
||
# use statistical test to check if weight A is from a normal distribution | ||
normal = self.get_normal(0.0, 1 / config.r) | ||
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy()) | ||
|
||
# import matplotlib.pyplot as plt | ||
# x = weight_A.detach().flatten().cpu().numpy() | ||
# breakpoint() | ||
|
||
self.assertGreater(p_value, 0.5) | ||
|
||
# check that weight A is *not* from a uniform distribution | ||
unif = self.get_uniform(weight_A.min().item(), weight_A.max().item()) | ||
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy()) | ||
self.assertLess(p_value, 0.05) | ||
|
||
# check that weight B is zero | ||
self.assertTrue((weight_B == 0.0).all()) | ||
|
||
def test_lora_linear_false(self): | ||
torch.manual_seed(0) | ||
|
||
model = self.get_model() | ||
config = LoraConfig(target_modules=["linear"], init_lora_weights=False) | ||
model = get_peft_model(model, config) | ||
weight_B = model.linear.lora_B["default"].weight | ||
|
||
# with init_lora_weights=False, weight B should *not* be zero. We don't care so much about the actual values | ||
# as long as they are not zero, in order to avoid identity transformation. | ||
self.assertFalse(torch.allclose(weight_B, torch.zeros_like(weight_B))) | ||
|
||
def test_lora_embedding_default(self): | ||
# embedding is initialized as a normal distribution, not kaiming uniform | ||
torch.manual_seed(0) | ||
|
||
model = self.get_model() | ||
config = LoraConfig(target_modules=["embed"]) | ||
model = get_peft_model(model, config) | ||
weight_A = model.embed.lora_embedding_A["default"] | ||
weight_B = model.embed.lora_embedding_B["default"] | ||
|
||
# use statistical test to check if weight B is from a normal distribution | ||
normal = self.get_normal(0.0, 1.0) | ||
_, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy()) | ||
self.assertGreater(p_value, 0.5) | ||
|
||
# check that weight B is *not* from a uniform distribution | ||
unif = self.get_uniform(weight_B.min().item(), weight_B.max().item()) | ||
_, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy()) | ||
self.assertLess(p_value, 0.05) | ||
|
||
# check that weight A is zero | ||
self.assertTrue((weight_A == 0.0).all()) | ||
|
||
def test_lora_embedding_gaussian(self): | ||
# embedding does not change with init_lora_weights="gaussian" vs True | ||
torch.manual_seed(0) | ||
|
||
model = self.get_model() | ||
config = LoraConfig(target_modules=["embed"], init_lora_weights="gaussian") | ||
model = get_peft_model(model, config) | ||
weight_A = model.embed.lora_embedding_A["default"] | ||
weight_B = model.embed.lora_embedding_B["default"] | ||
|
||
# use statistical test to check if weight B is from a normal distribution | ||
normal = self.get_normal(0.0, 1.0) | ||
_, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy()) | ||
self.assertGreater(p_value, 0.5) | ||
|
||
# check that weight B is *not* from a uniform distribution | ||
unif = self.get_uniform(weight_B.min().item(), weight_B.max().item()) | ||
_, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy()) | ||
self.assertLess(p_value, 0.05) | ||
|
||
# check that weight A is zero | ||
self.assertTrue((weight_A == 0.0).all()) | ||
|
||
def test_lora_embedding_false(self): | ||
torch.manual_seed(0) | ||
|
||
model = self.get_model() | ||
config = LoraConfig(target_modules=["embed"], init_lora_weights=False) | ||
model = get_peft_model(model, config) | ||
weight_A = model.embed.lora_embedding_B["default"] | ||
|
||
# with init_lora_weights=False, weight A should *not* be zero. We don't care so much about the actual values | ||
# as long as they are not zero, in order to avoid identity transformation. | ||
self.assertFalse(torch.allclose(weight_A, torch.zeros_like(weight_A))) | ||
|
||
def test_lora_conv2d_default(self): | ||
# default is True | ||
torch.manual_seed(0) | ||
|
||
model = self.get_model() | ||
config = LoraConfig(target_modules=["conv2d"]) | ||
model = get_peft_model(model, config) | ||
weight_A = model.conv2d.lora_A["default"].weight | ||
weight_B = model.conv2d.lora_B["default"].weight | ||
|
||
# use statistical test to check if weight A is from a uniform distribution | ||
unif = self.get_uniform(weight_A.min().item(), weight_A.max().item()) | ||
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy()) | ||
self.assertGreater(p_value, 0.5) | ||
|
||
# check that weight A is *not* from a normal distribution | ||
normal = self.get_normal(weight_A.mean().item(), weight_A.std().item()) | ||
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy()) | ||
self.assertLess(p_value, 0.05) | ||
|
||
# check that weight B is zero | ||
self.assertTrue((weight_B == 0.0).all()) | ||
|
||
def test_lora_conv2d_init_gaussian(self): | ||
# use gaussian init | ||
torch.manual_seed(0) | ||
|
||
model = self.get_model() | ||
config = LoraConfig(target_modules=["conv2d"], init_lora_weights="gaussian") | ||
model = get_peft_model(model, config) | ||
weight_A = model.conv2d.lora_A["default"].weight | ||
weight_B = model.conv2d.lora_B["default"].weight | ||
|
||
# use statistical test to check if weight A is from a normal distribution | ||
normal = self.get_normal(0.0, 1 / config.r) | ||
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy()) | ||
self.assertGreater(p_value, 0.5) | ||
|
||
# check that weight A is *not* from a uniform distribution | ||
unif = self.get_uniform(weight_A.min().item(), weight_A.max().item()) | ||
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy()) | ||
self.assertLess(p_value, 0.05) | ||
|
||
# check that weight B is zero | ||
self.assertTrue((weight_B == 0.0).all()) | ||
|
||
def test_lora_conv2d_false(self): | ||
torch.manual_seed(0) | ||
|
||
model = self.get_model() | ||
config = LoraConfig(target_modules=["conv2d"], init_lora_weights=False) | ||
model = get_peft_model(model, config) | ||
weight_B = model.conv2d.lora_B["default"].weight | ||
|
||
# with init_lora_weights=False, weight B should *not* be zero. We don't care so much about the actual values | ||
# as long as they are not zero, in order to avoid identity transformation. | ||
self.assertFalse(torch.allclose(weight_B, torch.zeros_like(weight_B))) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why's the version restriction needed on the
diffusers
side?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was pinned in #936 and can probably be unpinned now.