Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc][LoRA] Clean up the function interface of Punica #10917

Merged
merged 27 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,15 +565,18 @@ def _pretest():
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_replicated(dist_init, num_loras, device, stage,
bias_enabled) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
lora_dtype=torch.float16,
bias_enabled=bias_enabled)

def create_random_linear_replicated_layer():

Expand All @@ -585,7 +588,12 @@ def create_random_linear_replicated_layer():
lora_linear = ReplicatedLinearWithLoRA(linear)

lora_linear.create_lora_weights(max_loras, lora_config)

assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
lora_linear.lora_b_stacked) == 1)
if bias_enabled:
assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
else:
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear

for i in range(10):
Expand Down Expand Up @@ -669,8 +677,9 @@ def create_random_linear_replicated_layer():
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device, stage) -> None:
device, stage, bias_enabled) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
Expand All @@ -679,7 +688,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
fully_sharded_loras=fully_shard,
lora_dtype=torch.float16)
lora_dtype=torch.float16,
bias_enabled=bias_enabled)

def create_random_linear_parallel_layer():
if orientation == "row":
Expand All @@ -700,7 +710,12 @@ def create_random_linear_parallel_layer():
if not fully_shard else
ColumnParallelLinearWithShardedLoRA(linear))
lora_linear.create_lora_weights(max_loras, lora_config)

assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
lora_linear.lora_b_stacked) == 1)
if bias_enabled:
assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
else:
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear

for i in range(10):
Expand Down Expand Up @@ -784,8 +799,9 @@ def create_random_linear_parallel_layer():
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device, stage) -> None:
device, stage, bias_enabled) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
Expand All @@ -794,7 +810,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
fully_sharded_loras=fully_shard,
lora_dtype=torch.float16)
lora_dtype=torch.float16,
bias_enabled=bias_enabled)

def create_column_parallel_packed_layer():
if repeats == 2:
Expand Down Expand Up @@ -832,10 +849,16 @@ class FakeConfig:
num_key_value_heads = 32
num_attention_heads = 32

n_slices = repeats
lora_linear.create_lora_weights(max_loras,
lora_config,
model_config=FakeConfig())

assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
lora_linear.lora_b_stacked) == n_slices)
if bias_enabled:
assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
else:
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear

for i in range(10):
Expand Down Expand Up @@ -911,7 +934,6 @@ class FakeConfig:
512,
lora_config.lora_extra_vocab_size,
)
# lora_linear.set_mapping(*mapping_info)

lora_result = lora_linear(torch.cat(inputs))[0]
expected_result = linear(torch.cat(inputs))[0]
Expand Down
175 changes: 75 additions & 100 deletions vllm/lora/fully_sharded_layers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=unused-argument
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast

import torch
import torch.nn as nn
Expand Down Expand Up @@ -32,6 +32,44 @@ def dec(*args, **kwargs):
return dec


def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA):
"""
For `ColumnParallelLinearWithLoRA` or classes that inherit from
`ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
"""
assert (layer.n_slices == len(layer.lora_a_stacked) == len(
layer.lora_b_stacked) == len(layer.output_slices))
if layer.lora_bias_stacked is not None:
assert layer.n_slices == len(layer.lora_bias_stacked)

output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)

x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape

# Since communication is needed, the buffer is directly initialized as a
# tensor rather than a tuple of tensor.
buffers = torch.zeros(
(layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)

layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0)
buffers = tensor_model_parallel_all_gather(buffers)
layer.punica_wrapper.add_expand(output,
buffers,
layer.lora_b_stacked,
layer.lora_bias_stacked,
layer.output_slices,
offset_start=0,
add_input=True)

output = output.view(*out_orig_shape)
# now have column partitioned and packed output
return output


# these layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
Expand All @@ -51,34 +89,15 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
# gather operation.
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a

def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros(
(x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device,
)
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
self.punica_wrapper.add_expand(output,
buffer,
self.lora_b_stacked,
self.bias_stacked,
add_input=True)
# now have column partitioned output

output = output.view(*out_orig_shape)
return output
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)

@classmethod
@_fully_sharded_can_replace
Expand All @@ -99,46 +118,6 @@ def can_replace_layer(
)


def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
"""
MergedColumnParallelLinearWithShardedLoRA and
MergedQKVParallelLinearWithShardedLora share the same
LoRa weight application method.

The main difference is the step by shard_size for lora_b which can
vary for MergedQKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
n = len(layer.lora_a_stacked)
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)

x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
buffers = torch.zeros(
(n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
for idx in range(n):
layer.punica_wrapper.add_shrink(buffers[idx], x,
layer.lora_a_stacked[idx], 1.0)

buffers = tensor_model_parallel_all_gather(buffers)
layer.punica_wrapper.add_expand_packed_nslice(
output,
buffers,
layer.lora_b_stacked,
layer.bias_stacked,
1.0,
layer.output_slices,
)

output = output.view(*out_orig_shape)
# now have column partitioned and packed output
return output


class MergedColumnParallelLinearWithShardedLoRA(
MergedColumnParallelLinearWithLoRA):
"""
Expand All @@ -162,8 +141,9 @@ def slice_lora_a(
]
return lora_a

def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)

@classmethod
Expand Down Expand Up @@ -195,31 +175,15 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):

def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a

def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
self.punica_wrapper.add_expand(output,
buffer,
self.lora_b_stacked,
self.bias_stacked,
add_input=True)
# now have column partitioned output
output = output.view(*out_orig_shape)
return output
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)

@classmethod
@_fully_sharded_can_replace
Expand Down Expand Up @@ -260,8 +224,9 @@ def slice_lora_a(
]
return lora_a

def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)

@classmethod
Expand Down Expand Up @@ -294,7 +259,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
"""

def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_b_stacked.shape[2]
shard_size = self.lora_b_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
Expand All @@ -303,20 +268,24 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
shard_size = self.bias_stacked.shape[2]
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
self.lora_bias_stacked)
shard_size = self.lora_bias_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias

def apply(self, x: torch.Tensor) -> torch.Tensor:
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)

x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros(
(x.shape[0], self.lora_a_stacked.shape[2]),
(self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
Expand All @@ -330,12 +299,18 @@ def apply(self, x: torch.Tensor) -> torch.Tensor:
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
self.punica_wrapper.add_expand_slice(output, buffer,
self.lora_b_stacked,
self.bias_stacked, start_idx,
shard_size)
# NOTE offset are based on the rank.
shard_size = self.lora_b_stacked[0].shape[2]
offset_start = self.tp_rank * shard_size
self.punica_wrapper.add_expand(
output,
buffer,
self.lora_b_stacked,
self.lora_bias_stacked,
self.output_slices,
offset_start=offset_start,
add_input=True,
)
output = output.view(*out_orig_shape)
return output

Expand Down
Loading
Loading