Skip to content

Commit

Permalink
Better respect result dtype in LoRA layers (huggingface#1010)
Browse files Browse the repository at this point in the history
  • Loading branch information
MFajcik authored and BenjaminBossan committed Mar 14, 2024
1 parent 33be8ea commit cbafd66
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
return output_tensor

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
Expand All @@ -307,6 +305,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
Expand All @@ -317,7 +316,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
x = x.to(lora_A.weight.dtype)
result += lora_B(lora_A(dropout(x))) * scaling

result = result.to(previous_dtype)
result = result.to(torch_result_dtype)
return result

def __repr__(self) -> str:
Expand Down Expand Up @@ -483,6 +482,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_embedding_A:
continue
Expand All @@ -491,6 +491,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
scaling = self.scaling[active_adapter]
after_A = self._embed(x, embedding_A)
result += (after_A @ embedding_B) * scaling
result = result.to(torch_result_dtype)

return result

Expand Down Expand Up @@ -650,8 +651,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
return output_tensor

def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
Expand All @@ -660,6 +659,8 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
Expand All @@ -670,7 +671,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
x = x.to(lora_A.weight.dtype)
result += lora_B(lora_A(dropout(x))) * scaling

result = result.to(previous_dtype)
result = result.to(torch_result_dtype)
return result

def __repr__(self) -> str:
Expand Down
109 changes: 109 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,3 +1524,112 @@ def test_causal_lm_training_multi_gpu(self):

# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None


PRECISIONS = [(torch.float32), (torch.float16), (torch.bfloat16)]

LORA_PARAMS = {
"r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
}


class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()

self.embedding_layer = torch.nn.Embedding(1000, 768)
self.layer_norm = torch.nn.LayerNorm(768)
self.linear_transform = torch.nn.Linear(768, 256)

def forward(self, input_ids):
embedded_output = self.embedding_layer(input_ids)
norm_output = self.layer_norm(embedded_output)
linear_output = self.linear_transform(norm_output)

return linear_output


class SimpleConv2DModel(torch.nn.Module):
def __init__(self):
super().__init__()

self.embedding_layer = torch.nn.Embedding(1000, 768)
self.layer_norm = torch.nn.LayerNorm(768)
self.conv2d_transform = torch.nn.Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

def forward(self, input_ids):
# Additional layers for your custom model
embedded_output = self.embedding_layer(input_ids)
norm_output = self.layer_norm(embedded_output)

# Reshape for Conv2d input (add batch size dimension)
norm_output = norm_output.unsqueeze(1)
conv_output = self.conv2d_transform(norm_output)

# Remove batch size dimension
conv_output = conv_output.squeeze(1)

return conv_output


@require_torch_gpu
class TestAutoCast(unittest.TestCase):
# This test makes sure, that Lora dtypes are consistent with the types
# infered by torch.autocast under tested PRECISIONS
@parameterized.expand(PRECISIONS)
def test_simple_model(self, *args, **kwargs):
self._test_model(SimpleModel(), *args, **kwargs)

@parameterized.expand(PRECISIONS)
def test_simple_lora_linear_model(self, *args, **kwargs):
simple_model = SimpleModel()
config = LoraConfig(
**LORA_PARAMS,
target_modules=["linear_transform"],
)

lora_model = get_peft_model(simple_model, config)

self._test_model(lora_model, *args, **kwargs)

@parameterized.expand(PRECISIONS)
def test_simple_lora_embedding_model(self, *args, **kwargs):
simple_model = SimpleModel()
config = LoraConfig(
**LORA_PARAMS,
target_modules=["embedding_layer"],
)
lora_model = get_peft_model(simple_model, config)

self._test_model(lora_model, *args, **kwargs)

@parameterized.expand(PRECISIONS)
def test_simple_conv2d_model(self, *args, **kwargs):
self._test_model(SimpleConv2DModel(), *args, **kwargs)

@parameterized.expand(PRECISIONS)
def test_simple_lora_conv2d_model(self, *args, **kwargs):
simple_model = SimpleConv2DModel()
config = LoraConfig(
**LORA_PARAMS,
target_modules=["conv2d_transform"],
)
lora_model = get_peft_model(simple_model, config)
self._test_model(lora_model, *args, **kwargs)

def _test_model(self, model, precision):
# Move model to GPU
model = model.cuda()

# Prepare dummy inputs
input_ids = torch.randint(0, 1000, (2, 10)).cuda()
if precision == torch.bfloat16:
if not torch.cuda.is_bf16_supported():
self.skipTest("Bfloat16 not supported on this device")

# Forward pass with test precision
with torch.autocast(enabled=True, dtype=precision, device_type="cuda"):
outputs = model(input_ids)
assert outputs.dtype == precision

0 comments on commit cbafd66

Please sign in to comment.