Skip to content

Commit

Permalink
feat(llm): respect warnings environment for dtype warning (#664)
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm authored Nov 16, 2023
1 parent 4a6f13d commit 86d23fd
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions openllm-python/src/openllm/_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def normalise_model_name(name: str) -> str:


def _resolve_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
if not is_peft_available():
raise RuntimeError(
"LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'"
)
try:
from huggingface_hub import hf_hub_download
except ImportError:
Expand Down Expand Up @@ -174,19 +178,11 @@ def __init__(
backend = first_not_none(backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if is_vllm_available() else 'pt')
torch_dtype = first_not_none(os.getenv('TORCH_DTYPE'), torch_dtype, default='auto')
quantize = first_not_none(quantize, os.getenv('OPENLLM_QUANTIZE'), default=None)
# elif quantization_config is None and quantize is not None:
# quantization_config, attrs = infer_quantisation_config(self, quantize, **attrs)
attrs.update({'low_cpu_mem_usage': low_cpu_mem_usage})

# parsing tokenizer and model kwargs, as the hierarchy is param pass > default
model_attrs, tokenizer_attrs = flatten_attrs(**attrs)

if adapter_map is not None and not is_peft_available():
raise RuntimeError(
"LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'"
)
if isinstance(prompt_template, str):
prompt_template = PromptTemplate(prompt_template)
if model_tag is None:
model_tag, model_version = self._make_tag_components(model_id, model_version, backend=backend)
if model_version:
Expand All @@ -202,7 +198,7 @@ def __init__(
adapter_map=_resolve_peft_config_type(adapter_map) if adapter_map is not None else None,
serialisation=serialisation,
local=_local,
prompt_template=prompt_template,
prompt_template=PromptTemplate(prompt_template) if isinstance(prompt_template, str) else prompt_template,
system_message=system_message,
LLM__model_attrs=model_attrs,
LLM__tokenizer_attrs=tokenizer_attrs,
Expand Down Expand Up @@ -243,7 +239,7 @@ def _torch_dtype(self):
if config_dtype is None:
config_dtype = torch.float32
if not torch.cuda.is_available():
if self.__llm_torch_dtype__ in {'auto', 'half'}:
if self.__llm_torch_dtype__ in {'auto', 'half'} and not get_disable_warnings() and not get_quiet_mode():
logger.warning('"auto" and "half" are not supported on CPU. OpenLLM will default fallback to "float32".')
torch_dtype = torch.float32 # we need to cast back to full precision if cuda is not available
elif self.__llm_torch_dtype__ == 'auto':
Expand Down

0 comments on commit 86d23fd

Please sign in to comment.