-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
previous_dtype is now inferred from F.linear's result output type. #1010
Changes from 7 commits
ec29aa8
552935a
2c1c8b0
f6ad119
aac3781
0f0daae
96d9ec7
3c7eb2e
c09fbe7
0b4001c
1748f21
57f6d8d
6d1df05
6b38805
c697dbe
751c5ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,156 @@ | ||||||
import unittest | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add the comment header that we have in all our files? Also, please move this test to |
||||||
|
||||||
import torch | ||||||
import torch.nn as nn | ||||||
|
||||||
from peft.tuners.lora import Conv2d as LoraConv2d | ||||||
from peft.tuners.lora import Embedding as LoraEmbedding | ||||||
from peft.tuners.lora import Linear as LoraLinear | ||||||
|
||||||
|
||||||
class SimpleModel(nn.Module): | ||||||
def __init__(self): | ||||||
super(SimpleModel, self).__init__() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit:
Suggested change
Same for other classes. |
||||||
|
||||||
self.embedding_layer = nn.Embedding(1000, 768) | ||||||
self.layer_norm = nn.LayerNorm(768) | ||||||
self.linear_transform = nn.Linear(768, 256) | ||||||
|
||||||
def forward(self, input_ids): | ||||||
embedded_output = self.embedding_layer(input_ids) | ||||||
norm_output = self.layer_norm(embedded_output) | ||||||
linear_output = self.linear_transform(norm_output) | ||||||
|
||||||
return linear_output | ||||||
|
||||||
|
||||||
class SimpleConv2DModel(nn.Module): | ||||||
def __init__(self): | ||||||
super(SimpleConv2DModel, self).__init__() | ||||||
|
||||||
self.embedding_layer = nn.Embedding(1000, 768) | ||||||
self.layer_norm = nn.LayerNorm(768) | ||||||
self.conv2d_transform = nn.Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | ||||||
|
||||||
def forward(self, input_ids): | ||||||
# Additional layers for your custom model | ||||||
embedded_output = self.embedding_layer(input_ids) | ||||||
norm_output = self.layer_norm(embedded_output) | ||||||
|
||||||
# Reshape for Conv2d input (add batch size dimension) | ||||||
norm_output = norm_output.unsqueeze(1) | ||||||
conv_output = self.conv2d_transform(norm_output) | ||||||
|
||||||
# Remove batch size dimension | ||||||
conv_output = conv_output.squeeze(1) | ||||||
|
||||||
return conv_output | ||||||
|
||||||
|
||||||
class SimpleLorALinearModel(nn.Module): | ||||||
"""Same as SimpleModel but wraps Linear in Lora layer""" | ||||||
|
||||||
def __init__(self): | ||||||
super(SimpleLorALinearModel, self).__init__() | ||||||
|
||||||
self.embedding_layer = nn.Embedding(1000, 768) | ||||||
self.layer_norm = nn.LayerNorm(768) | ||||||
self.linear_transform_base = nn.Linear(768, 256) | ||||||
self.linear_transform = LoraLinear( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of wrapping the layer explicitly, let's create a |
||||||
self.linear_transform_base, adapter_name="test_linear", r=8, lora_alpha=16, lora_dropout=0.05 | ||||||
) | ||||||
|
||||||
def forward(self, input_ids): | ||||||
embedded_output = self.embedding_layer(input_ids) | ||||||
norm_output = self.layer_norm(embedded_output) | ||||||
linear_output = self.linear_transform(norm_output) | ||||||
|
||||||
return linear_output | ||||||
|
||||||
|
||||||
class SimpleLorAEmbeddingModel(nn.Module): | ||||||
"""Same as SimpleModel but wraps Embedding in Lora layer""" | ||||||
|
||||||
def __init__(self): | ||||||
super(SimpleLorAEmbeddingModel, self).__init__() | ||||||
|
||||||
self.embedding_layer_base = nn.Embedding(1000, 768) | ||||||
self.embedding_layer = LoraEmbedding( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same argument as above. |
||||||
self.embedding_layer_base, adapter_name="test_embedding", r=8, lora_alpha=16, lora_dropout=0.05 | ||||||
) | ||||||
self.layer_norm = nn.LayerNorm(768) | ||||||
self.linear_transform = nn.Linear(768, 256) | ||||||
|
||||||
def forward(self, input_ids): | ||||||
embedded_output = self.embedding_layer(input_ids) | ||||||
norm_output = self.layer_norm(embedded_output) | ||||||
linear_output = self.linear_transform(norm_output) | ||||||
|
||||||
return linear_output | ||||||
|
||||||
|
||||||
class SimpleLorAConv2DModel(nn.Module): | ||||||
"""Same as SimpleModel but wraps Conv2D in Lora layer""" | ||||||
|
||||||
def __init__(self): | ||||||
super(SimpleLorAConv2DModel, self).__init__() | ||||||
|
||||||
self.embedding_layer = nn.Embedding(1000, 768) | ||||||
self.layer_norm = nn.LayerNorm(768) | ||||||
self.conv2d_transform_base = nn.Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) | ||||||
self.conv2d_transform = LoraConv2d( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same argument as above. |
||||||
self.conv2d_transform_base, adapter_name="test_conv2d", r=8, lora_alpha=16, lora_dropout=0.05 | ||||||
) | ||||||
|
||||||
def forward(self, input_ids): | ||||||
# Additional layers for your custom model | ||||||
embedded_output = self.embedding_layer(input_ids) | ||||||
norm_output = self.layer_norm(embedded_output) | ||||||
|
||||||
# Reshape for Conv2d input (add batch size dimension) | ||||||
norm_output = norm_output.unsqueeze(1) | ||||||
conv_output = self.conv2d_transform(norm_output) | ||||||
|
||||||
# Remove batch size dimension | ||||||
conv_output = conv_output.squeeze(1) | ||||||
|
||||||
return conv_output | ||||||
|
||||||
|
||||||
class TestAutoCast(unittest.TestCase): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test seems to require a GPU, can you add the Line 25 in 3708793
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||||||
def test_simple_model(self): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. WDYT about parametrizing the test (using |
||||||
self._test_model(SimpleModel) | ||||||
|
||||||
def test_simple_conv2d_model(self): | ||||||
self._test_model(SimpleConv2DModel) | ||||||
|
||||||
def test_simple_lora_linear_model(self): | ||||||
self._test_model(SimpleLorALinearModel) | ||||||
|
||||||
def test_simple_lora_embedding_model(self): | ||||||
self._test_model(SimpleLorAEmbeddingModel) | ||||||
|
||||||
def test_simple_lora_conv2d_model(self): | ||||||
self._test_model(SimpleLorAConv2DModel) | ||||||
|
||||||
def _test_model(self, model_class): | ||||||
# Instantiate the model | ||||||
model = model_class().cuda() | ||||||
|
||||||
# Prepare dummy inputs | ||||||
input_ids = torch.randint(0, 1000, (2, 10)).cuda() | ||||||
|
||||||
# Forward pass with torch.bfloat16 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the bf16 case, can we please run |
||||||
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): | ||||||
outputs = model(input_ids) | ||||||
self.assertEqual(outputs.dtype, torch.bfloat16) | ||||||
|
||||||
# Forward pass with torch.float32 | ||||||
with torch.autocast(enabled=True, dtype=torch.float32, device_type="cuda"): | ||||||
outputs = model(input_ids) | ||||||
self.assertEqual(outputs.dtype, torch.float32) | ||||||
|
||||||
# Forward pass with torch.float16 | ||||||
with torch.autocast(enabled=True, dtype=torch.float16, device_type="cuda"): | ||||||
outputs = model(input_ids) | ||||||
self.assertEqual(outputs.dtype, torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we undo this change, or at least move the comment above 254, where it belongs?