Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA, Performance] Speedup multi-LoRA serving - Step 1 #1587

Merged
merged 3 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions benchmark/lora/launch_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import os

NUM_LORAS = 128
NUM_LORAS = 8
LORA_PATH = {
"base": "mistralai/Mistral-7B-Instruct-v0.3",
"lora": "/home/ying/test_lora",
Expand All @@ -11,12 +11,11 @@
def launch_server(args):
base_path = LORA_PATH["base"]
lora_path = LORA_PATH["lora"]
max_loras_per_batch = 4

if args.base_only:
cmd = f"python -m sglang.launch_server --model {base_path} "
cmd = f"python3 -m sglang.launch_server --model {base_path} "
else:
cmd = f"python -m sglang.launch_server --model {base_path} --lora-paths "
cmd = f"python3 -m sglang.launch_server --model {base_path} --lora-paths "
for i in range(NUM_LORAS):
lora_name = f"lora{i}"
cmd += f"{lora_name}={lora_path} "
Expand All @@ -29,11 +28,6 @@ def launch_server(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-loras",
type=int,
default=128,
)
parser.add_argument(
"--base-only",
action="store_true",
Expand Down
27 changes: 13 additions & 14 deletions python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def __init__(
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)

def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seq_lens = seq_lens
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices

def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -115,11 +115,10 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor
weights=self.A_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
# FIXME
assert lora_a_output.shape[-1] == self.lora_rank * 2
lora_output = torch.empty_like(base_output)
output_dim = lora_output.shape[-1] // 2
for i in range(2):
Expand All @@ -132,7 +131,7 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor
weights=self.B_buffer[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
return base_output + lora_output * self.scaling
Expand All @@ -145,14 +144,14 @@ def __init__(
super().__init__(base_layer, segment_gemm, lora_rank, scaling)

def set_lora_info(
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
self.B_buffer_q = B_buffer_q
self.B_buffer_kv = B_buffer_kv
self.bs = bs
self.seq_lens = seq_lens
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices

def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -161,7 +160,7 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor
weights=self.A_buffer_qkv,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
# FIXME parallelize qkv
Expand All @@ -173,7 +172,7 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor
weights=self.B_buffer_q,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
# kv
Expand All @@ -189,7 +188,7 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
)
Expand All @@ -202,12 +201,12 @@ def __init__(
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)

def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seq_lens = seq_lens
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices

def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -216,15 +215,15 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor
weights=self.A_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
lora_output = self.segment_gemm.run(
x=lora_output,
weights=self.B_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
return base_output + lora_output * self.scaling
Expand Down
27 changes: 18 additions & 9 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,18 +274,24 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
i = 0
j = len(self.active_uids)
evictable_uids = list(self.active_uids)
for uid in cur_uids:
if uid not in self.active_uids:
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
i += 1
if i < len(evictable_uids):
if j < self.max_loras_per_batch:
index = j
j += 1
else:
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
i += 1
assert i < len(evictable_uids)
self.active_uids.remove(evictable_uids[i])
self.buffer_id.pop(evictable_uids[i])
self.load_lora(uid, i)
index = i
i += 1
self.load_lora(uid, index)
self.active_uids.add(uid)
self.buffer_id[uid] = i
i += 1
self.buffer_id[uid] = index

if cur_uids == set([None]):
return
Expand All @@ -295,8 +301,11 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs)
else torch.ones(bs, device="cuda")
)
# FIXME: reuse the data rather than recompute
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.buffer_id[lora_path]
Expand All @@ -310,7 +319,7 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
self.A_buffer[weight_name][layer_id],
self.B_buffer[weight_name][layer_id],
bs,
seg_lens,
seg_indptr,
weight_indices,
)
else:
Expand All @@ -319,6 +328,6 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
self.B_buffer["q_proj"][layer_id],
self.B_buffer["kv_proj"][layer_id],
bs,
seg_lens,
seg_indptr,
weight_indices,
)
Loading