Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

previous_dtype is now inferred from F.linear's result output type. #1010

Merged
merged 16 commits into from
Feb 19, 2024
16 changes: 9 additions & 7 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ def __init__(
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
fan_in_fan_out: bool = False,
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we undo this change, or at least move the comment above 254, where it belongs?

is_target_conv_1d_layer: bool = False,
init_lora_weights: Union[bool, str] = True,
**kwargs,
Expand Down Expand Up @@ -352,8 +353,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
return output_tensor

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
Expand All @@ -362,6 +361,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
Expand All @@ -372,7 +372,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
x = x.to(lora_A.weight.dtype)
result += lora_B(lora_A(dropout(x))) * scaling

result = result.to(previous_dtype)
result = result.to(torch_result_dtype)
return result

def __repr__(self) -> str:
Expand Down Expand Up @@ -507,6 +507,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_embedding_A:
continue
Expand All @@ -515,6 +516,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
scaling = self.scaling[active_adapter]
after_A = self._embed(x, embedding_A)
result += (after_A @ embedding_B) * scaling
result = result.to(torch_result_dtype)

return result

Expand Down Expand Up @@ -642,8 +644,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
return output_tensor

def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
Expand All @@ -652,6 +652,8 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
Expand All @@ -662,7 +664,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
x = x.to(lora_A.weight.dtype)
result += lora_B(lora_A(dropout(x))) * scaling

result = result.to(previous_dtype)
result = result.to(torch_result_dtype)
return result

def __repr__(self) -> str:
Expand Down
156 changes: 156 additions & 0 deletions tests/test_autocast_torchcompatibility_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import unittest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add the comment header that we have in all our files? Also, please move this test to test_gpu_examples.py, otherwise it won't be run on CI (since our normal CI has no GPUs and only selected test files are run with GPU in our nightly tests).


import torch
import torch.nn as nn

from peft.tuners.lora import Conv2d as LoraConv2d
from peft.tuners.lora import Embedding as LoraEmbedding
from peft.tuners.lora import Linear as LoraLinear


class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
super(SimpleModel, self).__init__()
super().__init__()

Same for other classes.


self.embedding_layer = nn.Embedding(1000, 768)
self.layer_norm = nn.LayerNorm(768)
self.linear_transform = nn.Linear(768, 256)

def forward(self, input_ids):
embedded_output = self.embedding_layer(input_ids)
norm_output = self.layer_norm(embedded_output)
linear_output = self.linear_transform(norm_output)

return linear_output


class SimpleConv2DModel(nn.Module):
def __init__(self):
super(SimpleConv2DModel, self).__init__()

self.embedding_layer = nn.Embedding(1000, 768)
self.layer_norm = nn.LayerNorm(768)
self.conv2d_transform = nn.Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

def forward(self, input_ids):
# Additional layers for your custom model
embedded_output = self.embedding_layer(input_ids)
norm_output = self.layer_norm(embedded_output)

# Reshape for Conv2d input (add batch size dimension)
norm_output = norm_output.unsqueeze(1)
conv_output = self.conv2d_transform(norm_output)

# Remove batch size dimension
conv_output = conv_output.squeeze(1)

return conv_output


class SimpleLorALinearModel(nn.Module):
"""Same as SimpleModel but wraps Linear in Lora layer"""

def __init__(self):
super(SimpleLorALinearModel, self).__init__()

self.embedding_layer = nn.Embedding(1000, 768)
self.layer_norm = nn.LayerNorm(768)
self.linear_transform_base = nn.Linear(768, 256)
self.linear_transform = LoraLinear(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of wrapping the layer explicitly, let's create a LoraConfig and call get_peft_model with SimpleModel as input. test_simple_lora_linear_model could be responsible for initializing the class and passes the instance to _test_model. This way, we can be 100% that this generates a model the same way that users typically do.

self.linear_transform_base, adapter_name="test_linear", r=8, lora_alpha=16, lora_dropout=0.05
)

def forward(self, input_ids):
embedded_output = self.embedding_layer(input_ids)
norm_output = self.layer_norm(embedded_output)
linear_output = self.linear_transform(norm_output)

return linear_output


class SimpleLorAEmbeddingModel(nn.Module):
"""Same as SimpleModel but wraps Embedding in Lora layer"""

def __init__(self):
super(SimpleLorAEmbeddingModel, self).__init__()

self.embedding_layer_base = nn.Embedding(1000, 768)
self.embedding_layer = LoraEmbedding(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same argument as above.

self.embedding_layer_base, adapter_name="test_embedding", r=8, lora_alpha=16, lora_dropout=0.05
)
self.layer_norm = nn.LayerNorm(768)
self.linear_transform = nn.Linear(768, 256)

def forward(self, input_ids):
embedded_output = self.embedding_layer(input_ids)
norm_output = self.layer_norm(embedded_output)
linear_output = self.linear_transform(norm_output)

return linear_output


class SimpleLorAConv2DModel(nn.Module):
"""Same as SimpleModel but wraps Conv2D in Lora layer"""

def __init__(self):
super(SimpleLorAConv2DModel, self).__init__()

self.embedding_layer = nn.Embedding(1000, 768)
self.layer_norm = nn.LayerNorm(768)
self.conv2d_transform_base = nn.Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.conv2d_transform = LoraConv2d(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same argument as above.

self.conv2d_transform_base, adapter_name="test_conv2d", r=8, lora_alpha=16, lora_dropout=0.05
)

def forward(self, input_ids):
# Additional layers for your custom model
embedded_output = self.embedding_layer(input_ids)
norm_output = self.layer_norm(embedded_output)

# Reshape for Conv2d input (add batch size dimension)
norm_output = norm_output.unsqueeze(1)
conv_output = self.conv2d_transform(norm_output)

# Remove batch size dimension
conv_output = conv_output.squeeze(1)

return conv_output


class TestAutoCast(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test seems to require a GPU, can you add the require_torch_gpu decorator here?

def require_torch_gpu(test_case):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

def test_simple_model(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about parametrizing the test (using parameterize, see other tests) over the dtype? That way, we can run a single test per test case, which is usually preferable.

self._test_model(SimpleModel)

def test_simple_conv2d_model(self):
self._test_model(SimpleConv2DModel)

def test_simple_lora_linear_model(self):
self._test_model(SimpleLorALinearModel)

def test_simple_lora_embedding_model(self):
self._test_model(SimpleLorAEmbeddingModel)

def test_simple_lora_conv2d_model(self):
self._test_model(SimpleLorAConv2DModel)

def _test_model(self, model_class):
# Instantiate the model
model = model_class().cuda()

# Prepare dummy inputs
input_ids = torch.randint(0, 1000, (2, 10)).cuda()

# Forward pass with torch.bfloat16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the bf16 case, can we please run self.skipTest if not torch.cuda.is_bf16_supported()?

with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
outputs = model(input_ids)
self.assertEqual(outputs.dtype, torch.bfloat16)

# Forward pass with torch.float32
with torch.autocast(enabled=True, dtype=torch.float32, device_type="cuda"):
outputs = model(input_ids)
self.assertEqual(outputs.dtype, torch.float32)

# Forward pass with torch.float16
with torch.autocast(enabled=True, dtype=torch.float16, device_type="cuda"):
outputs = model(input_ids)
self.assertEqual(outputs.dtype, torch.float16)
Loading