diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index f8fd9d801d289..326a499415bc7 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -26,6 +26,9 @@ def _add_attn_metadata_broadcastable_dict( """ if attn_metadata is not None: tensor_dict.update(attn_metadata.asdict_zerocopy()) + else: + # skip building attn_metadata + tensor_dict.update({"skip_attn_metadata": True}) def _init_attn_metadata_from_tensor_dict( @@ -36,15 +39,19 @@ def _init_attn_metadata_from_tensor_dict( Helper method to initialize AttentionMetadata based on an AttentionBackend and broadcastable AttentionMetadata fields. """ - # Extract the fields used to create AttentionMetadata. - valid_attn_kwargs = {} - for field in dataclasses.fields(attn_backend.get_metadata_cls()): - val = tensor_dict.pop(field.name, None) - if val is not None: - valid_attn_kwargs[field.name] = val - - attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) - tensor_dict["attn_metadata"] = attn_metadata + skip_attn_metadata = tensor_dict.pop("skip_attn_metadata", False) + if skip_attn_metadata: + tensor_dict["attn_metadata"] = None + else: + # Extract the fields used to create AttentionMetadata. + valid_attn_kwargs = {} + for field in dataclasses.fields(attn_backend.get_metadata_cls()): + val = tensor_dict.pop(field.name, None) + if val is not None: + valid_attn_kwargs[field.name] = val + + attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) + tensor_dict["attn_metadata"] = attn_metadata return tensor_dict