From cbafd6663bbb443a71c243be4e9543bac5e69fab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Faj=C4=8D=C3=ADk?= Date: Mon, 19 Feb 2024 15:33:49 +0100 Subject: [PATCH] Better respect result dtype in LoRA layers (#1010) --- src/peft/tuners/lora/layer.py | 13 ++-- tests/test_gpu_examples.py | 109 ++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 6 deletions(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 5c920be1aa..5daef719a2 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -297,8 +297,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor: return output_tensor def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: - previous_dtype = x.dtype - if self.disable_adapters: if self.merged: self.unmerge() @@ -307,6 +305,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: result = self.base_layer(x, *args, **kwargs) else: result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype for active_adapter in self.active_adapters: if active_adapter not in self.lora_A.keys(): continue @@ -317,7 +316,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: x = x.to(lora_A.weight.dtype) result += lora_B(lora_A(dropout(x))) * scaling - result = result.to(previous_dtype) + result = result.to(torch_result_dtype) return result def __repr__(self) -> str: @@ -483,6 +482,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: result = self.base_layer(x, *args, **kwargs) else: result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype for active_adapter in self.active_adapters: if active_adapter not in self.lora_embedding_A: continue @@ -491,6 +491,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: scaling = self.scaling[active_adapter] after_A = self._embed(x, embedding_A) result += (after_A @ embedding_B) * scaling + result = result.to(torch_result_dtype) return result @@ -650,8 +651,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor: return output_tensor def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: - previous_dtype = x.dtype - if self.disable_adapters: if self.merged: self.unmerge() @@ -660,6 +659,8 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: result = self.base_layer(x, *args, **kwargs) else: result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + for active_adapter in self.active_adapters: if active_adapter not in self.lora_A.keys(): continue @@ -670,7 +671,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: x = x.to(lora_A.weight.dtype) result += lora_B(lora_A(dropout(x))) * scaling - result = result.to(previous_dtype) + result = result.to(torch_result_dtype) return result def __repr__(self) -> str: diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index f98d15cddf..acaca27028 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1524,3 +1524,112 @@ def test_causal_lm_training_multi_gpu(self): # assert loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None + + +PRECISIONS = [(torch.float32), (torch.float16), (torch.bfloat16)] + +LORA_PARAMS = { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, +} + + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + + self.embedding_layer = torch.nn.Embedding(1000, 768) + self.layer_norm = torch.nn.LayerNorm(768) + self.linear_transform = torch.nn.Linear(768, 256) + + def forward(self, input_ids): + embedded_output = self.embedding_layer(input_ids) + norm_output = self.layer_norm(embedded_output) + linear_output = self.linear_transform(norm_output) + + return linear_output + + +class SimpleConv2DModel(torch.nn.Module): + def __init__(self): + super().__init__() + + self.embedding_layer = torch.nn.Embedding(1000, 768) + self.layer_norm = torch.nn.LayerNorm(768) + self.conv2d_transform = torch.nn.Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + + def forward(self, input_ids): + # Additional layers for your custom model + embedded_output = self.embedding_layer(input_ids) + norm_output = self.layer_norm(embedded_output) + + # Reshape for Conv2d input (add batch size dimension) + norm_output = norm_output.unsqueeze(1) + conv_output = self.conv2d_transform(norm_output) + + # Remove batch size dimension + conv_output = conv_output.squeeze(1) + + return conv_output + + +@require_torch_gpu +class TestAutoCast(unittest.TestCase): + # This test makes sure, that Lora dtypes are consistent with the types + # infered by torch.autocast under tested PRECISIONS + @parameterized.expand(PRECISIONS) + def test_simple_model(self, *args, **kwargs): + self._test_model(SimpleModel(), *args, **kwargs) + + @parameterized.expand(PRECISIONS) + def test_simple_lora_linear_model(self, *args, **kwargs): + simple_model = SimpleModel() + config = LoraConfig( + **LORA_PARAMS, + target_modules=["linear_transform"], + ) + + lora_model = get_peft_model(simple_model, config) + + self._test_model(lora_model, *args, **kwargs) + + @parameterized.expand(PRECISIONS) + def test_simple_lora_embedding_model(self, *args, **kwargs): + simple_model = SimpleModel() + config = LoraConfig( + **LORA_PARAMS, + target_modules=["embedding_layer"], + ) + lora_model = get_peft_model(simple_model, config) + + self._test_model(lora_model, *args, **kwargs) + + @parameterized.expand(PRECISIONS) + def test_simple_conv2d_model(self, *args, **kwargs): + self._test_model(SimpleConv2DModel(), *args, **kwargs) + + @parameterized.expand(PRECISIONS) + def test_simple_lora_conv2d_model(self, *args, **kwargs): + simple_model = SimpleConv2DModel() + config = LoraConfig( + **LORA_PARAMS, + target_modules=["conv2d_transform"], + ) + lora_model = get_peft_model(simple_model, config) + self._test_model(lora_model, *args, **kwargs) + + def _test_model(self, model, precision): + # Move model to GPU + model = model.cuda() + + # Prepare dummy inputs + input_ids = torch.randint(0, 1000, (2, 10)).cuda() + if precision == torch.bfloat16: + if not torch.cuda.is_bf16_supported(): + self.skipTest("Bfloat16 not supported on this device") + + # Forward pass with test precision + with torch.autocast(enabled=True, dtype=precision, device_type="cuda"): + outputs = model(input_ids) + assert outputs.dtype == precision