From 658a51ef1ab8b7060971ad3e7618305fa83804dd Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 28 Jun 2024 12:48:12 -0700 Subject: [PATCH 1/7] wip --- vllm/worker/model_runner.py | 866 +++++++++++++++++-------------- vllm/worker/model_runner_base.py | 15 + 2 files changed, 484 insertions(+), 397 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 08216603023d7..0f7f03a0d2b24 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -30,7 +30,7 @@ from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, @@ -130,11 +130,468 @@ def from_broadcasted_tensor_dict( return cls(**tensor_dict) +class ModelInputForGPUBuilder( + ModelRunnerInputBuilderBase[ModelInputForGPUWithSamplingMetadata]): + """TBA""" + _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( + ModelInputForGPUWithSamplingMetadata) + + def __init__(self, attn_backend: "AttentionBackend", + scheduler_config: SchedulerConfig, + sliding_window: Optional[int], block_size: int, + enable_lora: bool, multi_modal_input_mapper): + super().__init__() + self.attn_backend = attn_backend + self.scheduler_config = scheduler_config + self.sliding_window = sliding_window + self.block_size = block_size + self.enable_lora = enable_lora + self.multi_modal_input_mapper = multi_modal_input_mapper + self.decode_only = True + + self.chunked_prefill_enabled = ( + self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled) + if self.sliding_window is not None: + self.sliding_window_blocks = ( + self.sliding_window + self.block_size - 1) // self.block_size + self.block_aligned_sliding_window = \ + self.sliding_window_blocks * self.block_size + + # Common inputs. + self.input_tokens: List[int] = [] + self.input_positions: List[int] = [] + + # LoRA inputs. + self.lora_index_mapping: List[int] = [] + self.lora_prompt_mapping: List[int] = [] + self.lora_requests: Set[LoRARequest] = set() + + # Multi-modal inputs. + self.multi_modal_kwargs_list: Dict[ + str, List[torch.Tensor]] = defaultdict(list) + + # Attention metadata inputs. + self.slot_mapping: List[int] = [] + self.seq_lens: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.decode_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.query_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + # The following fields are only for flashinfer + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + self.paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + self.paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + self.paged_kv_last_page_len: List[int] = [] + + def _compute_slot_mapping(self, seq_len, context_len, start_idx, + block_table): + """TODO: Move to attention metadata builder.""" + if block_table is None: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + self.slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + return + + # Mask the [0, start_idx) tokens of the prompt with + # _PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + if start_idx > context_len: + self.slot_mapping.extend([_PAD_SLOT_ID] * + (start_idx - context_len)) + for i in range(start_idx, seq_len): + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + self.slot_mapping.append(slot) + + def _add_prompt_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + seq_ids: List[int]): + self.decode_only = False + computed_block_nums = seq_group_metadata.computed_block_nums + + # Check if hit prefix cache (i.e., some blocks are already computed) + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None) + if self.chunked_prefill_enabled and prefix_cache_hit: + raise RuntimeError( + "chunked prefill cannot be used with prefix caching now.") + + # TODO(comaniac): Add a proper comment. + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + seq_data = seq_group_metadata.seq_data[seq_id] + + context_len = seq_data.get_num_computed_tokens() + seq_len = min(seq_data.get_len(), + context_len + seq_group_metadata.token_chunk_size) + tokens = seq_data.get_token_ids()[context_len:seq_len] + + # Uodate context_len and tokens if prefix cache hit. + if prefix_cache_hit: + assert computed_block_nums is not None + assert self.sliding_window is None + context_len = len(computed_block_nums) * self.block_size + tokens = tokens[context_len:] + + self.input_tokens.extend(tokens) + self.input_positions.extend(list(range(context_len, seq_len))) + + ### Attention metadata. TODO: Move to attention metadata builder. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + assert self.sliding_window is None + + if self.attn_backend.get_name() == "flash-attn": + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + # TODO(woosuk): This is a temporary fix. We should + # provide a unified interface for different backends. + block_table = seq_group_metadata.block_tables[seq_id] + else: + block_table = computed_block_nums + else: + # Prefill without chunked prefill or memory profiling. + block_table = [] + + self.block_tables.append(block_table) + self.seq_lens.append(seq_len) + self.context_lens.append(context_len) + query_len = seq_len - context_len + self.query_lens.append(query_len) + + assert len(seq_ids) == 1 + self.num_prefills += 1 + self.num_prefill_tokens += len(tokens) + self.prefill_seq_lens.append(seq_len) + + # Compute the slot mapping. + block_table = None + if not _is_block_tables_empty(seq_group_metadata.block_tables): + block_table = seq_group_metadata.block_tables[seq_id] + + start_idx = 0 + if self.sliding_window is not None: + assert self.scheduler_config.use_v2_block_manager \ + or context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention in V1 block manager") + # When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + self._compute_slot_mapping(seq_len, context_len, start_idx, + block_table) + + def _add_decode_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + seq_ids: List[int]): + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + + ### Prepare context length, sequence length and tokens. + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_data.get_len() - 1 + seq_len = min(seq_data.get_len(), + context_len + seq_group_metadata.token_chunk_size) + # Avoid using .get_token_ids because it copies all tokens. + tokens = [seq_data.get_last_token_id()] + + # These are seq_len/context_len capped to the sliding window. + # They are passed to decode kernel. + # We still need original seq_len/context_len to compute slot + # mapping (and input position) below. + curr_sliding_window_blocks = None + sliding_seq_len = seq_len + sliding_context_len = context_len + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if self.sliding_window is not None: + curr_sliding_window_blocks = self.sliding_window_blocks + if self.scheduler_config.use_v2_block_manager: + # number of elements in last block + suff_len = seq_len % self.block_size + sliding_seq_len = min( + seq_len, self.block_aligned_sliding_window + suff_len) + if suff_len > 0: + curr_sliding_window_blocks += 1 + else: + sliding_seq_len = min(seq_len, self.sliding_window) + sliding_context_len = sliding_seq_len - 1 + + self.input_tokens.extend(tokens) + self.input_positions.extend(list(range(context_len, seq_len))) + + ### Attention metadata. TODO: Move to attention metadata builder. + if seq_group_metadata.block_tables is not None: + # chunked prefill or decode + block_table = seq_group_metadata.block_tables[seq_id] + if curr_sliding_window_blocks is not None: + block_table = block_table[-curr_sliding_window_blocks:] + if self.attn_backend.get_name() == "flashinfer": + self.paged_kv_indices.extend(block_table) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + len(block_table)) + last_page_len = seq_data.get_len() % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_len.append(last_page_len) + else: + # Only happens when memory profiling runs. + block_table = [] + + self.block_tables.append(block_table) + self.seq_lens.append(sliding_seq_len) + self.context_lens.append(sliding_context_len) + query_len = sliding_seq_len - sliding_context_len + self.query_lens.append(query_len) + + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.decode_seq_lens.append(sliding_seq_len) + + # Compute the slot mapping. + block_table = None + if not _is_block_tables_empty(seq_group_metadata.block_tables): + block_table = seq_group_metadata.block_tables[seq_id] + self._compute_slot_mapping(seq_len, context_len, 0, block_table) + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + seq_ids = list(seq_group_metadata.seq_data.keys()) + n_seq = len(seq_ids) + if seq_group_metadata.is_prompt: + self._add_prompt_seq_group(seq_group_metadata, seq_ids) + else: + self._add_decode_seq_group(seq_group_metadata, seq_ids) + query_lens = self.query_lens[-n_seq:] + + if self.enable_lora: + lora_id = seq_group_metadata.lora_int_id + if lora_id > 0: + self.lora_requests.add(seq_group_metadata.lora_request) + + for query_len in query_lens: + if self.enable_lora: + self.lora_index_mapping += [lora_id] * query_len + self.lora_prompt_mapping.extend( + [lora_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + is not None else 1)) + + mm_data = seq_group_metadata.multi_modal_data + if mm_data is not None: + # Process multi-modal data + mm_kwargs = self.multi_modal_input_mapper(mm_data) + for k, v in mm_kwargs.items(): + self.multi_modal_kwargs_list[k].append(v) + + def build(self, model_config: ModelConfig, parallel_config: ParallelConfig, + kv_cache_dtype: Optional[str], max_seq_len_to_capture: int, + graph_block_tables: np.ndarray, + device: torch.device) -> ModelInputForGPUWithSamplingMetadata: + + if not self.input_tokens: + return self._model_input_cls() + + #### Attention metadata + batch_size = len(self.input_tokens) + max_query_len = max(self.query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.decode_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + # If cuda graph can be used, pad tensors accordingly. + # See `capture_model` API for more details. + # vLLM uses cuda graph only for decoding requests. + use_captured_graph = (self.decode_only + and not model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_decode_seq_len <= max_seq_len_to_capture) + if use_captured_graph: + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + pad_size = graph_batch_size - batch_size + self.input_tokens.extend([0] * pad_size) + self.input_positions.extend([0] * pad_size) + self.slot_mapping.extend([_PAD_SLOT_ID] * pad_size) + self.seq_lens.extend([1] * pad_size) + self.block_tables.extend([] * pad_size) + self.lora_index_mapping.extend([0] * pad_size) + + batch_size = graph_batch_size + num_decode_tokens = batch_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=device) + else: + max_block_table_len = max( + len(block_table) for block_table in self.block_tables) + block_tables = make_tensor_with_pad( + self.block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(self.query_lens)) + + seq_lens_tensor = torch.tensor(self.seq_lens, + dtype=torch.int, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + input_tokens_tensor = torch.tensor(self.input_tokens, + dtype=torch.long, + device=device) + input_positions_tensor = torch.tensor(self.input_positions, + dtype=torch.long, + device=device) + slot_mapping_tensor = torch.tensor(self.slot_mapping, + dtype=torch.long, + device=device) + + if self.attn_backend.get_name() == "flashinfer": + if not hasattr(self, "flashinfer_workspace_buffer"): + # Allocate 16MB workspace buffer + # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html + self.flashinfer_workspace_buffer = torch.empty( + 16 * 1024 * 1024, dtype=torch.uint8, device=device) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + dtype=torch.int, + device=device) + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + dtype=torch.int, + device=device) + paged_kv_last_page_len_tensor = torch.tensor( + self.paged_kv_last_page_len, dtype=torch.int, device=device) + kv_cache_dtype = get_kv_cache_torch_dtype(kv_cache_dtype, + model_config.dtype) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + use_cuda_graph=False, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, + workspace_buffer=self.flashinfer_workspace_buffer, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, + num_qo_heads=model_config.get_num_attention_heads( + parallel_config), + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_dim=model_config.get_head_size(), + page_size=16, + seq_start_loc=seq_start_loc, + data_type=kv_cache_dtype) + else: + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(self.query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + attn_metadata = self.attn_backend.make_metadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=self.seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + # Others + if self.enable_lora: + lora_mapping = LoRAMapping( + self.lora_index_mapping, + self.lora_prompt_mapping, + ) + else: + lora_mapping = None + + multi_modal_kwargs = { + k: torch.cat(v, dim=0).to(device) + for k, v in self.multi_modal_kwargs_list.items() + } + + return self._model_input_cls( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, + attn_metadata=attn_metadata, + seq_lens=self.seq_lens, + query_lens=self.query_lens, + lora_mapping=lora_mapping, + lora_requests=self.lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + ) + + class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): """ Helper class for shared methods between GPU model runners. """ - _model_input_cls: Type[TModelInputForGPU] def __init__( self, @@ -309,400 +766,17 @@ def _prepare_model_input_tensors( If cuda graph is required, this API automatically pads inputs. """ - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - lora_index_mapping: List[int] = [] - lora_prompt_mapping: List[int] = [] - lora_requests: Set[LoRARequest] = set() - - seq_lens: List[int] = [] - prefill_seq_lens: List[int] = [] - decode_seq_lens: List[int] = [] - context_lens: List[int] = [] - query_lens: List[int] = [] - block_tables: List[List[int]] = [] - multi_modal_kwargs_list: Dict[str, - List[torch.Tensor]] = defaultdict(list) - decode_only = True - num_prefills = 0 - num_prefill_tokens = 0 - num_decode_tokens = 0 - - # The following fields are only for flashinfer - # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout - # for the precise definition of the following fields. - # An example: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - paged_kv_indices: List[int] = [] - # 0 at the beginning of paged_kv_indptr indicates the start of the - # first request’s page indices in the paged_kv_indices list. - paged_kv_indptr: List[int] = [0] - # paged_kv_last_page_len is the length of the last page of each request - paged_kv_last_page_len: List[int] = [] - - if len(seq_group_metadata_list) == 0: - return self._model_input_cls() - - if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window + self.block_size - - 1) // self.block_size - block_aligned_sliding_window = \ - sliding_window_blocks * self.block_size - + builder = ModelInputForGPUBuilder(self.attn_backend, + self.scheduler_config, + self.sliding_window, self.block_size, + self.lora_config is not None, + self.multi_modal_input_mapper) for seq_group_metadata in seq_group_metadata_list: - seq_ids = list(seq_group_metadata.seq_data.keys()) - is_prompt = seq_group_metadata.is_prompt - - for seq_id in seq_ids: - computed_block_nums = seq_group_metadata.computed_block_nums - if (self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled - and not (computed_block_nums is None - or computed_block_nums == [])): - raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") - - seq_data = seq_group_metadata.seq_data[seq_id] - if is_prompt: - context_len = seq_data.get_num_computed_tokens() - else: - # get_num_computed_tokens is incorrect for spec decoding. - # So, we should have a special logic here. - # TODO(sang): Fix it. - context_len = seq_data.get_len() - 1 - - seq_len = min( - seq_data.get_len(), - context_len + seq_group_metadata.token_chunk_size) - if is_prompt: - tokens = seq_data.get_token_ids()[context_len:seq_len] - else: - # Optimization. get_token_ids requires the entire copy of - # tokens. - tokens = [seq_data.get_last_token_id()] - - # Prefix cache was hit. - # Prefix is not supported with sliding_window - prefix_cache_hit = (computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None - and is_prompt) - - # These are seq_len/context_len capped to the sliding window. - # They are passed to decode kernel. - # We still need original seq_len/context_len to compute slot - # mapping (and input position) below. - curr_sliding_window_blocks = None - sliding_seq_len = seq_len - sliding_context_len = context_len - - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - if (self.sliding_window is not None and not is_prompt): - curr_sliding_window_blocks = sliding_window_blocks - if self.scheduler_config.use_v2_block_manager: - # number of elements in last block - suff_len = seq_len % self.block_size - sliding_seq_len = min( - seq_len, block_aligned_sliding_window + suff_len) - if suff_len > 0: - curr_sliding_window_blocks += 1 - else: - sliding_seq_len = min(seq_len, self.sliding_window) - sliding_context_len = sliding_seq_len - 1 - - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - if prefix_cache_hit: - assert computed_block_nums is not None - context_len = len(computed_block_nums) * self.block_size - tokens = tokens[context_len:] - - # need to think what to set it to when we have both sliding - # window and prefix caching... - assert self.sliding_window is None, \ - "Prefix caching is not supported with sliding window" - sliding_context_len = context_len - - if self.attn_backend.get_name() == "flash-attn": - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - # TODO(woosuk): This is a temporary fix. We should - # provide a unified interface for different backends. - block_table = seq_group_metadata.block_tables[seq_id] - else: - block_table = computed_block_nums - elif (self.scheduler_config.chunked_prefill_enabled - or not is_prompt): - if seq_group_metadata.block_tables is not None: - # chunked prefill or decode - block_table = seq_group_metadata.block_tables[seq_id] - if curr_sliding_window_blocks is not None: - block_table = block_table[ - -curr_sliding_window_blocks:] - if self.attn_backend.get_name() == "flashinfer": - paged_kv_indices.extend(block_table) - paged_kv_indptr.append(paged_kv_indptr[-1] + - len(block_table)) - last_page_len = seq_data.get_len( - ) % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - paged_kv_last_page_len.append(last_page_len) - else: - # Only happens when memory profiling runs. - block_table = [] - else: - # Prefill without chunked prefill or memory profiling. - block_table = [] - block_tables.append(block_table) - - seq_lens.append(sliding_seq_len) - context_lens.append(sliding_context_len) - query_len = sliding_seq_len - sliding_context_len - query_lens.append(query_len) - input_tokens.extend(tokens) - input_positions.extend(list(range(context_len, seq_len))) - lora_id = seq_group_metadata.lora_int_id - - if is_prompt: - assert len(seq_ids) == 1 - num_prefills += 1 - num_prefill_tokens += len(tokens) - decode_only = False - prefill_seq_lens.append(seq_len) - else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) - num_decode_tokens += query_len - decode_seq_lens.append(sliding_seq_len) - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * query_len - lora_prompt_mapping.extend( - [lora_id] * - (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs - is not None else 1)) - - mm_data = seq_group_metadata.multi_modal_data - if mm_data is not None: - # Process multi-modal data - mm_kwargs = self.multi_modal_input_mapper(mm_data) - for k, v in mm_kwargs.items(): - multi_modal_kwargs_list[k].append(v) - - if _is_block_tables_empty(seq_group_metadata.block_tables): - # During memory profiling, the block tables are not - # initialized yet. In this case, we just use a dummy - # slot mapping. - # In embeddings, the block tables are {seq_id: None}. - slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - continue - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - - # Mask the [0, start_idx) tokens of the prompt with - # _PAD_SLOT_ID, where start_idx is max(0, seq_len - - # sliding_window). For example, if the prompt len is 10, - # sliding window is 8, and block size is 4, the first two - # tokens are masked and the slot mapping will be - # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - if is_prompt: - assert self.scheduler_config.use_v2_block_manager \ - or context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention in V1 block manager") - # It is an optimization. When it is decoding, it is always - # 0. When prefill, we use it to not write slots to kv cache - # to save memory. - start_idx = max(0, query_len - self.sliding_window) - - for i in range(context_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - batch_size = len(input_tokens) - max_query_len = max(query_lens) - max_prefill_seq_len = max(prefill_seq_lens, default=0) - max_decode_seq_len = max(decode_seq_lens, default=0) - - # If cuda graph can be used, pad tensors accordingly. - # See `capture_model` API for more details. - # vLLM uses cuda graph only for decoding requests. - use_captured_graph = ( - decode_only and not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_decode_seq_len <= self.max_seq_len_to_capture) - if use_captured_graph: - graph_batch_size = _get_graph_batch_size(batch_size) - assert graph_batch_size >= batch_size - for _ in range(graph_batch_size - batch_size): - input_tokens.append(0) - input_positions.append(0) - slot_mapping.append(_PAD_SLOT_ID) - seq_lens.append(1) - block_tables.append([]) - lora_index_mapping.append(0) - batch_size = graph_batch_size - num_decode_tokens = batch_size - - if use_captured_graph: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - input_block_tables = self.graph_block_tables[:batch_size] - for i, block_table in enumerate(block_tables): - if block_table: - input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device=self.device) - else: - max_block_table_len = max( - len(block_table) for block_table in block_tables) - block_tables = make_tensor_with_pad( - block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions_tensor = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping_tensor = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - - if self.attn_backend.get_name() == "flashinfer": - if not hasattr(self, "flashinfer_workspace_buffer"): - # Allocate 16MB workspace buffer - # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html - self.flashinfer_workspace_buffer = torch.empty( - 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) - paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, - dtype=torch.int, - device=self.device) - paged_kv_indices_tensor = torch.tensor(paged_kv_indices, - dtype=torch.int, - device=self.device) - paged_kv_last_page_len_tensor = torch.tensor( - paged_kv_last_page_len, dtype=torch.int, device=self.device) - kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, - self.model_config.dtype) - attn_metadata = self.attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - use_cuda_graph=False, - max_prefill_seq_len=max_prefill_seq_len, - block_tables=block_tables, - workspace_buffer=self.flashinfer_workspace_buffer, - paged_kv_indptr=paged_kv_indptr_tensor, - paged_kv_indices=paged_kv_indices_tensor, - paged_kv_last_page_len=paged_kv_last_page_len_tensor, - num_qo_heads=self.model_config.get_num_attention_heads( - self.parallel_config), - num_kv_heads=self.model_config.get_num_kv_heads( - self.parallel_config), - head_dim=self.model_config.get_head_size(), - page_size=16, - seq_start_loc=seq_start_loc, - data_type=kv_cache_dtype) - else: - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) - - attn_metadata = self.attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) - - if self.lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) - else: - lora_mapping = None - - multi_modal_kwargs = { - k: torch.cat(v, dim=0).to(self.device) - for k, v in multi_modal_kwargs_list.items() - } - - return self._model_input_cls( - input_tokens=input_tokens_tensor, - input_positions=input_positions_tensor, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_mapping=lora_mapping, - lora_requests=lora_requests, - multi_modal_kwargs=multi_modal_kwargs, - ) + builder.add_seq_group(seq_group_metadata) + return builder.build(self.model_config, self.parallel_config, + self.kv_cache_dtype, self.max_seq_len_to_capture, + self.graph_block_tables, + self.device) # type: ignore @torch.inference_mode() def profile_run(self) -> None: @@ -911,8 +985,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): """ GPU model runner with sampling step. """ - _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( - ModelInputForGPUWithSamplingMetadata) def make_model_input_from_broadcasted_tensor_dict( self, diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 959cfc0b9cac5..78175fd7402d3 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -112,6 +112,21 @@ def from_broadcasted_tensor_dict( raise NotImplementedError +class ModelRunnerInputBuilderBase(ABC, Generic[T]): + """A builder to create ModelRunnerInputBase objects. + """ + + @abstractmethod + def add_seq_group(self, seq_group_metadata): + """TBA""" + raise NotImplementedError + + @abstractmethod + def build(self, *args, **kwargs) -> T: + """Build metadata with on-device tensors.""" + raise NotImplementedError + + class ModelRunnerBase(ABC, Generic[T]): """ Model runner interface that abstracts a particular hardware and/or type of From 406c6b72f919b1c94d8421e958789a09ffd0a9cc Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 28 Jun 2024 14:23:50 -0700 Subject: [PATCH 2/7] fix slot_mapping --- vllm/worker/model_runner.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0f7f03a0d2b24..3de210ca12930 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -218,10 +218,9 @@ def _compute_slot_mapping(self, seq_len, context_len, start_idx, # sliding window is 8, and block size is 4, the first two # tokens are masked and the slot mapping will be # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - if start_idx > context_len: - self.slot_mapping.extend([_PAD_SLOT_ID] * - (start_idx - context_len)) - for i in range(start_idx, seq_len): + self.slot_mapping.extend([_PAD_SLOT_ID] * + max(0, start_idx - context_len)) + for i in range(max(start_idx, context_len), seq_len): block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset From a6f66ab0f1b8f771b20e6511bf1ff37042e048a2 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 28 Jun 2024 14:50:05 -0700 Subject: [PATCH 3/7] isolate moreattention --- vllm/worker/model_runner.py | 71 +++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 34 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3de210ca12930..35a0c2dad988f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -424,34 +424,58 @@ def build(self, model_config: ModelConfig, parallel_config: ParallelConfig, if not self.input_tokens: return self._model_input_cls() - #### Attention metadata batch_size = len(self.input_tokens) max_query_len = max(self.query_lens) max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.decode_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens - - # If cuda graph can be used, pad tensors accordingly. - # See `capture_model` API for more details. - # vLLM uses cuda graph only for decoding requests. use_captured_graph = (self.decode_only and not model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_decode_seq_len <= max_seq_len_to_capture) + + # If cuda graph can be used, pad tensors accordingly. + # See `capture_model` API for more details. + # vLLM uses cuda graph only for decoding requests. + cuda_graph_pad_size = 0 if use_captured_graph: graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size - pad_size = graph_batch_size - batch_size - self.input_tokens.extend([0] * pad_size) - self.input_positions.extend([0] * pad_size) - self.slot_mapping.extend([_PAD_SLOT_ID] * pad_size) - self.seq_lens.extend([1] * pad_size) - self.block_tables.extend([] * pad_size) - self.lora_index_mapping.extend([0] * pad_size) - + cuda_graph_pad_size = graph_batch_size - batch_size batch_size = graph_batch_size num_decode_tokens = batch_size + #### Tokens and positions. + self.input_tokens.extend([0] * cuda_graph_pad_size) + self.input_positions.extend([0] * cuda_graph_pad_size) + input_tokens_tensor = torch.tensor(self.input_tokens, + dtype=torch.long, + device=device) + input_positions_tensor = torch.tensor(self.input_positions, + dtype=torch.long, + device=device) + + #### LoRA and multi-modal data. + if self.enable_lora: + self.lora_index_mapping.extend([0] * cuda_graph_pad_size) + lora_mapping = LoRAMapping( + self.lora_index_mapping, + self.lora_prompt_mapping, + ) + else: + lora_mapping = None + + multi_modal_kwargs = { + k: torch.cat(v, dim=0).to(device) + for k, v in self.multi_modal_kwargs_list.items() + } + + #### Attention metadata. TODO: Move to attention metadata builder. + if use_captured_graph: + self.slot_mapping.extend([_PAD_SLOT_ID] * cuda_graph_pad_size) + self.seq_lens.extend([1] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = graph_block_tables[:batch_size] @@ -477,18 +501,11 @@ def build(self, model_config: ModelConfig, parallel_config: ParallelConfig, seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) - torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - input_tokens_tensor = torch.tensor(self.input_tokens, - dtype=torch.long, - device=device) - input_positions_tensor = torch.tensor(self.input_positions, - dtype=torch.long, - device=device) slot_mapping_tensor = torch.tensor(self.slot_mapping, dtype=torch.long, device=device) @@ -561,20 +578,6 @@ def build(self, model_config: ModelConfig, parallel_config: ParallelConfig, use_cuda_graph=use_captured_graph, ) - # Others - if self.enable_lora: - lora_mapping = LoRAMapping( - self.lora_index_mapping, - self.lora_prompt_mapping, - ) - else: - lora_mapping = None - - multi_modal_kwargs = { - k: torch.cat(v, dim=0).to(device) - for k, v in self.multi_modal_kwargs_list.items() - } - return self._model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, From 1db982ff07f4d51ff4cea17f655ee29014ee66e5 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 28 Jun 2024 16:43:35 -0700 Subject: [PATCH 4/7] a --- vllm/worker/model_runner.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1dc098d67d662..2fece29b1e11d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -237,7 +237,10 @@ def _compute_slot_mapping(self, seq_len, context_len, start_idx, slot = block_number * self.block_size + block_offset self.slot_mapping.append(slot) - def _add_seq_group_for_flashinfer(self, seq_data, block_table): + def _add_seq_group_for_flashinfer(self, seq_data, block_table): + if block_table is None: + return + seq_len = seq_data.get_len() # Get the number of valid blocks based on sequence length. # If seq_len = 16, block_size = 16, @@ -249,7 +252,8 @@ def _add_seq_group_for_flashinfer(self, seq_data, block_table): else seq_len // self.block_size self.paged_kv_indices.extend(block_table[:block_table_bound]) - self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + block_table_bound) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) last_page_len = seq_len % self.block_size if last_page_len == 0: @@ -322,7 +326,8 @@ def _add_prompt_seq_group(self, seq_group_metadata: SequenceGroupMetadata, # Compute the block table for slot mapping and flashinfer. block_table = None - is_profile_run = _is_block_tables_empty(seq_group_metadata.block_tables) + is_profile_run = _is_block_tables_empty( + seq_group_metadata.block_tables) if not is_profile_run: block_table = seq_group_metadata.block_tables[seq_id] @@ -341,7 +346,6 @@ def _add_prompt_seq_group(self, seq_group_metadata: SequenceGroupMetadata, if self.attn_backend.get_name() == "flashinfer": self._add_seq_group_for_flashinfer(seq_data, block_table) - def _add_decode_seq_group(self, seq_group_metadata: SequenceGroupMetadata, seq_ids: List[int]): for seq_id in seq_ids: @@ -408,7 +412,8 @@ def _add_decode_seq_group(self, seq_group_metadata: SequenceGroupMetadata, # Compute the slot mapping. block_table = None - is_profile_run = _is_block_tables_empty(seq_group_metadata.block_tables) + is_profile_run = _is_block_tables_empty( + seq_group_metadata.block_tables) if not is_profile_run: block_table = seq_group_metadata.block_tables[seq_id] self._compute_slot_mapping(seq_len, context_len, 0, block_table) @@ -515,7 +520,8 @@ def build(self, model_config: ModelConfig, parallel_config: ParallelConfig, if self.attn_backend.get_name() == "flashinfer": last_paged_kv_indptr = self.paged_kv_indptr[-1] - self.paged_kv_indptr.extend([last_paged_kv_indptr] * cuda_graph_pad_size) + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) else: max_block_table_len = max( @@ -531,19 +537,19 @@ def build(self, model_config: ModelConfig, parallel_config: ParallelConfig, context_lens_tensor = torch.tensor(self.context_lens, dtype=torch.int, - device=self.device) + device=device) seq_lens_tensor = torch.tensor(self.seq_lens, dtype=torch.int, - device=self.device) + device=device) query_lens_tensor = torch.tensor(self.query_lens, dtype=torch.long, - device=self.device) + device=device) query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, - device=self.device) + device=device) seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, - device=self.device) + device=device) torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, From f7e0e48c1f69354ec41d911f058a2dc24b5f2f93 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 1 Jul 2024 10:19:29 -0700 Subject: [PATCH 5/7] fix bug --- vllm/worker/model_runner.py | 56 +++++++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2fece29b1e11d..f12b0eb07bfda 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -172,6 +172,8 @@ def __init__(self, attn_backend: "AttentionBackend", # Common inputs. self.input_tokens: List[int] = [] self.input_positions: List[int] = [] + self.seq_lens: List[int] = [] + self.query_lens: List[int] = [] # LoRA inputs. self.lora_index_mapping: List[int] = [] @@ -184,11 +186,9 @@ def __init__(self, attn_backend: "AttentionBackend", # Attention metadata inputs. self.slot_mapping: List[int] = [] - self.seq_lens: List[int] = [] self.prefill_seq_lens: List[int] = [] self.decode_seq_lens: List[int] = [] self.context_lens: List[int] = [] - self.query_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.num_prefills = 0 self.num_prefill_tokens = 0 @@ -212,6 +212,28 @@ def __init__(self, attn_backend: "AttentionBackend", # paged_kv_last_page_len is the length of the last page of each request self.paged_kv_last_page_len: List[int] = [] + def _compute_for_sliding_window(self, seq_len, context_len): + curr_sliding_window_blocks = None + sliding_seq_len = seq_len + sliding_context_len = context_len + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if self.sliding_window is not None: + curr_sliding_window_blocks = self.sliding_window_blocks + if self.scheduler_config.use_v2_block_manager: + # number of elements in last block + suff_len = seq_len % self.block_size + sliding_seq_len = min( + seq_len, self.block_aligned_sliding_window + suff_len) + if suff_len > 0: + curr_sliding_window_blocks += 1 + else: + sliding_seq_len = min(seq_len, self.sliding_window) + sliding_context_len = sliding_seq_len - 1 + return curr_sliding_window_blocks, sliding_seq_len, sliding_context_len + def _compute_slot_mapping(self, seq_len, context_len, start_idx, block_table): """TODO: Move to attention metadata builder.""" @@ -283,6 +305,9 @@ def _add_prompt_seq_group(self, seq_group_metadata: SequenceGroupMetadata, context_len + seq_group_metadata.token_chunk_size) tokens = seq_data.get_token_ids()[context_len:seq_len] + (curr_sliding_window_blocks, _, + _) = self._compute_for_sliding_window(seq_len, context_len) + # Uodate context_len and tokens if prefix cache hit. if prefix_cache_hit: assert computed_block_nums is not None @@ -292,6 +317,9 @@ def _add_prompt_seq_group(self, seq_group_metadata: SequenceGroupMetadata, self.input_tokens.extend(tokens) self.input_positions.extend(list(range(context_len, seq_len))) + self.seq_lens.append(seq_len) + query_len = seq_len - context_len + self.query_lens.append(query_len) ### Attention metadata. TODO: Move to attention metadata builder. # TODO(sang): Combine chunked prefill and prefix caching by @@ -309,15 +337,17 @@ def _add_prompt_seq_group(self, seq_group_metadata: SequenceGroupMetadata, block_table = seq_group_metadata.block_tables[seq_id] else: block_table = computed_block_nums + elif (self.scheduler_config.chunked_prefill_enabled + and seq_group_metadata.block_tables is not None): + block_table = seq_group_metadata.block_tables[seq_id] + if curr_sliding_window_blocks is not None: + block_table = block_table[-curr_sliding_window_blocks:] else: # Prefill without chunked prefill or memory profiling. block_table = [] self.block_tables.append(block_table) - self.seq_lens.append(seq_len) self.context_lens.append(context_len) - query_len = seq_len - context_len - self.query_lens.append(query_len) assert len(seq_ids) == 1 self.num_prefills += 1 @@ -361,13 +391,9 @@ def _add_decode_seq_group(self, seq_group_metadata: SequenceGroupMetadata, # Avoid using .get_token_ids because it copies all tokens. tokens = [seq_data.get_last_token_id()] - # These are seq_len/context_len capped to the sliding window. - # They are passed to decode kernel. - # We still need original seq_len/context_len to compute slot - # mapping (and input position) below. - curr_sliding_window_blocks = None - sliding_seq_len = seq_len - sliding_context_len = context_len + (curr_sliding_window_blocks, sliding_seq_len, + sliding_context_len) = self._compute_for_sliding_window( + seq_len, context_len) # TODO(sang): This is a hack to make sliding window work with # paged attn. We can remove it if we make paged attn kernel @@ -387,6 +413,9 @@ def _add_decode_seq_group(self, seq_group_metadata: SequenceGroupMetadata, self.input_tokens.extend(tokens) self.input_positions.extend(list(range(context_len, seq_len))) + self.seq_lens.append(sliding_seq_len) + query_len = sliding_seq_len - sliding_context_len + self.query_lens.append(query_len) ### Attention metadata. TODO: Move to attention metadata builder. if seq_group_metadata.block_tables is not None: @@ -399,10 +428,7 @@ def _add_decode_seq_group(self, seq_group_metadata: SequenceGroupMetadata, block_table = [] self.block_tables.append(block_table) - self.seq_lens.append(sliding_seq_len) self.context_lens.append(sliding_context_len) - query_len = sliding_seq_len - sliding_context_len - self.query_lens.append(query_len) assert query_len == 1, ( "seq_len: {}, context_len: {}, query_len: {}".format( From d15fb1c0b646d52791378ffa5211192aff0718d2 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 1 Jul 2024 14:54:12 -0700 Subject: [PATCH 6/7] flash_attn / flashinfer --- vllm/attention/backends/abstract.py | 36 ++ vllm/attention/backends/flash_attn.py | 190 +++++++- vllm/attention/backends/flashinfer.py | 254 ++++++++++- vllm/attention/backends/utils.py | 40 ++ vllm/worker/model_runner.py | 598 +++++++++++++------------- 5 files changed, 820 insertions(+), 298 deletions(-) create mode 100644 vllm/attention/backends/utils.py diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 40768532f59c2..a9b211494d45a 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -3,6 +3,7 @@ from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar) +import numpy as np import torch @@ -28,6 +29,16 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": return cls.get_metadata_cls()(*args, **kwargs) + @staticmethod + @abstractmethod + def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + raise NotImplementedError + + @classmethod + def make_metadata_builder(cls, *args, + **kwargs) -> "AttentionMetadataBuilder": + return cls.get_builder_cls()(*args, **kwargs) + @staticmethod @abstractmethod def get_kv_cache_shape( @@ -103,6 +114,31 @@ def asdict_zerocopy(self, T = TypeVar("T", bound=AttentionMetadata) +class AttentionMetadataBuilder(ABC, Generic[T]): + """Abstract class for attention metadata builders.""" + + @abstractmethod + def __init__(self, block_size: int, sliding_window: int, + use_v2_block_manager: bool) -> None: + raise NotImplementedError + + @abstractmethod + def add_prefill_seq_group(self, *args, **kwargs) -> None: + raise NotImplementedError + + @abstractmethod + def add_decode_seq_group(self, *args, **kwargs) -> None: + raise NotImplementedError + + @abstractmethod + def build(self, model_config: Any, parallel_config: Any, + kv_cache_dtype: Any, seq_lens: Any, query_lens: Any, + decode_seq_lens: Any, use_captured_graph: bool, + cuda_graph_pad_size: int, graph_block_tables: np.ndarray, + batch_size: int, device: Any) -> T: + raise NotImplementedError + + class AttentionImpl(ABC, Generic[T]): @abstractmethod diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 8cb5c3101a804..522ce8101becc 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -2,12 +2,18 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type +import numpy as np import torch from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataBuilder) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + is_block_tables_empty) +from vllm.sequence import SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad class FlashAttentionBackend(AttentionBackend): @@ -28,6 +34,10 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return FlashAttentionMetadata + @staticmethod + def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -184,6 +194,184 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: return self._cached_decode_metadata +class FlashAttentionMetadataBuilder( + AttentionMetadataBuilder[FlashAttentionMetadata]): + + def __init__(self, block_size, sliding_window, use_v2_block_manager): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.sliding_window = sliding_window + self.block_size = block_size + self.use_v2_block_manager = use_v2_block_manager + + def add_prefill_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + tokens: List[int], seq_id: int, seq_len: int, + query_len: int, context_len: int, + prefix_cache_hit, chunked_prefill_enabled, + computed_block_nums, + curr_sliding_window_blocks) -> None: + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + assert self.sliding_window is None + block_table = seq_group_metadata.block_tables[seq_id] + elif (chunked_prefill_enabled + and seq_group_metadata.block_tables is not None): + block_table = seq_group_metadata.block_tables[seq_id] + if curr_sliding_window_blocks is not None: + block_table = block_table[-curr_sliding_window_blocks:] + else: + # Prefill without chunked prefill or memory profiling. + block_table = [] + + self.block_tables.append(block_table) + self.context_lens.append(context_len) + + self.num_prefills += 1 + self.num_prefill_tokens += len(tokens) + self.prefill_seq_lens.append(seq_len) + + # Compute slot mapping. + block_table = None + is_profile_run = is_block_tables_empty(seq_group_metadata.block_tables) + if not is_profile_run: + block_table = seq_group_metadata.block_tables[seq_id] + + start_idx = 0 + if self.sliding_window is not None: + assert self.use_v2_block_manager \ + or context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention in V1 block manager") + # When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + compute_slot_mapping(self.slot_mapping, seq_len, context_len, + start_idx, self.block_size, block_table) + + def add_decode_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + seq_id, seq_len, query_len, context_len, + curr_sliding_window_blocks, sliding_seq_len, + sliding_context_len): + + # Compute block table. + if seq_group_metadata.block_tables is not None: + block_table = seq_group_metadata.block_tables[seq_id] + if curr_sliding_window_blocks is not None: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_table[-curr_sliding_window_blocks:] + else: + # Only happens when memory profiling runs. + block_table = [] + + self.block_tables.append(block_table) + self.context_lens.append(sliding_context_len) + + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + + # Compute the slot mapping. + block_table = None + is_profile_run = is_block_tables_empty(seq_group_metadata.block_tables) + if not is_profile_run: + block_table = seq_group_metadata.block_tables[seq_id] + + compute_slot_mapping(self.slot_mapping, seq_len, context_len, 0, + self.block_size, block_table) + + def build(self, model_config, parallel_config, kv_cache_dtype, seq_lens, + query_lens, decode_seq_lens, use_captured_graph: bool, + cuda_graph_pad_size: int, graph_block_tables: np.ndarray, + batch_size: int, device): + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(decode_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + cuda_graph_pad_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=device) + else: + max_block_table_len = max( + len(block_table) for block_table in self.block_tables) + block_tables = make_tensor_with_pad( + self.block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + slot_mapping_tensor = torch.tensor(self.slot_mapping, + dtype=torch.long, + device=device) + + return FlashAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + class FlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 4ecac7379c7f6..4540b3c04c1ae 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Type +import numpy as np + try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper @@ -14,7 +16,12 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataBuilder) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + is_block_tables_empty) +from vllm.sequence import SequenceGroupMetadata +from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad class FlashInferBackend(AttentionBackend): @@ -31,6 +38,10 @@ def get_impl_cls() -> Type["FlashInferImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return FlashInferMetadata + @staticmethod + def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + return FlashInferMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -126,6 +137,7 @@ def begin_forward(self): self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) + self.prefill_wrapper.end_forward() self.prefill_wrapper.begin_forward( self.query_start_loc, self.paged_kv_indptr, self.paged_kv_indices, self.paged_kv_last_page_len, @@ -142,6 +154,7 @@ def begin_forward(self): self.device) assert self.decode_wrapper is not None + self.decode_wrapper.end_forward() self.decode_wrapper.begin_forward( self.paged_kv_indptr, self.paged_kv_indices, @@ -184,6 +197,245 @@ def decode_metadata(self) -> Optional["FlashInferMetadata"]: return self +class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): + + def __init__(self, block_size, sliding_window, use_v2_block_manager): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.sliding_window = sliding_window + self.block_size = block_size + self.use_v2_block_manager = use_v2_block_manager + + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + self.paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + self.paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + self.paged_kv_last_page_len: List[int] = [] + + def _update_paged_kv_metadata(self, seq_data, block_table): + if block_table is None: + return + + seq_len = seq_data.get_len() + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_len.append(last_page_len) + + def add_prefill_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + tokens: List[int], seq_id: int, seq_len: int, + query_len: int, context_len: int, + prefix_cache_hit, chunked_prefill_enabled, + computed_block_nums, + curr_sliding_window_blocks) -> None: + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + assert self.sliding_window is None + block_table = computed_block_nums + elif (chunked_prefill_enabled + and seq_group_metadata.block_tables is not None): + block_table = seq_group_metadata.block_tables[seq_id] + if curr_sliding_window_blocks is not None: + block_table = block_table[-curr_sliding_window_blocks:] + else: + # Prefill without chunked prefill or memory profiling. + block_table = [] + + self.block_tables.append(block_table) + + self.num_prefills += 1 + self.num_prefill_tokens += len(tokens) + self.prefill_seq_lens.append(seq_len) + + # Compute slot mapping. + block_table = None + is_profile_run = is_block_tables_empty(seq_group_metadata.block_tables) + if not is_profile_run: + block_table = seq_group_metadata.block_tables[seq_id] + + start_idx = 0 + if self.sliding_window is not None: + assert self.use_v2_block_manager \ + or context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention in V1 block manager") + # When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + compute_slot_mapping(self.slot_mapping, seq_len, context_len, + start_idx, self.block_size, block_table) + + # FlashInfer specific + self._update_paged_kv_metadata(seq_group_metadata.seq_data[seq_id], + block_table) + + def add_decode_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + seq_id, seq_len, query_len, context_len, + curr_sliding_window_blocks, sliding_seq_len, + sliding_context_len): + seq_data = seq_group_metadata.seq_data[seq_id] + + if seq_group_metadata.block_tables is not None: + # chunked prefill or decode + block_table = seq_group_metadata.block_tables[seq_id] + if curr_sliding_window_blocks is not None: + block_table = block_table[-curr_sliding_window_blocks:] + else: + # Only happens when memory profiling runs. + block_table = [] + + self.block_tables.append(block_table) + + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + + # Compute the slot mapping. + block_table = None + is_profile_run = is_block_tables_empty(seq_group_metadata.block_tables) + if not is_profile_run: + block_table = seq_group_metadata.block_tables[seq_id] + + compute_slot_mapping(self.slot_mapping, seq_len, context_len, 0, + self.block_size, block_table) + + self._update_paged_kv_metadata(seq_data, block_table) + + def build(self, model_config, parallel_config, kv_cache_dtype, seq_lens, + query_lens, decode_seq_lens, use_captured_graph: bool, + cuda_graph_pad_size: int, graph_block_tables: np.ndarray, + batch_size: int, device): + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + cuda_graph_pad_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=device) + + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) + else: + max_block_table_len = max( + len(block_table) for block_table in self.block_tables) + block_tables = make_tensor_with_pad( + self.block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + slot_mapping_tensor = torch.tensor(self.slot_mapping, + dtype=torch.long, + device=device) + + if len(self.paged_kv_indptr) > 0: + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device="cpu", + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device="cpu", + dtype=torch.int) + paged_kv_last_page_len_tensor = torch.tensor( + self.paged_kv_last_page_len, device="cpu", dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_len_tensor = None + + kv_cache_dtype = get_kv_cache_torch_dtype(kv_cache_dtype, + model_config.dtype) + + return FlashInferMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, + num_qo_heads=model_config.get_num_attention_heads(parallel_config), + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_dim=model_config.get_head_size(), + page_size=self.block_size, + seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc, + device=device, + data_type=kv_cache_dtype, + use_cuda_graph=use_captured_graph) + + class FlashInferImpl(AttentionImpl): def __init__( diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py new file mode 100644 index 0000000000000..b541792c0a803 --- /dev/null +++ b/vllm/attention/backends/utils.py @@ -0,0 +1,40 @@ +from typing import Dict, Union + +PAD_SLOT_ID = -1 + + +def is_block_tables_empty(block_tables: Union[None, Dict]): + """ + Check if block_tables is None or a dictionary with all None values. + """ + if block_tables is None: + return True + if isinstance(block_tables, dict) and all( + value is None for value in block_tables.values()): + return True + return False + + +def compute_slot_mapping(slot_mapping, seq_len, context_len, start_idx, + block_size, block_table): + """TBA.""" + if block_table is None: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([PAD_SLOT_ID] * seq_len) + return + + # Mask the [0, start_idx) tokens of the prompt with + # PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len)) + for i in range(max(start_idx, context_len), seq_len): + block_number = block_table[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f12b0eb07bfda..df07de9a48e3a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -39,7 +39,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, - is_pin_memory_available, make_tensor_with_pad) + is_pin_memory_available) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -160,20 +160,12 @@ def __init__(self, attn_backend: "AttentionBackend", self.multi_modal_input_mapper = multi_modal_input_mapper self.decode_only = True - self.chunked_prefill_enabled = ( - self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled) - if self.sliding_window is not None: - self.sliding_window_blocks = ( - self.sliding_window + self.block_size - 1) // self.block_size - self.block_aligned_sliding_window = \ - self.sliding_window_blocks * self.block_size - # Common inputs. self.input_tokens: List[int] = [] self.input_positions: List[int] = [] self.seq_lens: List[int] = [] self.query_lens: List[int] = [] + self.decode_seq_lens: List[int] = [] # LoRA inputs. self.lora_index_mapping: List[int] = [] @@ -185,32 +177,44 @@ def __init__(self, attn_backend: "AttentionBackend", str, List[torch.Tensor]] = defaultdict(list) # Attention metadata inputs. - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.decode_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - - # The following fields are only for flashinfer - # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout - # for the precise definition of the following fields. - # An example: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - self.paged_kv_indices: List[int] = [] - # 0 at the beginning of paged_kv_indptr indicates the start of the - # first request’s page indices in the paged_kv_indices list. - self.paged_kv_indptr: List[int] = [0] - # paged_kv_last_page_len is the length of the last page of each request - self.paged_kv_last_page_len: List[int] = [] + self.attn_metadata_builder = self.attn_backend.make_metadata_builder( + self.block_size, self.sliding_window, + self.scheduler_config.use_v2_block_manager) + + self.chunked_prefill_enabled = ( + self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled) + if self.sliding_window is not None: + self.sliding_window_blocks = ( + self.sliding_window + self.block_size - 1) // self.block_size + self.block_aligned_sliding_window = \ + self.sliding_window_blocks * self.block_size + + # self.slot_mapping: List[int] = [] + # self.prefill_seq_lens: List[int] = [] + # self.context_lens: List[int] = [] + # self.block_tables: List[List[int]] = [] + # self.num_prefills = 0 + # self.num_prefill_tokens = 0 + # self.num_decode_tokens = 0 + + # # The following fields are only for flashinfer + # # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # # for the precise definition of the following fields. + # # An example: + # # request 1, page indices [0, 5, 8] + # # request 2, page indices [1, 6, 7] + # # request 3, page indices [3, 4] + # # paged_kv_indices is a concatenation of page indices of all requests: + # # [0, 5, 8, 1, 6, 7, 3, 4] + # # paged_kv_indptr is used to index into paged_kv_indices: + # # [0, 3, 6, 8] + # self.paged_kv_indices: List[int] = [] + # # 0 at the beginning of paged_kv_indptr indicates the start of the + # # first request’s page indices in the paged_kv_indices list. + # self.paged_kv_indptr: List[int] = [0] + # # paged_kv_last_page_len is the length of the last page of each request # noqa: E501 + # self.paged_kv_last_page_len: List[int] = [] def _compute_for_sliding_window(self, seq_len, context_len): curr_sliding_window_blocks = None @@ -234,53 +238,53 @@ def _compute_for_sliding_window(self, seq_len, context_len): sliding_context_len = sliding_seq_len - 1 return curr_sliding_window_blocks, sliding_seq_len, sliding_context_len - def _compute_slot_mapping(self, seq_len, context_len, start_idx, - block_table): - """TODO: Move to attention metadata builder.""" - if block_table is None: - # During memory profiling, the block tables are not - # initialized yet. In this case, we just use a dummy - # slot mapping. - # In embeddings, the block tables are {seq_id: None}. - self.slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - return - - # Mask the [0, start_idx) tokens of the prompt with - # _PAD_SLOT_ID, where start_idx is max(0, seq_len - - # sliding_window). For example, if the prompt len is 10, - # sliding window is 8, and block size is 4, the first two - # tokens are masked and the slot mapping will be - # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - self.slot_mapping.extend([_PAD_SLOT_ID] * - max(0, start_idx - context_len)) - for i in range(max(start_idx, context_len), seq_len): - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - self.slot_mapping.append(slot) - - def _add_seq_group_for_flashinfer(self, seq_data, block_table): - if block_table is None: - return - - seq_len = seq_data.get_len() - # Get the number of valid blocks based on sequence length. - # If seq_len = 16, block_size = 16, - # block_table_bound is 1 with 1 valid block. - # If seq_len = 15, block_size = 16, - # block_table_bound is 0 + 1 with 1 valid block. - block_table_bound = seq_len // self.block_size + 1 \ - if seq_len % self.block_size != 0 \ - else seq_len // self.block_size - - self.paged_kv_indices.extend(block_table[:block_table_bound]) - self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + - block_table_bound) - - last_page_len = seq_len % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - self.paged_kv_last_page_len.append(last_page_len) + # def _compute_slot_mapping(self, seq_len, context_len, start_idx, + # block_table): + # """TODO: Move to attention metadata builder.""" + # if block_table is None: + # # During memory profiling, the block tables are not + # # initialized yet. In this case, we just use a dummy + # # slot mapping. + # # In embeddings, the block tables are {seq_id: None}. + # self.slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + # return + + # # Mask the [0, start_idx) tokens of the prompt with + # # _PAD_SLOT_ID, where start_idx is max(0, seq_len - + # # sliding_window). For example, if the prompt len is 10, + # # sliding window is 8, and block size is 4, the first two + # # tokens are masked and the slot mapping will be + # # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + # self.slot_mapping.extend([_PAD_SLOT_ID] * + # max(0, start_idx - context_len)) + # for i in range(max(start_idx, context_len), seq_len): + # block_number = block_table[i // self.block_size] + # block_offset = i % self.block_size + # slot = block_number * self.block_size + block_offset + # self.slot_mapping.append(slot) + + # def _add_seq_group_for_flashinfer(self, seq_data, block_table): + # if block_table is None: + # return + + # seq_len = seq_data.get_len() + # # Get the number of valid blocks based on sequence length. + # # If seq_len = 16, block_size = 16, + # # block_table_bound is 1 with 1 valid block. + # # If seq_len = 15, block_size = 16, + # # block_table_bound is 0 + 1 with 1 valid block. + # block_table_bound = seq_len // self.block_size + 1 \ + # if seq_len % self.block_size != 0 \ + # else seq_len // self.block_size + + # self.paged_kv_indices.extend(block_table[:block_table_bound]) + # self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + # block_table_bound) + + # last_page_len = seq_len % self.block_size + # if last_page_len == 0: + # last_page_len = self.block_size + # self.paged_kv_last_page_len.append(last_page_len) def _add_prompt_seq_group(self, seq_group_metadata: SequenceGroupMetadata, seq_ids: List[int]): @@ -322,59 +326,64 @@ def _add_prompt_seq_group(self, seq_group_metadata: SequenceGroupMetadata, self.query_lens.append(query_len) ### Attention metadata. TODO: Move to attention metadata builder. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - if prefix_cache_hit: - assert computed_block_nums is not None - assert self.sliding_window is None - - if self.attn_backend.get_name() == "flash-attn": - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - # TODO(woosuk): This is a temporary fix. We should - # provide a unified interface for different backends. - block_table = seq_group_metadata.block_tables[seq_id] - else: - block_table = computed_block_nums - elif (self.scheduler_config.chunked_prefill_enabled - and seq_group_metadata.block_tables is not None): - block_table = seq_group_metadata.block_tables[seq_id] - if curr_sliding_window_blocks is not None: - block_table = block_table[-curr_sliding_window_blocks:] - else: - # Prefill without chunked prefill or memory profiling. - block_table = [] - - self.block_tables.append(block_table) - self.context_lens.append(context_len) - - assert len(seq_ids) == 1 - self.num_prefills += 1 - self.num_prefill_tokens += len(tokens) - self.prefill_seq_lens.append(seq_len) - - # Compute the block table for slot mapping and flashinfer. - block_table = None - is_profile_run = _is_block_tables_empty( - seq_group_metadata.block_tables) - if not is_profile_run: - block_table = seq_group_metadata.block_tables[seq_id] - - start_idx = 0 - if self.sliding_window is not None: - assert self.scheduler_config.use_v2_block_manager \ - or context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention in V1 block manager") - # When prefill, we use it to not write slots to kv cache - # to save memory. - start_idx = max(0, query_len - self.sliding_window) - - self._compute_slot_mapping(seq_len, context_len, start_idx, - block_table) - if self.attn_backend.get_name() == "flashinfer": - self._add_seq_group_for_flashinfer(seq_data, block_table) + self.attn_metadata_builder.add_prefill_seq_group( + seq_group_metadata, tokens, seq_id, seq_len, query_len, + context_len, prefix_cache_hit, self.chunked_prefill_enabled, + computed_block_nums, curr_sliding_window_blocks) + + # # TODO(sang): Combine chunked prefill and prefix caching by + # # only allowing multiple of block_size chunk size. + # # NOTE: This only works for oooooooxxx style attention. + # if prefix_cache_hit: + # assert computed_block_nums is not None + # assert self.sliding_window is None + + # if self.attn_backend.get_name() == "flash-attn": + # # NOTE(woosuk): For flash-attn, the block table should + # # include the entries for the incoming prefill tokens. + # # TODO(woosuk): This is a temporary fix. We should + # # provide a unified interface for different backends. + # block_table = seq_group_metadata.block_tables[seq_id] + # else: + # block_table = computed_block_nums + # elif (self.scheduler_config.chunked_prefill_enabled + # and seq_group_metadata.block_tables is not None): + # block_table = seq_group_metadata.block_tables[seq_id] + # if curr_sliding_window_blocks is not None: + # block_table = block_table[-curr_sliding_window_blocks:] + # else: + # # Prefill without chunked prefill or memory profiling. + # block_table = [] + + # self.block_tables.append(block_table) + # self.context_lens.append(context_len) + + # assert len(seq_ids) == 1 + # self.num_prefills += 1 + # self.num_prefill_tokens += len(tokens) + # self.prefill_seq_lens.append(seq_len) + + # # Compute the block table for slot mapping and flashinfer. + # block_table = None + # is_profile_run = _is_block_tables_empty( + # seq_group_metadata.block_tables) + # if not is_profile_run: + # block_table = seq_group_metadata.block_tables[seq_id] + + # start_idx = 0 + # if self.sliding_window is not None: + # assert self.scheduler_config.use_v2_block_manager \ + # or context_len == 0, ( + # "Prefix caching is currently not supported with " + # "sliding window attention in V1 block manager") + # # When prefill, we use it to not write slots to kv cache + # # to save memory. + # start_idx = max(0, query_len - self.sliding_window) + + # self._compute_slot_mapping(seq_len, context_len, start_idx, + # block_table) + # if self.attn_backend.get_name() == "flashinfer": + # self._add_seq_group_for_flashinfer(seq_data, block_table) def _add_decode_seq_group(self, seq_group_metadata: SequenceGroupMetadata, seq_ids: List[int]): @@ -395,56 +404,45 @@ def _add_decode_seq_group(self, seq_group_metadata: SequenceGroupMetadata, sliding_context_len) = self._compute_for_sliding_window( seq_len, context_len) - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - if self.sliding_window is not None: - curr_sliding_window_blocks = self.sliding_window_blocks - if self.scheduler_config.use_v2_block_manager: - # number of elements in last block - suff_len = seq_len % self.block_size - sliding_seq_len = min( - seq_len, self.block_aligned_sliding_window + suff_len) - if suff_len > 0: - curr_sliding_window_blocks += 1 - else: - sliding_seq_len = min(seq_len, self.sliding_window) - sliding_context_len = sliding_seq_len - 1 - self.input_tokens.extend(tokens) self.input_positions.extend(list(range(context_len, seq_len))) self.seq_lens.append(sliding_seq_len) query_len = sliding_seq_len - sliding_context_len self.query_lens.append(query_len) - - ### Attention metadata. TODO: Move to attention metadata builder. - if seq_group_metadata.block_tables is not None: - # chunked prefill or decode - block_table = seq_group_metadata.block_tables[seq_id] - if curr_sliding_window_blocks is not None: - block_table = block_table[-curr_sliding_window_blocks:] - else: - # Only happens when memory profiling runs. - block_table = [] - - self.block_tables.append(block_table) - self.context_lens.append(sliding_context_len) - - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) - self.num_decode_tokens += query_len self.decode_seq_lens.append(sliding_seq_len) - # Compute the slot mapping. - block_table = None - is_profile_run = _is_block_tables_empty( - seq_group_metadata.block_tables) - if not is_profile_run: - block_table = seq_group_metadata.block_tables[seq_id] - self._compute_slot_mapping(seq_len, context_len, 0, block_table) - if self.attn_backend.get_name() == "flashinfer": - self._add_seq_group_for_flashinfer(seq_data, block_table) + ### Attention metadata. TODO: Move to attention metadata builder. + self.attn_metadata_builder.add_decode_seq_group( + seq_group_metadata, seq_id, seq_len, query_len, context_len, + curr_sliding_window_blocks, sliding_seq_len, + sliding_context_len) + + # if seq_group_metadata.block_tables is not None: + # # chunked prefill or decode + # block_table = seq_group_metadata.block_tables[seq_id] + # if curr_sliding_window_blocks is not None: + # block_table = block_table[-curr_sliding_window_blocks:] + # else: + # # Only happens when memory profiling runs. + # block_table = [] + + # self.block_tables.append(block_table) + # self.context_lens.append(sliding_context_len) + + # assert query_len == 1, ( + # "seq_len: {}, context_len: {}, query_len: {}".format( + # seq_len, context_len, query_len)) + # self.num_decode_tokens += query_len + + # # Compute the slot mapping. + # block_table = None + # is_profile_run = _is_block_tables_empty( + # seq_group_metadata.block_tables) + # if not is_profile_run: + # block_table = seq_group_metadata.block_tables[seq_id] + # self._compute_slot_mapping(seq_len, context_len, 0, block_table) + # if self.attn_backend.get_name() == "flashinfer": + # self._add_seq_group_for_flashinfer(seq_data, block_table) def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -485,10 +483,10 @@ def build(self, model_config: ModelConfig, parallel_config: ParallelConfig, return self._model_input_cls() batch_size = len(self.input_tokens) - max_query_len = max(self.query_lens) - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + # max_query_len = max(self.query_lens) + # max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.decode_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens + # num_decode_tokens = self.num_decode_tokens use_captured_graph = (self.decode_only and not model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] @@ -503,7 +501,7 @@ def build(self, model_config: ModelConfig, parallel_config: ParallelConfig, assert graph_batch_size >= batch_size cuda_graph_pad_size = graph_batch_size - batch_size batch_size = graph_batch_size - num_decode_tokens = batch_size + # num_decode_tokens = batch_size #### Tokens and positions. self.input_tokens.extend([0] * cuda_graph_pad_size) @@ -515,6 +513,10 @@ def build(self, model_config: ModelConfig, parallel_config: ParallelConfig, dtype=torch.long, device=device) + #### Sequence and query lenghes. + if use_captured_graph: + self.seq_lens.extend([1] * cuda_graph_pad_size) + #### LoRA and multi-modal data. if self.enable_lora: self.lora_index_mapping.extend([0] * cuda_graph_pad_size) @@ -531,119 +533,123 @@ def build(self, model_config: ModelConfig, parallel_config: ParallelConfig, } #### Attention metadata. TODO: Move to attention metadata builder. - if use_captured_graph: - self.slot_mapping.extend([_PAD_SLOT_ID] * cuda_graph_pad_size) - self.seq_lens.extend([1] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - input_block_tables = graph_block_tables[:batch_size] - for i, block_table in enumerate(self.block_tables): - if block_table: - input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device=device) - - if self.attn_backend.get_name() == "flashinfer": - last_paged_kv_indptr = self.paged_kv_indptr[-1] - self.paged_kv_indptr.extend([last_paged_kv_indptr] * - cuda_graph_pad_size) - self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) - else: - max_block_table_len = max( - len(block_table) for block_table in self.block_tables) - block_tables = make_tensor_with_pad( - self.block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(self.query_lens)) - - context_lens_tensor = torch.tensor(self.context_lens, - dtype=torch.int, - device=device) - seq_lens_tensor = torch.tensor(self.seq_lens, - dtype=torch.int, - device=device) - query_lens_tensor = torch.tensor(self.query_lens, - dtype=torch.long, - device=device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) - - slot_mapping_tensor = torch.tensor(self.slot_mapping, - dtype=torch.long, - device=device) - - if self.attn_backend.get_name() == "flashinfer": - if len(self.paged_kv_indptr) > 0: - paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, - device="cpu", - dtype=torch.int) - paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, - device="cpu", - dtype=torch.int) - paged_kv_last_page_len_tensor = torch.tensor( - self.paged_kv_last_page_len, device="cpu", dtype=torch.int) - else: - paged_kv_indices_tensor = None - paged_kv_indptr_tensor = None - paged_kv_last_page_len_tensor = None - - kv_cache_dtype = get_kv_cache_torch_dtype(kv_cache_dtype, - model_config.dtype) - - attn_metadata = self.attn_backend.make_metadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - max_prefill_seq_len=max_prefill_seq_len, - block_tables=block_tables, - paged_kv_indptr=paged_kv_indptr_tensor, - paged_kv_indices=paged_kv_indices_tensor, - paged_kv_last_page_len=paged_kv_last_page_len_tensor, - num_qo_heads=model_config.get_num_attention_heads( - parallel_config), - num_kv_heads=model_config.get_num_kv_heads(parallel_config), - head_dim=model_config.get_head_size(), - page_size=self.block_size, - seq_start_loc=seq_start_loc, - query_start_loc=query_start_loc, - device=device, - data_type=kv_cache_dtype, - use_cuda_graph=use_captured_graph) - else: - attn_metadata = self.attn_backend.make_metadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=self.seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) + attn_metadata = self.attn_metadata_builder.build( + model_config, parallel_config, kv_cache_dtype, self.seq_lens, + self.query_lens, self.decode_seq_lens, use_captured_graph, + cuda_graph_pad_size, graph_block_tables, batch_size, device) + + # if use_captured_graph: + # self.slot_mapping.extend([_PAD_SLOT_ID] * cuda_graph_pad_size) + # self.block_tables.extend([] * cuda_graph_pad_size) + + # # The shape of graph_block_tables is + # # [max batch size, max context len // block size]. + # input_block_tables = graph_block_tables[:batch_size] + # for i, block_table in enumerate(self.block_tables): + # if block_table: + # input_block_tables[i, :len(block_table)] = block_table + # block_tables = torch.tensor(input_block_tables, device=device) + + # if self.attn_backend.get_name() == "flashinfer": + # last_paged_kv_indptr = self.paged_kv_indptr[-1] + # self.paged_kv_indptr.extend([last_paged_kv_indptr] * + # cuda_graph_pad_size) + # self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) + # else: + # max_block_table_len = max( + # len(block_table) for block_table in self.block_tables) + # block_tables = make_tensor_with_pad( + # self.block_tables, + # max_len=max_block_table_len, + # pad=0, + # dtype=torch.int, + # device=device, + # ) + # assert max_query_len > 0, ("query_lens: {}".format(self.query_lens)) + + # context_lens_tensor = torch.tensor(self.context_lens, + # dtype=torch.int, + # device=device) + # seq_lens_tensor = torch.tensor(self.seq_lens, + # dtype=torch.int, + # device=device) + # query_lens_tensor = torch.tensor(self.query_lens, + # dtype=torch.long, + # device=device) + # query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + # dtype=torch.int32, + # device=device) + # seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + # dtype=torch.int32, + # device=device) + # torch.cumsum(seq_lens_tensor, + # dim=0, + # dtype=seq_start_loc.dtype, + # out=seq_start_loc[1:]) + # torch.cumsum(query_lens_tensor, + # dim=0, + # dtype=query_start_loc.dtype, + # out=query_start_loc[1:]) + + # slot_mapping_tensor = torch.tensor(self.slot_mapping, + # dtype=torch.long, + # device=device) + + # if self.attn_backend.get_name() == "flashinfer": + # if len(self.paged_kv_indptr) > 0: + # paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + # device="cpu", + # dtype=torch.int) + # paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + # device="cpu", + # dtype=torch.int) + # paged_kv_last_page_len_tensor = torch.tensor( + # self.paged_kv_last_page_len, device="cpu", dtype=torch.int) # noqa: E501 + # else: + # paged_kv_indices_tensor = None + # paged_kv_indptr_tensor = None + # paged_kv_last_page_len_tensor = None + + # kv_cache_dtype = get_kv_cache_torch_dtype(kv_cache_dtype, + # model_config.dtype) + + # attn_metadata = self.attn_backend.make_metadata( + # num_prefills=self.num_prefills, + # slot_mapping=slot_mapping_tensor, + # num_prefill_tokens=self.num_prefill_tokens, + # num_decode_tokens=num_decode_tokens, + # max_prefill_seq_len=max_prefill_seq_len, + # block_tables=block_tables, + # paged_kv_indptr=paged_kv_indptr_tensor, + # paged_kv_indices=paged_kv_indices_tensor, + # paged_kv_last_page_len=paged_kv_last_page_len_tensor, + # num_qo_heads=model_config.get_num_attention_heads( + # parallel_config), + # num_kv_heads=model_config.get_num_kv_heads(parallel_config), + # head_dim=model_config.get_head_size(), + # page_size=self.block_size, + # seq_start_loc=seq_start_loc, + # query_start_loc=query_start_loc, + # device=device, + # data_type=kv_cache_dtype, + # use_cuda_graph=use_captured_graph) + # else: + # attn_metadata = self.attn_backend.make_metadata( + # num_prefills=self.num_prefills, + # slot_mapping=slot_mapping_tensor, + # num_prefill_tokens=self.num_prefill_tokens, + # num_decode_tokens=num_decode_tokens, + # seq_lens=self.seq_lens, + # seq_lens_tensor=seq_lens_tensor, + # max_query_len=max_query_len, + # max_prefill_seq_len=max_prefill_seq_len, + # max_decode_seq_len=max_decode_seq_len, + # query_start_loc=query_start_loc, + # seq_start_loc=seq_start_loc, + # context_lens_tensor=context_lens_tensor, + # block_tables=block_tables, + # use_cuda_graph=use_captured_graph, + # ) return self._model_input_cls( input_tokens=input_tokens_tensor, From 592b61b5155599b33b4b3c6a15598db446a23523 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 2 Jul 2024 09:22:01 -0700 Subject: [PATCH 7/7] rocm / xformers --- vllm/attention/backends/abstract.py | 3 +- vllm/attention/backends/flash_attn.py | 16 +- vllm/attention/backends/flashinfer.py | 14 +- vllm/attention/backends/rocm_flash_attn.py | 196 ++++++++++++++++++++- vllm/attention/backends/xformers.py | 195 +++++++++++++++++++- 5 files changed, 408 insertions(+), 16 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index a9b211494d45a..3d5ae7ec4ee1a 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -118,8 +118,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]): """Abstract class for attention metadata builders.""" @abstractmethod - def __init__(self, block_size: int, sliding_window: int, - use_v2_block_manager: bool) -> None: + def __init__(self, input_builder) -> None: raise NotImplementedError @abstractmethod diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 522ce8101becc..52f8dddcccf3c 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,6 +1,6 @@ """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import numpy as np import torch @@ -15,6 +15,9 @@ from vllm.sequence import SequenceGroupMetadata from vllm.utils import make_tensor_with_pad +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + class FlashAttentionBackend(AttentionBackend): @@ -197,7 +200,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): - def __init__(self, block_size, sliding_window, use_v2_block_manager): + def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] @@ -206,9 +209,10 @@ def __init__(self, block_size, sliding_window, use_v2_block_manager): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 - self.sliding_window = sliding_window - self.block_size = block_size - self.use_v2_block_manager = use_v2_block_manager + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) def add_prefill_seq_group(self, seq_group_metadata: SequenceGroupMetadata, tokens: List[int], seq_id: int, seq_len: int, @@ -224,6 +228,8 @@ def add_prefill_seq_group(self, seq_group_metadata: SequenceGroupMetadata, if prefix_cache_hit: assert computed_block_nums is not None assert self.sliding_window is None + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. block_table = seq_group_metadata.block_tables[seq_id] elif (chunked_prefill_enabled and seq_group_metadata.block_tables is not None): diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 4540b3c04c1ae..79295a2046d1d 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Set, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type import numpy as np @@ -23,6 +23,9 @@ from vllm.sequence import SequenceGroupMetadata from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + class FlashInferBackend(AttentionBackend): @@ -199,7 +202,7 @@ def decode_metadata(self) -> Optional["FlashInferMetadata"]: class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - def __init__(self, block_size, sliding_window, use_v2_block_manager): + def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.block_tables: List[List[int]] = [] @@ -207,9 +210,10 @@ def __init__(self, block_size, sliding_window, use_v2_block_manager): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 - self.sliding_window = sliding_window - self.block_size = block_size - self.use_v2_block_manager = use_v2_block_manager + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout # for the precise definition of the following fields. diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 81fabdbdfc83c..00677fc945794 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,18 +1,27 @@ """Attention layer ROCm GPUs.""" from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +import numpy as np import torch import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataBuilder) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + is_block_tables_empty) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger +from vllm.sequence import SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad logger = init_logger(__name__) +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + class ROCmFlashAttentionBackend(AttentionBackend): @@ -28,6 +37,10 @@ def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return ROCmFlashAttentionMetadata + @staticmethod + def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: + return ROCmFlashAttentionMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -166,6 +179,185 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: return self._cached_decode_metadata +class ROCmFlashAttentionMetadataBuilder( + AttentionMetadataBuilder[ROCmFlashAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) + + def add_prefill_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + tokens: List[int], seq_id: int, seq_len: int, + query_len: int, context_len: int, + prefix_cache_hit, chunked_prefill_enabled, + computed_block_nums, + curr_sliding_window_blocks) -> None: + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + assert self.sliding_window is None + block_table = computed_block_nums + elif (chunked_prefill_enabled + and seq_group_metadata.block_tables is not None): + block_table = seq_group_metadata.block_tables[seq_id] + if curr_sliding_window_blocks is not None: + block_table = block_table[-curr_sliding_window_blocks:] + else: + # Prefill without chunked prefill or memory profiling. + block_table = [] + + self.block_tables.append(block_table) + self.context_lens.append(context_len) + + self.num_prefills += 1 + self.num_prefill_tokens += len(tokens) + self.prefill_seq_lens.append(seq_len) + + # Compute slot mapping. + block_table = None + is_profile_run = is_block_tables_empty(seq_group_metadata.block_tables) + if not is_profile_run: + block_table = seq_group_metadata.block_tables[seq_id] + + start_idx = 0 + if self.sliding_window is not None: + assert self.use_v2_block_manager \ + or context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention in V1 block manager") + # When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + compute_slot_mapping(self.slot_mapping, seq_len, context_len, + start_idx, self.block_size, block_table) + + def add_decode_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + seq_id, seq_len, query_len, context_len, + curr_sliding_window_blocks, sliding_seq_len, + sliding_context_len): + + # Compute block table. + if seq_group_metadata.block_tables is not None: + block_table = seq_group_metadata.block_tables[seq_id] + if curr_sliding_window_blocks is not None: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_table[-curr_sliding_window_blocks:] + else: + # Only happens when memory profiling runs. + block_table = [] + + self.block_tables.append(block_table) + self.context_lens.append(sliding_context_len) + + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + + # Compute the slot mapping. + block_table = None + is_profile_run = is_block_tables_empty(seq_group_metadata.block_tables) + if not is_profile_run: + block_table = seq_group_metadata.block_tables[seq_id] + + compute_slot_mapping(self.slot_mapping, seq_len, context_len, 0, + self.block_size, block_table) + + def build(self, model_config, parallel_config, kv_cache_dtype, seq_lens, + query_lens, decode_seq_lens, use_captured_graph: bool, + cuda_graph_pad_size: int, graph_block_tables: np.ndarray, + batch_size: int, device): + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(decode_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + cuda_graph_pad_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=device) + else: + max_block_table_len = max( + len(block_table) for block_table in self.block_tables) + block_tables = make_tensor_with_pad( + self.block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + slot_mapping_tensor = torch.tensor(self.slot_mapping, + dtype=torch.long, + device=device) + + return ROCmFlashAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + class ROCmFlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index ff449c3ff74f8..5235b5637f769 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -1,7 +1,8 @@ """Attention layer with xFormers and PagedAttention.""" from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +import numpy as np import torch from xformers import ops as xops from xformers.ops.fmha.attn_bias import (AttentionBias, @@ -9,13 +10,21 @@ LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataBuilder) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + is_block_tables_empty) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger +from vllm.sequence import SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad logger = init_logger(__name__) +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + class XFormersBackend(AttentionBackend): @@ -31,6 +40,10 @@ def get_impl_cls() -> Type["XFormersImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return XFormersMetadata + @staticmethod + def get_builder_cls() -> Type["XFormersMetadataBuilder"]: + return XFormersMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -177,6 +190,184 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: return self._cached_decode_metadata +class XFormersMetadataBuilder(AttentionMetadataBuilder[XFormersMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) + + def add_prefill_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + tokens: List[int], seq_id: int, seq_len: int, + query_len: int, context_len: int, + prefix_cache_hit, chunked_prefill_enabled, + computed_block_nums, + curr_sliding_window_blocks) -> None: + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + assert self.sliding_window is None + block_table = computed_block_nums + elif (chunked_prefill_enabled + and seq_group_metadata.block_tables is not None): + block_table = seq_group_metadata.block_tables[seq_id] + if curr_sliding_window_blocks is not None: + block_table = block_table[-curr_sliding_window_blocks:] + else: + # Prefill without chunked prefill or memory profiling. + block_table = [] + + self.block_tables.append(block_table) + self.context_lens.append(context_len) + + self.num_prefills += 1 + self.num_prefill_tokens += len(tokens) + self.prefill_seq_lens.append(seq_len) + + # Compute slot mapping. + block_table = None + is_profile_run = is_block_tables_empty(seq_group_metadata.block_tables) + if not is_profile_run: + block_table = seq_group_metadata.block_tables[seq_id] + + start_idx = 0 + if self.sliding_window is not None: + assert self.use_v2_block_manager \ + or context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention in V1 block manager") + # When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + compute_slot_mapping(self.slot_mapping, seq_len, context_len, + start_idx, self.block_size, block_table) + + def add_decode_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + seq_id, seq_len, query_len, context_len, + curr_sliding_window_blocks, sliding_seq_len, + sliding_context_len): + + # Compute block table. + if seq_group_metadata.block_tables is not None: + block_table = seq_group_metadata.block_tables[seq_id] + if curr_sliding_window_blocks is not None: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_table[-curr_sliding_window_blocks:] + else: + # Only happens when memory profiling runs. + block_table = [] + + self.block_tables.append(block_table) + self.context_lens.append(sliding_context_len) + + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + + # Compute the slot mapping. + block_table = None + is_profile_run = is_block_tables_empty(seq_group_metadata.block_tables) + if not is_profile_run: + block_table = seq_group_metadata.block_tables[seq_id] + + compute_slot_mapping(self.slot_mapping, seq_len, context_len, 0, + self.block_size, block_table) + + def build(self, model_config, parallel_config, kv_cache_dtype, seq_lens, + query_lens, decode_seq_lens, use_captured_graph: bool, + cuda_graph_pad_size: int, graph_block_tables: np.ndarray, + batch_size: int, device): + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(decode_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + cuda_graph_pad_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=device) + else: + max_block_table_len = max( + len(block_table) for block_table in self.block_tables) + block_tables = make_tensor_with_pad( + self.block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + slot_mapping_tensor = torch.tensor(self.slot_mapping, + dtype=torch.long, + device=device) + + return XFormersMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + class XFormersImpl(AttentionImpl[XFormersMetadata]): """ If the input tensors contain prompt tokens, the layout is as follows: