Skip to content

Commit

Permalink
🐛minor refactoring and bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
yaswanth19 committed Oct 6, 2024
1 parent 9a079da commit 3cf6418
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 53 deletions.
46 changes: 22 additions & 24 deletions src/peft/tuners/lokrv2/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@

import math
import warnings
from typing import Any, Optional, Set, Tuple
from typing import Optional, Set, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from lycoris.functional import lokr
from transformers.pytorch_utils import Conv1D

from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge

Expand Down Expand Up @@ -76,7 +74,6 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self._disable_adapters = False
self.merged_adapters = []
self.kwargs = kwargs
print(kwargs)

@property
def merged(self) -> bool:
Expand Down Expand Up @@ -133,7 +130,7 @@ def update_layer(
in_m, in_n = self.factorization(self.in_features, decompose_factor)
out_l, out_k = self.factorization(self.out_features, decompose_factor)

if hasattr(self, "kernel_size"): # For Conv2d
if hasattr(base_layer, "kernel_size"): # For Conv2d
k_size = base_layer.kernel_size
shape = ((out_l, out_k), (in_m, in_n), *k_size)
use_w2 = r >= max(shape[0][1], shape[1][1]) / 2
Expand Down Expand Up @@ -214,13 +211,13 @@ def factorization(self, dimension: int, factor: int = -1) -> Tuple[int, int]:
Args:
dimension (`int`): The number that needs to be factorized.
factor (`int`, optional):
Factorization divider. The algorithm will try to output two numbers, one of each will be as close to the
factor as possible. If -1 is provided, the decomposition algorithm would try to search dividers near the
square root of the dimension. Defaults to -1.
Factorization divider. The algorithm will try to output two numbers, one of each will be as close to
the factor as possible. If -1 is provided, the decomposition algorithm would try to search dividers
near the square root of the dimension. Defaults to -1.
Returns:
Tuple[`int`, `int`]: A tuple of two numbers, whose product is equal to the provided number. The first number is
always less than or equal to the second.
Tuple[`int`, `int`]: A tuple of two numbers, whose product is equal to the provided number. The first
number is always less than or equal to the second.
Example:
```py
Expand Down Expand Up @@ -271,7 +268,7 @@ def make_kron(self, w1, w2, scale=1.0):
return rebuild * scale


class Linear(nn.Module, LoKrLayerv2):
class Linear(nn.Linear, LoKrLayerv2):
"""LoKr implemented in Linear layer"""

def __init__(
Expand All @@ -283,10 +280,10 @@ def __init__(
rank_dropout: float = 0.0,
module_dropout: float = 0.0,
init_weights: bool | str = True,
fan_in_fan_out:bool = False,
fan_in_fan_out: bool = False,
**kwargs,
):
super().__init__()
super(nn.Linear,self).__init__()
LoKrLayerv2.__init__(self, base_layer, **kwargs)
self.fan_in_fan_out = fan_in_fan_out

Expand All @@ -307,7 +304,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(adapter_names)
adapter_names = check_adapters_to_merge(self,adapter_names)

if not adapter_names:
return
Expand Down Expand Up @@ -368,8 +365,8 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
else:
w2 = self.lokr_w2_a[adapter_name] @ self.lokr_w2_b[adapter_name]

device = w1.weight.device
dtype = w1.weight.dtype
device = w1.device
dtype = w1.dtype

cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
if cast_to_fp32:
Expand All @@ -386,8 +383,8 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
drop = (torch.rand(weight.size(0)) > rank_dropout).float()
drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device)
# consider adapter name check
if self.kwargs["rank_dropout_scale"]:
drop /= drop.mean()
# if self.kwargs["rank_dropout_scale"]:
drop /= drop.mean()
weight *= drop

if cast_to_fp32:
Expand Down Expand Up @@ -425,7 +422,7 @@ def __repr__(self) -> str:
return "lokr." + rep


class Conv2d(LoKrLayerv2):
class Conv2d(nn.Module, LoKrLayerv2):
"""LoKr implemented in Conv2d layer"""

def __init__(
Expand All @@ -447,7 +444,9 @@ def __init__(

# Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout,init_weights, use_effective_conv2d,**kwargs)
self.update_layer(
adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs
)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Expand All @@ -462,7 +461,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(adapter_names)
adapter_names = check_adapters_to_merge(self,adapter_names)

if not adapter_names:
return
Expand Down Expand Up @@ -522,8 +521,8 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
else:
w2 = self.lokr_w2_a[adapter_name] @ self.lokr_w2_b[adapter_name]

device = w1.weight.device
dtype = w1.weight.dtype
device = w1.device
dtype = w1.dtype

cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
if cast_to_fp32:
Expand Down Expand Up @@ -583,4 +582,3 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
def __repr__(self) -> str:
rep = super().__repr__()
return "lokr." + rep

58 changes: 29 additions & 29 deletions src/peft/tuners/lokrv2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,15 @@

from __future__ import annotations

import math
import warnings
from dataclasses import asdict
from enum import Enum
from typing import Optional, Union
from typing import Optional

import torch
import torch.nn as nn
from torch.nn.init import _calculate_correct_fan
from tqdm import tqdm
from transformers.pytorch_utils import Conv1D

from peft.tuners.tuners_utils import (
BaseTuner,
Expand Down Expand Up @@ -120,11 +118,11 @@ def _check_new_adapter_config(self, config: LoKrConfigv2) -> None:
f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, "
"set bias to 'none' for all adapters."
)

@staticmethod
def _check_target_module_exists(LoKrConfigv2, key):
return check_target_module_exists(LoKrConfigv2, key)

def _create_and_replace(
self,
lokr_config: LoKrConfigv2,
Expand All @@ -140,35 +138,37 @@ def _create_and_replace(

r = lokr_config.r
bias = hasattr(target, "bias") and target.bias is not None

# Can the config be directly converted to dict and passed
kwargs = {
"r": r,
"rank_dropout": lokr_config.rank_dropout,
"fan_in_fan_out": lokr_config.fan_in_fan_out,
"init_weights": lokr_config.init_weights,
'use_effective_conv2d': lokr_config.use_effective_conv2d,
'decompose_both':lokr_config.decompose_both,
'decompose_factor':lokr_config.decompose_factor
"use_effective_conv2d": lokr_config.use_effective_conv2d,
"decompose_both": lokr_config.decompose_both,
"decompose_factor": lokr_config.decompose_factor,
}
kwargs["bias"] = bias

if not isinstance(target,LoKrLayerv2):
new_module= self._create_new_module(lokr_config,adapter_name,target,**kwargs)
if adapter_name not in self.active_adapters:
new_module.required_grad(False)
self._replace_module(parent, target_name,new_module, target)

if isinstance(target, LoKrLayerv2):
target.update_layer(
adapter_name,
r=lokr_config.r,
alpha=lokr_config.alpha,
rank_dropout=lokr_config.rank_dropout,
module_dropout=lokr_config.module_dropout,
init_weights=lokr_config.init_weights,
use_effective_conv2d=lokr_config.use_effective_conv2d,
decompose_both=lokr_config.decompose_both,
decompose_factor=lokr_config.decompose_factor,
)
else:
target.update_layer(adapter_name,
r=lokr_config.r,
alpha=lokr_config.alpha,
rank_dropout=lokr_config.rank_dropout,
module_dropout=lokr_config.module_dropout,
init_weights=lokr_config.init_weights,
use_effective_conv2d=lokr_config.use_effective_conv2d,
decompose_both=lokr_config.decompose_both,
decompose_factor=lokr_config.decompose_factor
)

new_module = self._create_new_module(lokr_config, adapter_name, target, **kwargs)
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)

def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
for n, p in model.named_parameters():
if self.prefix not in n:
Expand All @@ -189,7 +189,7 @@ def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
m.bias.requires_grad = True
else:
raise NotImplementedError(f"Requested bias: {bias}, is not implemented.")

def _replace_module(self, parent, child_name, new_module, child):
setattr(parent, child_name, new_module)
# It's not necessary to set requires_grad here, as that is handled by
Expand Down Expand Up @@ -355,8 +355,8 @@ def merge_and_unload(
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> torch.nn.Module:
r"""
This method merges the LoKr layers into the base model. This is needed if someone wants to use the base model as
a standalone model.
This method merges the LoKr layers into the base model. This is needed if someone wants to use the base model
as a standalone model.
Args:
progressbar (`bool`):
Expand Down

0 comments on commit 3cf6418

Please sign in to comment.