-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix for itemsize => element_size() for torch backwards compat #30133
fix for itemsize => element_size() for torch backwards compat #30133
Conversation
src/transformers/modeling_utils.py
Outdated
@@ -1160,7 +1160,7 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool | |||
# used for the 4bit quantization (uint8 tensors are stored) | |||
if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit): | |||
total_numel.append( | |||
param.numel() * 2 * self.hf_quantizer.quantization_config.bnb_4bit_quant_storage.itemsize | |||
param.numel() * 2 * self.hf_quantizer.quantization_config.bnb_4bit_quant_storage.element_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant to respond to the other one earlier, but may be good to have an if/else to know which version to call? E.g. param.numel() * 2 * elem_size
where elemsize
is defined earlier based on the torch version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correction, as @BenjaminBossan pointed out this is a must for us to have
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides itemsize not being available in torch 2.0.1, I couldn't get element_size() to work either (not a dtype attribute error).
I put this 'helper' function at the top of the file to do it manually:
def get_dtype_size(dtype):
if dtype == torch.float32:
return 4
elif dtype == torch.float64:
return 8
elif dtype == torch.float16:
return 2
elif dtype == torch.uint8:
return 1
elif dtype == torch.int8:
return 1
elif dtype == torch.int16:
return 2
elif dtype == torch.int32:
return 4
elif dtype == torch.int64:
return 8
else:
raise ValueError("Unsupported dtype")
and then used this in place of the itemsize code:
if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
quant_storage = self.hf_quantizer.quantization_config.bnb_4bit_quant_storage
nb_params = get_dtype_size(quant_storage)
total_numel.append(param.numel() * 2 * nb_params)
else:
total_numel.append(param.numel())
I've only tested it by using the qlora config on tiny-llama; and at least for this case it works:
accelerate launch -m axolotl.cli.train examples/tiny-llama/qlora.yml
I am using image 'winglian/axolotl-base:main-base-py3.11-cu121-2.2.1' (which actually uses torch 2.0.1 not 2.2.1 as the name implies.). Ubuntu 22.04 / laptop RTX4090 (16G)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you soooo much!!!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Note that this PR fails on PEFT CI. |
my last change isn't correct either. I'll take a look at this tomorrow. |
# For compatibility with older PT version - see: https://github.com/huggingface/peft/pull/1635 | ||
nb_params = ( | ||
quant_storage.itemsize if hasattr(quant_storage, "itemsize") else quant_storage.element_size() | ||
if hasattr(param, "element_size"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not simply use:
quant_storage = self.hf_quantizer.quantization_config.bnb_4bit_quant_storage
num_bytes = quant_storage.element_size()
element_size
seems present from 1.9.1: https://pytorch.org/docs/1.9.1/search.html?q=element_size&check_keywords=yes&area=default to latest: https://pytorch.org/docs/2.2/search.html?q=element_size&check_keywords=yes&area=default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.hf_quantizer.quantization_config.bnb_4bit_quant_storage
is a torch.dtype
instance while only the torch.tensor
has element_size()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking care of the backward compatibility for previous torch versions !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for handling this fix!
Just a small nit
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This is a similar fix to huggingface/peft#1630.
.itemsize
on a tensor is only supported on torch>=2.1Fixes: #30304
Edit by @younesbelkada:
This PR makes sure the checks are compatible with earlier versions of pytorch that have been overlooked by myself in #30162