Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][Intel] Add LoRA adapter support for CPU backend #4830

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
b8072c1
init cpu lora support
Isotr0py May 13, 2024
3e59331
try lora
Isotr0py May 13, 2024
8410abb
add lora cpu support
Isotr0py May 15, 2024
46d97f5
Merge remote-tracking branch 'upstream/main' into lora
Isotr0py May 15, 2024
35d8b8d
make warning less noisy
Isotr0py May 16, 2024
e4e76c2
make ruff happy
Isotr0py May 16, 2024
f779eec
remove a useless comment
Isotr0py May 16, 2024
8129eab
Merge branch 'main' into lora
Isotr0py May 19, 2024
adccac2
Merge remote-tracking branch 'upstream/main' into lora
Isotr0py Jun 13, 2024
ec51691
revert cpu model runner
Isotr0py Jun 13, 2024
5f3f640
rebase lora support
Isotr0py Jun 13, 2024
68a1434
format code
Isotr0py Jun 13, 2024
b4366e3
add lora cpu test
Isotr0py Jun 13, 2024
7539ef0
fix lora cpu test
Isotr0py Jun 13, 2024
4da6a10
fix cpu lora test CI
Isotr0py Jun 15, 2024
80603e0
Merge branch 'main' into lora
Isotr0py Jun 15, 2024
936e2ee
fix cpu test CI typo
Isotr0py Jun 15, 2024
49a2b42
rollback cpu test CI
Isotr0py Jun 15, 2024
c6e638d
fix cpu lora test CI
Isotr0py Jun 15, 2024
8882a69
remove gemma lora test from cpu test
Isotr0py Jun 15, 2024
e63df6c
revert cuda empty_cache
Isotr0py Jun 17, 2024
1cec47d
Merge branch 'vllm-project:main' into lora
Isotr0py Jun 17, 2024
61f02a6
Merge branch 'vllm-project:main' into lora
Isotr0py Jun 18, 2024
f150300
optimize cpu lora support
Isotr0py Jun 19, 2024
fc74eb5
Merge branch 'main' into lora
Isotr0py Jun 19, 2024
9f133ac
format code
Isotr0py Jun 19, 2024
eab4dc0
re-add ray to run-cpu-test
Isotr0py Jun 19, 2024
1cc83e6
fix typos
Isotr0py Jun 19, 2024
97d0115
handle native lora kernel for old gpu
Isotr0py Jun 19, 2024
bbb7ed8
Merge branch 'vllm-project:main' into lora
Isotr0py Jun 19, 2024
cbd20a8
fix warning message
Isotr0py Jun 21, 2024
855523c
format code
Isotr0py Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ docker exec cpu-test bash -c "python3 examples/offline_inference.py"

# Run basic model test
docker exec cpu-test bash -c "cd tests;
pip install pytest Pillow protobuf
pip install pytest Pillow protobuf ray
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
cd ../
pytest -v -s tests/models -m \"not llava\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py"
pytest -v -s tests/models -m \"not llava\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py
pytest -v -s tests/lora/test_baichuan.py tests/lora/test_chatglm3.py tests/lora/test_phi.py"
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def cleanup():
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
ray.shutdown()


Expand Down
1 change: 0 additions & 1 deletion vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class CPUExecutor(ExecutorBase):

def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
assert self.lora_config is None, "cpu backend doesn't support LoRA"
self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(self.cache_config)
self.scheduler_config = _verify_and_get_scheduler_config(
Expand Down
24 changes: 15 additions & 9 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,14 @@ def convert_mapping(
embeddings_indices, long_lora_indices). If long_lora doesn't
exist, it only contains first 4 entries.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device="cuda",
device=device,
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
Expand All @@ -100,6 +101,7 @@ def convert_mapping(
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx

if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
Expand All @@ -112,9 +114,10 @@ def convert_mapping(
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
indices = torch.tensor(indices_list, dtype=torch.long, device=device)

prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda",
device=device,
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
Expand All @@ -127,7 +130,7 @@ def convert_mapping(
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
0, len(sampler_indices_padded), device=device, dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
Expand Down Expand Up @@ -386,26 +389,29 @@ def __init__(
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size

device = "cuda" if torch.cuda.is_available() else "cpu"
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
self.long_lora_context: Optional[LongContextLoRAContext] = None
self.base_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
device=device)
self.sampler_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
device=device)
self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
device=device)
self.embeddings_indices = torch.empty(2,
self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
device=device)
self.long_lora_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
device=device)
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}

# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
Expand Down
63 changes: 63 additions & 0 deletions vllm/lora/native_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# torch implementation of LoRA kernels.
import torch


def dispatch_bgmv(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
):
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)

Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
matrices.
indicies: Shape: `[B]`. Indices of the weight matrices.
layer_idx: Layer index of the weight matrices.
scale: Scaling factor.
"""
y += (x.unsqueeze(1) @ w_t_all[indicies, layer_idx, :, :].transpose(
-1, -2).to(x.dtype) * scale).squeeze(1)


def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, indicies: torch.LongTensor,
layer_idx: int, scale: float, h_in: int,
h_out: int, y_offset: int):
"""
Same as `bgmv` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.

Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)

Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
all of the transposed LoRA matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
"""
y[:, y_offset:y_offset + h_out] += (
x[:, :h_in].unsqueeze(1)
@ w_t_all[indicies, layer_idx, :, :].transpose(-1, -2).to(x.dtype) *
scale).squeeze(1)
36 changes: 24 additions & 12 deletions vllm/lora/punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,29 @@
import torch

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.lora import native_kernels
from vllm.utils import is_cpu

logger = init_logger(__name__)

if is_cpu():
logger.warning("punica LoRA kernels require a GPU to run. "
"But you are using the CPU version vLLM")
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved


def _check_punica_support():
if is_cpu():
return native_kernels

if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
return
return ops

if torch.cuda.get_device_capability() < (8, 0):
raise ImportError(
"punica LoRA kernels require compute capability >= 8.0")
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ImportError(
logger.warning(
"punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set.")
Expand Down Expand Up @@ -46,9 +58,9 @@ def bgmv(
layer_idx: Layer index of the weight matrices.
scale: Scaling factor.
"""
_check_punica_support()
lora_ops = _check_punica_support()

ops.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
lora_ops.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)


def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
Expand Down Expand Up @@ -77,9 +89,9 @@ def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
_check_punica_support()
lora_ops = _check_punica_support()

ops.dispatch_bgmv_low_level(
lora_ops.dispatch_bgmv_low_level(
y,
x,
w_t_all,
Expand Down Expand Up @@ -122,7 +134,7 @@ def add_lora(y: torch.Tensor,
scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
"""
_check_punica_support()
lora_ops = _check_punica_support()

r = wb_t_all.size(-1)
if buffer is None:
Expand All @@ -132,8 +144,8 @@ def add_lora(y: torch.Tensor,
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
ops.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale)
lora_ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
lora_ops.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale)


def add_lora_slice(y: torch.Tensor,
Expand Down Expand Up @@ -172,7 +184,7 @@ def add_lora_slice(y: torch.Tensor,
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
_check_punica_support()
lora_ops = _check_punica_support()

r = wb_t_all.size(-1)
if buffer is None:
Expand All @@ -182,7 +194,7 @@ def add_lora_slice(y: torch.Tensor,
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
ops.dispatch_bgmv_low_level(
lora_ops.dispatch_bgmv_low_level(
buffer,
x,
wa_t_all,
Expand All @@ -193,7 +205,7 @@ def add_lora_slice(y: torch.Tensor,
buffer.size(1),
0,
)
ops.dispatch_bgmv_low_level(
lora_ops.dispatch_bgmv_low_level(
y,
buffer,
wb_t_all,
Expand Down
Loading
Loading