diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index de3d6c256..9493cc13e 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -222,7 +222,7 @@ def _torch_dtype(self): if config_dtype is None: config_dtype = torch.float32 if self.__llm_torch_dtype__ == 'auto': - if config_dtype == torch.float32 and torch.cuda.is_available(): + if config_dtype == torch.float32: torch_dtype = torch.float16 # following common practice else: torch_dtype = config_dtype