Skip to content

Commit

Permalink
Added scaling and full matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
yaswanth19 committed Oct 20, 2024
1 parent 8ac5414 commit 4501e28
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 33 deletions.
28 changes: 18 additions & 10 deletions src/peft/tuners/lokrv2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class LoKrV2Config(PeftConfig):
Perform rank decomposition of left kronecker product matrix.
decompose_factor (`int`):
Kronecker product decomposition factor.
rank_dropout_scale ('bool)
Scale the rank dropout while training.
target_modules (`Optional[Union[List[str], str]]`):
The names of the modules to apply the adapter to. If this is specified, only the modules with the specified
names will be replaced. When passing a string, a regex match will be performed. When passing a list of
Expand All @@ -49,6 +51,10 @@ class LoKrV2Config(PeftConfig):
excluding the output layer. If this is not specified, modules will be chosen according to the model
architecture. If the architecture is not known, an error will be raised -- in this case, you should specify
the target modules manually.
bias (`str`):
Bias type for LoKr. Can be 'none', 'all' or 'lokr_only'. If 'all' or 'lora_only', the corresponding biases
will be updated during training. Be aware that this means that, even when disabling the adapters, the model
will not produce the same output as the base model would have without adaptation.
init_weights (`bool`):
Whether to perform initialization of adapter weights. This defaults to `True`, passing `False` is
discouraged.
Expand All @@ -66,8 +72,13 @@ class LoKrV2Config(PeftConfig):
specified by `alpha`.
modules_to_save (`Optional[List[str]]`):
List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
use_upstream ('bool'):
Use the latest version of LoKr from Lycoris Implementation.
use_scalar (`Optional[bool]`):
Whether to use scalar multiplication for LoKR. If `True`, a scalar value will be learned and multiplied
with the adapter weights.
use_full_matrix (`Optional[bool]`):
Whether to use the full matrix instead of performing Low-Rank Decomposition for the LoKR layers.
use_upstream (`Optional[bool]`):
Whether to use the latest version of the LoKR module from the `Lycoris` repository.
"""

r: int = field(default=8, metadata={"help": "LoKr rank"})
Expand All @@ -93,6 +104,7 @@ class LoKrV2Config(PeftConfig):
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
)
decompose_factor: int = field(default=-1, metadata={"help": "Kronecker product decomposition factor."})
rank_dropout_scale: bool = field(default=False, metadata={"help": "Rank dropout scale"})
target_modules: Optional[Union[list[str], str]] = field(
default=None,
metadata={
Expand All @@ -101,7 +113,7 @@ class LoKrV2Config(PeftConfig):
"This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer."
},
)
bias: str = field(default="none", metadata={"help": "Bias type for Vera. Can be 'none', 'all' or 'lokr_only'"})
bias: str = field(default="none", metadata={"help": "Bias type for LoKr. Can be 'none', 'all' or 'lokr_only'"})
init_weights: bool = field(
default=True,
metadata={
Expand Down Expand Up @@ -153,19 +165,15 @@ class LoKrV2Config(PeftConfig):
)
use_scalar: Optional[bool] = field(
default=False,
metadata={"help": "Use scalar for multiplication instead of vector for LoKR."},
)
weight_decompose: Optional[bool] = field(
default=False,
metadata={"help": "Weight decompositon for LoKr matrices."},
metadata={"help": "Use scalar multiplication for LoKR."},
)
use_full_matrix: Optional[bool] = field(
default=False,
metadata={"help": "Use full matrix decompositon for full matrix."},
metadata={"help": "Use full matrix instead of Low-Rank Decomposition."},
)
use_upstream: Optional[bool] = field(
default=False,
metadata={"help": "Whether to use latest version of LoKr module or not."},
metadata={"help": "Use the latest version of LoKr module from `Lycoris` repository."},
)

def __post_init__(self):
Expand Down
65 changes: 45 additions & 20 deletions src/peft/tuners/lokrv2/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# weight decompose, fullm atrix and rslora remaining

import math
import warnings
from typing import Optional, Set, Tuple
Expand All @@ -38,15 +36,14 @@ class LoKrLayerv2(BaseTunerLayer):
"lokr_t2",
)
# Other params which may contain adapter related keys
other_param_names = ("r", "alpha", "scaling", "rank_dropout", "module_dropout")
other_param_names = ("r", "alpha", "scale", "rank_dropout", "module_dropout")

def __init__(self, base_layer: nn.Module, **kwargs) -> None:
super().__init__()
self.base_layer = base_layer
self.r = {}
self.alpha = {}
self.scale = {}
self.scaling = {}
self.rank_dropout = {}
self.module_dropout = {}

Expand Down Expand Up @@ -106,8 +103,8 @@ def update_layer(
use_effective_conv2d: bool,
decompose_both: bool,
decompose_factor: int,
full_matrix: bool,
use_upstream: bool = False,
**kwargs,
) -> None:
"""Internal function to create lokr adapter
Expand All @@ -121,13 +118,13 @@ def update_layer(
use_effective_conv2d (`bool`): Use parameter effective decomposition for Conv2d with ksize > 1.
decompose_both (`bool`): Perform rank decomposition of left kronecker product matrix.
decompose_factor (`int`): Kronecker product decomposition factor.
full_matrix (`bool`): Use full matrix instead of Low-Rank Decomposition.
use_upstream: Use the weight initializaition from Lycoris library.
"""
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

self.r[adapter_name] = r
self.alpha[adapter_name] = alpha
self.scaling[adapter_name] = alpha / r
self.rank_dropout[adapter_name] = rank_dropout
self.module_dropout[adapter_name] = module_dropout
base_layer = self.get_base_layer()
Expand All @@ -137,21 +134,34 @@ def update_layer(

if isinstance(base_layer, nn.Conv2d): # For Conv2d
shape = ((out_l, out_k), (in_m, in_n), *self.kernel_size)
use_w2 = r >= max(shape[0][1], shape[1][1]) / 2
use_w2 = r >= max(shape[0][1], shape[1][1]) / 2 or full_matrix
use_effective_conv2d = use_effective_conv2d and self.kernel_size != (1, 1)
else:
shape = ((out_l, out_k), (in_m, in_n))
use_w2 = not (r < max(shape[0][1], shape[1][1]) / 2)
use_w2 = not (r < max(shape[0][1], shape[1][1]) / 2 and not full_matrix)

if use_w2 and not full_matrix:
warnings.warn(
f"Lora dim {r} is too large for dim={max(self.in_features,self.out_features)} and factor={decompose_factor}."
"Hence using full matrix mode."
)

use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2 and not full_matrix)

use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2)
if (use_w1 and use_w2) or alpha is None or alpha == 0:
alpha = r

self.alpha[adapter_name] = alpha
self.scale[adapter_name] = alpha / r

if use_upstream:
# Creating dummy weights of required shape, as it is a argument for weight_gen function.
if self.kernel_size:
dummy_weights = torch.rand((self.in_features, self.out_features, *self.kernel_size))
else:
dummy_weights = torch.rand((self.in_features, self.out_features))

weights = lokr.weight_gen(dummy_weights, rank=r, decompose_both=decompose_both, **kwargs)
weights = lokr.weight_gen(dummy_weights, rank=r, factor=decompose_factor,decompose_both=decompose_both, tucker=use_effective_conv2d, full_matrix=full_matrix)
attributes = [
"lokr_w1",
"lokr_w1_a",
Expand Down Expand Up @@ -312,9 +322,17 @@ def __init__(
LoKrLayerv2.__init__(self, base_layer, **kwargs)
self.fan_in_fan_out = fan_in_fan_out

update_layer_kwargs = {
'use_effective_conv2d':kwargs.get('use_effective_conv2d',False),
'decompose_both': kwargs.get('decompose_both', False),
'decompose_factor': kwargs.get('decompose_factor', 1),
'full_matrix': kwargs.get('full_matrix', False),
'use_upstream': kwargs.get('use_upstream', False),
}

# Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs)
self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights,**update_layer_kwargs)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Expand Down Expand Up @@ -399,17 +417,16 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
w2 = w2.float()

# Make weights with Kronecker product
weight = self.make_kron(w1, w2)
weight = self.make_kron(w1, w2, self.scale[adapter_name])
weight = weight.reshape(self.get_base_layer().weight.shape)

# Perform rank dropout during training - drop rows of addition weights
rank_dropout = self.rank_dropout[adapter_name]
if self.training and rank_dropout:
drop = (torch.rand(weight.size(0)) > rank_dropout).float()
drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device)
# consider adapter name check
# if self.kwargs["rank_dropout_scale"]:
drop /= drop.mean()
if self.kwargs["rank_dropout_scale"]:
drop /= drop.mean()
weight *= drop

if cast_to_fp32:
Expand Down Expand Up @@ -467,10 +484,18 @@ def __init__(
LoKrLayerv2.__init__(self, base_layer, **kwargs)
self.fan_in_fan_out = fan_in_fan_out

update_layer_kwargs = {
'use_effective_conv2d':kwargs.get('use_effective_conv2d',False),
'decompose_both': kwargs.get('decompose_both', False),
'decompose_factor': kwargs.get('decompose_factor', 1),
'full_matrix': kwargs.get('full_matrix', False),
'use_upstream': kwargs.get('use_upstream', False),
}

# Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer(
adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs
adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **update_layer_kwargs
)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
Expand Down Expand Up @@ -555,16 +580,16 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
w2 = w2.float()

# Make weights with Kronecker product
weight = self.make_kron(w1, w2)
weight = self.make_kron(w1, w2, self.scale[adapter_name])
weight = weight.reshape(self.get_base_layer().weight.shape)

# Perform rank dropout during training - drop rows of addition weights
rank_dropout = self.rank_dropout[adapter_name]
if self.training and rank_dropout:
drop = (torch.rand(weight.size(0)) > rank_dropout).float()
drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device)
# if self.kwargs["rank_dropout_scale"]:
drop /= drop.mean()
if self.kwargs["rank_dropout_scale"]:
drop /= drop.mean()
weight *= drop

if cast_to_fp32:
Expand Down
8 changes: 5 additions & 3 deletions src/peft/tuners/lokrv2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,19 @@ def _create_and_replace(
raise ValueError("Current Key shouldn't be `None`")

r = lokr_config.r
# bias = hasattr(target, "bias") and target.bias is not None
bias = hasattr(target, "bias") and target.bias is not None
kwargs = {
"r": r,
"rank_dropout": lokr_config.rank_dropout,
"fan_in_fan_out": lokr_config.fan_in_fan_out,
"init_weights": lokr_config.init_weights,
"rank_dropout_scale": lokr_config.rank_dropout_scale,
"use_effective_conv2d": lokr_config.use_effective_conv2d,
"decompose_both": lokr_config.decompose_both,
"decompose_factor": lokr_config.decompose_factor,
"use_scalar": lokr_config.use_scalar,
}
# kwargs["bias"] = bias
kwargs["bias"] = bias

if isinstance(target, LoKrLayerv2):
target.update_layer(
Expand All @@ -159,8 +161,8 @@ def _create_and_replace(
use_effective_conv2d=lokr_config.use_effective_conv2d,
decompose_both=lokr_config.decompose_both,
decompose_factor=lokr_config.decompose_factor,
full_matrix=lokr_config.use_full_matrix,
use_upstream=lokr_config.use_upstream,
**kwargs,
)
else:
new_module = self._create_new_module(lokr_config, adapter_name, target, **kwargs)
Expand Down

0 comments on commit 4501e28

Please sign in to comment.