Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
afeldman-nm committed Jul 15, 2024
1 parent 3d5bb88 commit db5539a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 62 deletions.
9 changes: 6 additions & 3 deletions tests/worker/test_encoder_decoder_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,12 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch):
model_runner._prepare_encoder_model_input_tensors(
seq_group_metadata_list, decoder_only_model_input))
encoder_input_tokens = encoder_decoder_model_input.encoder_input_tokens
encoder_input_positions = encoder_decoder_model_input.encoder_input_positions
encoder_input_positions = (
encoder_decoder_model_input.encoder_input_positions)
attn_metadata = encoder_decoder_model_input.attn_metadata
cross_slot_mapping = attn_metadata.cross_slot_mapping
return_encoder_seq_lens = encoder_decoder_model_input.attn_metadata.encoder_seq_lens
return_encoder_seq_lens = (
encoder_decoder_model_input.attn_metadata.encoder_seq_lens)
assert return_encoder_seq_lens == encoder_seq_lens
assert len(cross_slot_mapping) == len(encoder_input_tokens)

Expand Down Expand Up @@ -350,7 +352,8 @@ def test_prepare_decode(batch_size, backend_name, enforce_eager, monkeypatch):
model_runner._prepare_encoder_model_input_tensors(
seq_group_metadata_list, decoder_only_model_input))
encoder_input_tokens = encoder_decoder_model_input.encoder_input_tokens
encoder_input_positions = encoder_decoder_model_input.encoder_input_positions
encoder_input_positions = (
encoder_decoder_model_input.encoder_input_positions)
attn_metadata = encoder_decoder_model_input.attn_metadata
return_encoder_seq_lens = attn_metadata.encoder_seq_lens
cross_slot_mapping = attn_metadata.cross_slot_mapping
Expand Down
79 changes: 20 additions & 59 deletions vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,19 @@
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, cast

import torch
import torch.distributed

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner import (
GPUModelRunnerBase,
ModelInputForGPU,
ModelInputForGPUWithSamplingMetadata,
LORA_WARMUP_RANK,
_BATCH_SIZES_TO_CAPTURE,
_PAD_SLOT_ID,
)
from vllm.distributed import get_pp_group
from vllm.sequence import (IntermediateTensors, SamplerOutput,
from vllm.logger import init_logger
from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput,
SequenceGroupMetadata)

import dataclasses
import gc
import time
import warnings
from collections import defaultdict
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Type, TypeVar, Union)

import numpy as np
import torch
import torch.distributed
import torch.nn as nn
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _PAD_SLOT_ID,
LORA_WARMUP_RANK, GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata)

try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
Expand All @@ -45,39 +26,17 @@
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.model_executor.models.interfaces import supports_vision
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad)
from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
_add_sampling_metadata_broadcastable_dict)

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
Expand Down Expand Up @@ -632,9 +591,9 @@ def _prepare_encoder_model_input_tensors(
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)

# Prepare input tensors for flashinfer
if self.attn_backend.get_name() == "flashinfer":
assert False
# # Prepare input tensors for flashinfer
# if self.attn_backend.get_name() == "flashinfer":
# assert False

batch_size = len(input_tokens)
max_query_len = max(query_lens)
Expand All @@ -649,7 +608,8 @@ def _prepare_encoder_model_input_tensors(
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_seq_len <= self.max_seq_len_to_capture)
if use_captured_graph:
assert False
raise NotImplementedError("CUDAGraph is currently not supported "
"for encoder/decoder models.")

max_block_table_len = max(
len(block_table) for block_table in block_tables)
Expand All @@ -663,9 +623,9 @@ def _prepare_encoder_model_input_tensors(
assert (not is_prompt) or max_query_len > 0, (
"Decode-phase query_lens: {}".format(query_lens))

context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
# context_lens_tensor = torch.tensor(context_lens,
# dtype=torch.int,
# device=self.device)

seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
Expand All @@ -690,6 +650,7 @@ def _prepare_encoder_model_input_tensors(
out=query_start_loc[1:])

attn_metadata = model_input.attn_metadata
assert attn_metadata is not None

slot_mapping_tensor = torch.tensor(slot_mapping,
dtype=torch.long,
Expand Down

0 comments on commit db5539a

Please sign in to comment.