Skip to content

Commit

Permalink
[Model] Jamba support (vllm-project#4115)
Browse files Browse the repository at this point in the history
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: Erez Schwartz <erezs@ai21.com>
Co-authored-by: Mor Zusman <morz@ai21.com>
Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: Tomer Asida <tomera@ai21.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
  • Loading branch information
7 people authored Jul 2, 2024
1 parent ee93f4f commit 9d6a8da
Show file tree
Hide file tree
Showing 21 changed files with 1,192 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
docker exec cpu-test bash -c "cd tests;
pip install pytest Pillow protobuf
cd ../
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py"
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported
23 changes: 23 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ COPY requirements-cuda.txt requirements-cuda.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-cuda.txt

COPY requirements-mamba.txt requirements-mamba.txt
RUN python3 -m pip install packaging
RUN python3 -m pip install -r requirements-mamba.txt

# cuda arch list used by torch
# can be useful for both `dev` and `test`
# explicitly set the list to avoid issues with torch 2.2
Expand Down Expand Up @@ -123,6 +127,21 @@ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-dev.txt

#################### DEV IMAGE ####################
#################### MAMBA Build IMAGE ####################
FROM dev as mamba-builder
# max jobs used for build
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}

WORKDIR /usr/src/mamba

COPY requirements-mamba.txt requirements-mamba.txt

# Download the wheel or build it if a pre-compiled release doesn't exist
RUN pip --verbose wheel -r requirements-mamba.txt \
--no-build-isolation --no-deps --no-cache-dir

#################### MAMBA Build IMAGE ####################

#################### vLLM installation IMAGE ####################
# image with vLLM installed
Expand All @@ -143,6 +162,10 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install dist/*.whl --verbose

RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
#################### vLLM installation IMAGE ####################


Expand Down
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ Alongside each architecture, we include some popular models that use it.
- Jais
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
-
* - :code:`JambaForCausalLM`
- Jamba
- :code:`ai21labs/Jamba-v0.1`, etc.
- ✅︎
* - :code:`LlamaForCausalLM`
- LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
- :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
Expand Down
3 changes: 3 additions & 0 deletions requirements-mamba.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Mamba dependencies
mamba-ssm>=1.2.2
causal-conv1d>=1.2.0
65 changes: 65 additions & 0 deletions tests/models/test_jamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest

MODELS = ["ai21labs/Jamba-tiny-random"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
assert dtype == "float"

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Jamba state is cleaned up between
# steps, If its not cleaned, an error would be expected.
try:
with vllm_runner(model, dtype=dtype) as vllm_model:
for _ in range(10):
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
except ValueError:
pytest.fail("Jamba inner state wasn't cleaned up between states, "
"could be related to finished_requests_ids")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
29 changes: 28 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,36 @@ def get_num_attention_heads(self,
return num_heads // parallel_config.tensor_parallel_size

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
return total_num_hidden_layers // parallel_config.pipeline_parallel_size

def contains_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> bool:
"""True for Mamba/SSM models (Jamba)"""
return self._get_num_seqlen_agnostic_layers(parallel_config) > 0

def get_layers_block_type(self,
parallel_config: "ParallelConfig") -> List[str]:
num_layers = self.get_num_layers(parallel_config)
# Transformers supports layers_block_type @property
return getattr(self.hf_config, "layers_block_type",
["attention"] * num_layers)

def get_num_attention_layers(self,
parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t == "attention"
])

def _get_num_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t != "attention"
])


class CacheConfig:
"""Configuration for the KV cache.
Expand Down
16 changes: 15 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,10 @@ def __init__(
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()

# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
self._finished_requests_ids: List[str] = list()
# Time at previous scheduling step
self.prev_time = 0.0
# Did we schedule a prompt at previous step?
Expand Down Expand Up @@ -373,6 +376,12 @@ def has_unfinished_seqs(self) -> bool:
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)

def get_and_reset_finished_requests_ids(self) -> List[str]:
"""Flushes the list of request ids of previously finished seq_groups."""
finished_requests_ids = self._finished_requests_ids
self._finished_requests_ids = list()
return finished_requests_ids

def _schedule_running(
self,
running_queue: deque,
Expand Down Expand Up @@ -1036,6 +1045,11 @@ def free_seq(self, seq: Sequence) -> None:
self.block_manager.free(seq)

def free_finished_seq_groups(self) -> None:
for queue in [self.running, self.swapped, self.waiting]:
self._finished_requests_ids += [
seq_group.request_id for seq_group in queue
if seq_group.is_finished()
]
self.running = deque(seq_group for seq_group in self.running
if not seq_group.is_finished())

Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ async def step_async(
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()

if not scheduler_outputs.is_empty():
# Execute the model.
Expand All @@ -235,7 +237,7 @@ async def step_async(
virtual_engine=virtual_engine,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
finished_requests_ids=finished_requests_ids)
output = await self.model_executor.execute_model_async(
execute_model_req)
else:
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"as performance will be severely degraded otherwise.")
seq_group_metadata_list, scheduler_outputs = self.scheduler[
0].schedule()
finished_requests_ids = self.scheduler[
0].get_and_reset_finished_requests_ids()

if not scheduler_outputs.is_empty():
execute_model_req = ExecuteModelRequest(
Expand All @@ -855,7 +857,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
finished_requests_ids=finished_requests_ids)
output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
else:
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
}

_EMBEDDING_MODELS = {
Expand Down
Loading

0 comments on commit 9d6a8da

Please sign in to comment.