From 609e9fbcd59c4262acbe9392e273eb8d096396fd Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 11 Oct 2024 09:28:42 -0400 Subject: [PATCH] fix for #9233 --- vllm/model_executor/models/registry.py | 36 +++++++------------------- 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 6063146988560..fbf75c2a15412 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -159,6 +159,8 @@ class _ModelInfo: is_embedding_model: bool supports_multimodal: bool supports_pp: bool + has_inner_state: bool + is_attention_free: bool @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": @@ -167,6 +169,8 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": is_embedding_model=is_embedding_model(model), supports_multimodal=supports_multimodal(model), supports_pp=supports_pp(model), + has_inner_state=has_inner_state(model), + is_attention_free=is_attention_free(model), ) @@ -382,6 +386,12 @@ def is_pp_supported_model( ) -> bool: return self.inspect_model_cls(architectures).supports_pp + def model_has_inner_state(self, architectures: Union[str, List[str]]) -> bool: + return self.inspect_model_cls(architectures).has_inner_state + + def is_attention_free_model(self, architectures: Union[str, List[str]]) -> bool: + return self.inspect_model_cls(architectures).is_attention_free + ModelRegistry = _ModelRegistry({ model_arch: _LazyRegisteredModel( @@ -430,32 +440,6 @@ def _run() -> None: with open(output_file, "wb") as f: f.write(pickle.dumps(result)) - @staticmethod - def model_has_inner_state(architectures: Union[str, List[str]]) -> bool: - if isinstance(architectures, str): - architectures = [architectures] - if not architectures: - logger.warning("No model architectures are specified") - - has_instate = partial(ModelRegistry._check_stateless, - has_inner_state, - default=False) - - return any(has_instate(arch) for arch in architectures) - - @staticmethod - def is_attention_free_model(architectures: Union[str, List[str]]) -> bool: - if isinstance(architectures, str): - architectures = [architectures] - if not architectures: - logger.warning("No model architectures are specified") - - is_attn_free = partial(ModelRegistry._check_stateless, - is_attention_free, - default=False) - - return any(is_attn_free(arch) for arch in architectures) - if __name__ == "__main__": _run()