Skip to content

Commit

Permalink
Adding logger.info about update_torch_dtype in some quantizers (#35046)
Browse files Browse the repository at this point in the history
adding logger.info
  • Loading branch information
MekkCyber authored Dec 23, 2024
1 parent a1780b7 commit 82fcac0
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/transformers/quantizers/quantizer_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def validate_environment(self, device_map, **kwargs):
def update_torch_dtype(self, torch_dtype):
if torch_dtype is None:
torch_dtype = torch.float16
logger.info("Loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually.")
elif torch_dtype != torch.float16:
logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.")
return torch_dtype
Expand Down
1 change: 1 addition & 0 deletions src/transformers/quantizers/quantizer_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def validate_environment(self, *args, **kwargs):
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
torch_dtype = torch.float16
logger.info("Loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually.")
elif torch_dtype != torch.float16:
logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.")
return torch_dtype
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def update_torch_dtype(self, torch_dtype):
torch_dtype = torch.bfloat16
if self.quantization_config.quant_type == "int8_dynamic_activation_int8_weight":
if torch_dtype is None:
logger.info(
"Setting torch_dtype to torch.float32 for int8_dynamic_activation_int8_weight quantization as no torch_dtype was specified in from_pretrained"
)
# we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
torch_dtype = torch.float32
return torch_dtype
Expand Down

0 comments on commit 82fcac0

Please sign in to comment.