Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lora] Support long context lora #4787

Merged
merged 29 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,23 @@ steps:

- label: LoRA Test %N
#mirror_hardwares: [amd]
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
parallelism: 4

- label: LoRA Long Context (Distributed)
#mirror_hardwares: [amd]
num_gpus: 4
# This test runs llama 13B, so it is required to run on 4 GPUs.
commands:
# Temporarily run this way because we cannot clean up GPU mem usage
# for multi GPU tests.
# TODO(sang): Fix it.
- pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
- pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
- pytest -v -s lora/test_long_context.py::test_self_consistency
- pytest -v -s lora/test_long_context.py::test_quality
- pytest -v -s lora/test_long_context.py::test_max_len

- label: Tensorizer Test
#mirror_hardwares: [amd]
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
Expand Down
5 changes: 2 additions & 3 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ mypy vllm/model_executor --config-file pyproject.toml


CODESPELL_EXCLUDES=(
'--skip' '*docs/source/_build/**'
'--skip' '*docs/source/_build/**,./tests/lora/data'
)

# check spelling of specified files
Expand All @@ -133,10 +133,9 @@ spell_check_changed() {
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"

if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
codespell "${CODESPELL_EXCLUDES[@]}"
codespell "${CODESPELL_EXCLUDES[@]}"
fi
}

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ exclude = [

[tool.codespell]
ignore-words-list = "dout, te, indicies"
skip = "./tests/prompts,./benchmarks/sonnet.txt"
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data"

[tool.isort]
use_parentheses = true
Expand Down
50 changes: 50 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model

LONG_LORA_INFOS = [{
"lora_id": 1,
"context_length": "16k",
}, {
"lora_id": 2,
"context_length": "16k",
}, {
"lora_id": 3,
"context_length": "32k",
}]


def cleanup():
destroy_model_parallel()
Expand Down Expand Up @@ -154,6 +165,45 @@ def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")


@pytest.fixture(scope="session")
def long_context_lora_files_16k_1():
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")


@pytest.fixture(scope="session")
def long_context_lora_files_16k_2():
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_2")


@pytest.fixture(scope="session")
def long_context_lora_files_32k():
return snapshot_download(repo_id="SangBinCho/long_context_32k_testing")


# SANG-TODO Download long lora files.
@pytest.fixture(scope="session")
def long_context_infos(long_context_lora_files_16k_1,
long_context_lora_files_16k_2,
long_context_lora_files_32k):
cleanup()
infos = {}
for lora_checkpoint_info in LONG_LORA_INFOS:
lora_id = lora_checkpoint_info["lora_id"]
if lora_id == 1:
lora = long_context_lora_files_16k_1
elif lora_id == 2:
lora = long_context_lora_files_16k_2
elif lora_id == 3:
lora = long_context_lora_files_32k
else:
raise AssertionError("Unknown lora id")
infos[lora_id] = {
"context_length": lora_checkpoint_info["context_length"],
"lora": lora,
}
return infos


@pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
Expand Down
Empty file added tests/lora/data/__init__.py
Empty file.
97 changes: 97 additions & 0 deletions tests/lora/data/long_context_test_data.py

Large diffs are not rendered by default.

100 changes: 98 additions & 2 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,22 @@
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LogitsProcessorWithLoRA, LoRAMapping,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
# yapf: enable
from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights,
convert_mapping)
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
PackedLoRALayerWeights, convert_mapping)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.utils import set_random_seed
Expand Down Expand Up @@ -771,3 +773,97 @@ class FakeConfig:
expected_result,
rtol=rtol,
atol=atol)


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 8])
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0),
(6.0, 1.0)])
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
@pytest.mark.parametrize("is_neox_style", [True, False])
@pytest.mark.parametrize("rotary_dim", [None, 32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
def test_rotary_embedding_long_context(dist_init, num_loras, device,
scaling_factors, max_position,
is_neox_style, rotary_dim, head_size,
seq_len) -> None:
dtype = torch.float16
seed = 0
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)

max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
long_lora_scaling_factors=scaling_factors,
lora_dtype=dtype)

if rotary_dim is None:
rotary_dim = head_size
base = 10000
batch_size = 5 * num_loras
num_heads = 7

# Verify lora is equivalent to linear scaling rotary embedding.
rope = get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
)
lora_rope = LinearScalingRotaryEmbeddingWithLora(rope)
lora_rope.create_lora_weights(max_loras, lora_config)
linear_rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, {
"type": "linear",
"factor": scaling_factors
})
linear_rope = linear_rope.to(dtype=dtype)
id_to_index = get_random_id_to_index(num_loras, max_loras)
_, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=batch_size,
input_size=(1, max_position),
input_range=(0, lora_config.lora_extra_vocab_size),
input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
long_lora_context = LongContextLoRAContext(list(scaling_factors),
rotary_dim)

next_expected_offset = 0
# Make sure the offset is correct.
scaling_factor_to_offset = lora_rope.scaling_factor_to_offset
for scaling_factor, offset in scaling_factor_to_offset.items():
assert offset == next_expected_offset
next_expected_offset += scaling_factor * max_position

for i in range(len(scaling_factors)):
long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get(
scaling_factors[i], 0)
mapping_info = convert_mapping(
lora_mapping,
id_to_index,
max_loras,
512,
lora_config.lora_extra_vocab_size,
long_lora_context=long_lora_context,
)
lora_rope.set_mapping(*mapping_info)

positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query)
ref_q, ref_k = linear_rope(positions, query, key)
actual_q, actual_k = lora_rope(positions, query, key)

torch.allclose(ref_q, actual_q)
torch.allclose(ref_k, actual_k)
Loading
Loading