Skip to content

Commit

Permalink
Run yapf and ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
Swapnil Parekh committed Jun 3, 2024
1 parent a368116 commit 23f741b
Show file tree
Hide file tree
Showing 32 changed files with 306 additions and 272 deletions.
6 changes: 4 additions & 2 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}')
MYPY_VERSION=$(mypy --version | awk '{print $2}')
CODESPELL_VERSION=$(codespell --version)
ISORT_VERSION=$(isort --vn)
CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}')
# CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}')

# # params: tool name, tool version, required version
tool_version_check() {
Expand All @@ -41,7 +41,7 @@ tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt |
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)"
# tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)"

YAPF_FLAGS=(
'--recursive'
Expand Down Expand Up @@ -112,6 +112,8 @@ mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/prompt_adapter --config-file pyproject.toml



# If git diff returns a file that is in the skip list, the file may be checked anyway:
Expand Down
4 changes: 2 additions & 2 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
seq_group_metadata_list.append(seq_group_metadata)
decode_metadata_list.append(seq_group_metadata)

(input_tokens, input_positions, attn_metadata, _, _, _,
_, _, _) = model_runner.prepare_input_tensors(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, _, _, _, _, _,
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)

prefill_meta_actual = attn_metadata.prefill_metadata
decode_meta_actual = attn_metadata.decode_metadata
Expand Down
40 changes: 31 additions & 9 deletions vllm/adapter_commons/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
import json
from typing import Any, Callable, Dict, Hashable, List, Optional, Set, Type, TypeVar
import torch
from abc import abstractproperty
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar

from torch import nn
from vllm.utils import LRUCache

from vllm.logger import init_logger
from vllm.utils import LRUCache

logger = init_logger(__name__)


class AdapterModel:

def __init__(self, model_id=None):
self.id = model_id

Expand All @@ -17,9 +20,14 @@ def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
# Load weights or embeddings from local checkpoint
raise NotImplementedError("Subclasses must implement this method.")


T = TypeVar('T')


class AdapterLRUCache(LRUCache[T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable], None]):

def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
None]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn

Expand All @@ -28,7 +36,9 @@ def _on_remove(self, key: Hashable, value: T):
self.deactivate_fn(key)
return super()._on_remove(key, value)


class AdapterModelManager:

