Skip to content

Commit

Permalink
Fix attn metadata broadcast (#16)
Browse files Browse the repository at this point in the history
When distributed KV cache is enabled, a chunk that is fully hit in the KV cache
skips the subsequent compute step to optimize performance.

However, skipping the compute step results in the attn metadata being set to None
on the driver side, as the attn metadata is not generated.

In tensor parallelism, the missing of attn metadata on the driver side causes
failures in building attn metadata on non-driver workers, leading to runtime errors.

To address this issue, this PR introduces a signal that allows non-driver workers
to skip building attn metadata.

Signed-off-by: Haiyang Shi <haiyang.shi@bytedance.com>
  • Loading branch information
DwyaneShi authored Dec 4, 2024
1 parent c5fa095 commit 0beb531
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def _add_attn_metadata_broadcastable_dict(
"""
if attn_metadata is not None:
tensor_dict.update(attn_metadata.asdict_zerocopy())
else:
# skip building attn_metadata
tensor_dict.update({"skip_attn_metadata": True})


def _init_attn_metadata_from_tensor_dict(
Expand All @@ -36,15 +39,19 @@ def _init_attn_metadata_from_tensor_dict(
Helper method to initialize AttentionMetadata based on an
AttentionBackend and broadcastable AttentionMetadata fields.
"""
# Extract the fields used to create AttentionMetadata.
valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
val = tensor_dict.pop(field.name, None)
if val is not None:
valid_attn_kwargs[field.name] = val

attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
tensor_dict["attn_metadata"] = attn_metadata
skip_attn_metadata = tensor_dict.pop("skip_attn_metadata", False)
if skip_attn_metadata:
tensor_dict["attn_metadata"] = None
else:
# Extract the fields used to create AttentionMetadata.
valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
val = tensor_dict.pop(field.name, None)
if val is not None:
valid_attn_kwargs[field.name] = val

attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
tensor_dict["attn_metadata"] = attn_metadata
return tensor_dict


Expand Down

0 comments on commit 0beb531

Please sign in to comment.