Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

previous_dtype is now inferred from F.linear's result output type. #1010

Merged
merged 16 commits into from
Feb 19, 2024
13 changes: 7 additions & 6 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
return output_tensor

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
Expand All @@ -307,6 +305,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
Expand All @@ -317,7 +316,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
x = x.to(lora_A.weight.dtype)
result += lora_B(lora_A(dropout(x))) * scaling

result = result.to(previous_dtype)
result = result.to(torch_result_dtype)
return result

def __repr__(self) -> str:
Expand Down Expand Up @@ -483,6 +482,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_embedding_A:
continue
Expand All @@ -491,6 +491,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
scaling = self.scaling[active_adapter]
after_A = self._embed(x, embedding_A)
result += (after_A @ embedding_B) * scaling
result = result.to(torch_result_dtype)

return result

Expand Down Expand Up @@ -650,8 +651,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
return output_tensor

def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
Expand All @@ -660,6 +659,8 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
Expand All @@ -670,7 +671,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
x = x.to(lora_A.weight.dtype)
result += lora_B(lora_A(dropout(x))) * scaling

result = result.to(previous_dtype)
result = result.to(torch_result_dtype)
return result

def __repr__(self) -> str:
Expand Down
109 changes: 109 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,3 +1524,112 @@ def test_causal_lm_training_multi_gpu(self):

# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None


PRECISIONS = [(torch.float32), (torch.float16), (torch.bfloat16)]

LORA_PARAMS = {
"r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
}


class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()

self.embedding_layer = torch.nn.Embedding(1000, 768)
self.layer_norm = torch.nn.LayerNorm(768)
self.linear_transform = torch.nn.Linear(768, 256)

def forward(self, input_ids):
embedded_output = self.embedding_layer(input_ids)
norm_output = self.layer_norm(embedded_output)
linear_output = self.linear_transform(norm_output)

return linear_output


class SimpleConv2DModel(torch.nn.Module):
def __init__(self):
super().__init__()

self.embedding_layer = torch.nn.Embedding(1000, 768)
self.layer_norm = torch.nn.LayerNorm(768)
self.conv2d_transform = torch.nn.Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

def forward(self, input_ids):
# Additional layers for your custom model
embedded_output = self.embedding_layer(input_ids)
norm_output = self.layer_norm(embedded_output)

# Reshape for Conv2d input (add batch size dimension)
norm_output = norm_output.unsqueeze(1)
conv_output = self.conv2d_transform(norm_output)

# Remove batch size dimension
conv_output = conv_output.squeeze(1)

return conv_output


@require_torch_gpu
class TestAutoCast(unittest.TestCase):
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
# This test makes sure, that Lora dtypes are consistent with the types
# infered by torch.autocast under tested PRECISIONS
@parameterized.expand(PRECISIONS)
def test_simple_model(self, *args, **kwargs):
self._test_model(SimpleModel(), *args, **kwargs)

@parameterized.expand(PRECISIONS)
def test_simple_lora_linear_model(self, *args, **kwargs):
simple_model = SimpleModel()
config = LoraConfig(
**LORA_PARAMS,
target_modules=["linear_transform"],
)

lora_model = get_peft_model(simple_model, config)

self._test_model(lora_model, *args, **kwargs)

@parameterized.expand(PRECISIONS)
def test_simple_lora_embedding_model(self, *args, **kwargs):
simple_model = SimpleModel()
config = LoraConfig(
**LORA_PARAMS,
target_modules=["embedding_layer"],
)
lora_model = get_peft_model(simple_model, config)

self._test_model(lora_model, *args, **kwargs)

@parameterized.expand(PRECISIONS)
def test_simple_conv2d_model(self, *args, **kwargs):
self._test_model(SimpleConv2DModel(), *args, **kwargs)

@parameterized.expand(PRECISIONS)
def test_simple_lora_conv2d_model(self, *args, **kwargs):
simple_model = SimpleConv2DModel()
config = LoraConfig(
**LORA_PARAMS,
target_modules=["conv2d_transform"],
)
lora_model = get_peft_model(simple_model, config)
self._test_model(lora_model, *args, **kwargs)

def _test_model(self, model, precision):
# Move model to GPU
model = model.cuda()

# Prepare dummy inputs
input_ids = torch.randint(0, 1000, (2, 10)).cuda()
if precision == torch.bfloat16:
if not torch.cuda.is_bf16_supported():
self.skipTest("Bfloat16 not supported on this device")

# Forward pass with test precision
with torch.autocast(enabled=True, dtype=precision, device_type="cuda"):
outputs = model(input_ids)
assert outputs.dtype == precision
Loading