def __init__(
self,
model: nn.Module,
Expand All @@ -42,10 +52,19 @@ def __init__(
# Dict instead of a Set for compatibility with LRUCache.
self._active_adapters: Dict[int, None] = {}
self.adapter_type = 'Adapter'
self._last_mapping = None

def __len__(self) -> int:
return len(self._registered_adapters)

@abstractproperty
def adapter_slots(self):
...

@abstractproperty
def capacity(self):
...

def _deactivate_adapter(self, adapter_id: int):
raise NotImplementedError("Subclasses must implement this method.")

Expand All @@ -65,7 +84,7 @@ def _add_adapter(self, adapter: Any):
def add_adapter(self, adapter: Any) -> bool:
if adapter.id not in self._registered_adapters:
if len(self._registered_adapters) >= self.capacity:
raise RuntimeError("No free "+self.adapter_type+" slots.")
raise RuntimeError("No free " + self.adapter_type + " slots.")
self._add_adapter(adapter)
return True
return False
Expand All @@ -74,7 +93,10 @@ def set_adapter_mapping(self, mapping: Any) -> None:
if self._last_mapping != mapping:
self._set_adapter_mapping(mapping)
self._last_mapping = mapping


def _set_adapter_mapping(self, mapping: Any) -> None:
raise NotImplementedError("Subclasses must implement this method.")

def remove_adapter(self, adapter_id: int) -> bool:
self.deactivate_adapter(adapter_id)
return bool(self._registered_adapters.pop(adapter_id, None))
Expand All @@ -83,4 +105,4 @@ def list_adapters(self) -> Dict[int, Any]:
return dict(self._registered_adapters)

def get_adapter(self, adapter_id: int) -> Optional[Any]:
return self._registered_adapters.get(adapter_id, None)
return self._registered_adapters.get(adapter_id, None)
7 changes: 4 additions & 3 deletions vllm/adapter_commons/request.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from dataclasses import dataclass


@dataclass
class AdapterRequest:
"""
Base class for adapter requests.
"""
adapter_id: int

def __post_init__(self):
if self.adapter_id < 1:
raise ValueError(
f"id must be > 0, got {self.adapter_id}")
raise ValueError(f"id must be > 0, got {self.adapter_id}")

def __eq__(self, value: object) -> bool:
return isinstance(
value, self.__class__) and self.adapter_id == value.adapter_id

def __hash__(self) -> int:
return hash(self.adapter_id)
return hash(self.adapter_id)
17 changes: 14 additions & 3 deletions vllm/adapter_commons/worker_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, List, Optional, Set, Type, Dict
from typing import Any, Optional, Set

import torch

from vllm.adapter_commons.models import AdapterModelManager


class AbstractWorkerManager(ABC):

def __init__(self, device: torch.device):
self.device = device
self._model_manager: AdapterModelManager = None

@abstractproperty
def is_enabled(self) -> bool:
Expand All @@ -14,14 +20,19 @@ def is_enabled(self) -> bool:
def create_manager(self, model: torch.nn.Module) -> Any:
...

def set_active_adapters(self, requests: Set[Any], mapping: Optional[Any]) -> None:
def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None:
self._apply_adapters(requests)
self._model_manager.set_adapter_mapping(mapping)

@abstractmethod
def add_dummy_adapter(self, request: Any) -> bool:
...

@abstractmethod
def _load_adapter(self, request: Any) -> Any:
...

def add_adapter(self, adapter_request: Any) -> bool:
if adapter_request.adapter_id in self.list_adapters():
return False
Expand Down Expand Up @@ -59,4 +70,4 @@ def remove_all_adapters(self):
self._model_manager.remove_all_adapters()

def list_adapters(self) -> Set[int]:
return set(self._model_manager.list_adapters())
return set(self._model_manager.list_adapters())
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,7 @@ def __post_init__(self):
if self.max_cpu_prompt_adapters is None:
self.max_cpu_prompt_adapters = self.max_prompt_adapters


@dataclass
class VisionLanguageConfig:
"""Configs the input data format and how models should run for
Expand Down
3 changes: 2 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def _sort_by_lora_ids(self):
def _sort_by_prompt_adapter_ids(self):
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.seq_group.prompt_adapter_id, g.seq_group.request_id))
key=lambda g:
(g.seq_group.prompt_adapter_id, g.seq_group.request_id))

@property
def lora_requests(self) -> Set[LoRARequest]:
Expand Down
14 changes: 6 additions & 8 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
TokenizerPoolConfig, VisionLanguageConfig, PromptAdapterConfig)
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig,
VisionLanguageConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple

Expand Down Expand Up @@ -495,9 +496,6 @@ def add_cli_args(
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'))
# parser.add_argument('--enable-prompt-adapter',
# action='store_true',
# help='If True, enable handling of PromptAdapters.')
parser.add_argument('--max-prompt-adapters',
type=int,
default=EngineArgs.max_prompt_adapters,
Expand Down Expand Up @@ -699,10 +697,10 @@ def create_engine_config(self, ) -> EngineConfig:
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
)

prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters) #if self.enable_prompt_adapter else None
max_prompt_adapters=self.max_prompt_adapters)

if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape
or not self.image_feature_size):
Expand Down
19 changes: 11 additions & 8 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -270,12 +271,13 @@ async def process_model_inputs_async(
multi_modal_data=inputs.get("multi_modal_data"))

async def add_request_async(
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
Expand All @@ -292,7 +294,7 @@ async def add_request_async(
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
)
prompt_adapter_request=prompt_adapter_request)

async def check_health_async(self) -> None:
self.model_executor.check_health()
Expand Down Expand Up @@ -534,6 +536,7 @@ async def add_request(
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncStream:
if self.log_requests:
if isinstance(inputs, str):
Expand Down Expand Up @@ -587,7 +590,7 @@ async def add_request(
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
)
prompt_adapter_request=prompt_adapter_request)

return stream

Expand Down
Loading

0 comments on commit 23f741b

Please sign in to comment.