Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix for itemsize => element_size() for torch backwards compat #30133

Merged
merged 6 commits into from
Apr 23, 2024

Conversation

winglian
Copy link
Contributor

@winglian winglian commented Apr 9, 2024

This is a similar fix to huggingface/peft#1630.

.itemsize on a tensor is only supported on torch>=2.1

Fixes: #30304

Edit by @younesbelkada:

This PR makes sure the checks are compatible with earlier versions of pytorch that have been overlooked by myself in #30162

@winglian
Copy link
Contributor Author

winglian commented Apr 9, 2024

@pacman100 @BenjaminBossan

@@ -1160,7 +1160,7 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool
# used for the 4bit quantization (uint8 tensors are stored)
if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
total_numel.append(
param.numel() * 2 * self.hf_quantizer.quantization_config.bnb_4bit_quant_storage.itemsize
param.numel() * 2 * self.hf_quantizer.quantization_config.bnb_4bit_quant_storage.element_size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to respond to the other one earlier, but may be good to have an if/else to know which version to call? E.g. param.numel() * 2 * elem_size where elemsize is defined earlier based on the torch version

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correction, as @BenjaminBossan pointed out this is a must for us to have

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides itemsize not being available in torch 2.0.1, I couldn't get element_size() to work either (not a dtype attribute error).

I put this 'helper' function at the top of the file to do it manually:

def get_dtype_size(dtype):
if dtype == torch.float32:
return 4
elif dtype == torch.float64:
return 8
elif dtype == torch.float16:
return 2
elif dtype == torch.uint8:
return 1
elif dtype == torch.int8:
return 1
elif dtype == torch.int16:
return 2
elif dtype == torch.int32:
return 4
elif dtype == torch.int64:
return 8
else:
raise ValueError("Unsupported dtype")

and then used this in place of the itemsize code:

if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
quant_storage = self.hf_quantizer.quantization_config.bnb_4bit_quant_storage
nb_params = get_dtype_size(quant_storage)
total_numel.append(param.numel() * 2 * nb_params)
else:
total_numel.append(param.numel())

I've only tested it by using the qlora config on tiny-llama; and at least for this case it works:

accelerate launch -m axolotl.cli.train examples/tiny-llama/qlora.yml

I am using image 'winglian/axolotl-base:main-base-py3.11-cu121-2.2.1' (which actually uses torch 2.0.1 not 2.2.1 as the name implies.). Ubuntu 22.04 / laptop RTX4090 (16G)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you soooo much!!!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

Note that this PR fails on PEFT CI.

@winglian
Copy link
Contributor Author

winglian commented Apr 9, 2024

my last change isn't correct either. I'll take a look at this tomorrow.

@younesbelkada
Copy link
Contributor

Hi @winglian , as pointed out by @hiyouga my fix did not cover all the cases, would you be happy to rebase your PR with main and we merge it ?

# For compatibility with older PT version - see: https://github.com/huggingface/peft/pull/1635
nb_params = (
quant_storage.itemsize if hasattr(quant_storage, "itemsize") else quant_storage.element_size()
if hasattr(param, "element_size"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simply use:

quant_storage = self.hf_quantizer.quantization_config.bnb_4bit_quant_storage
num_bytes = quant_storage.element_size()

element_size seems present from 1.9.1: https://pytorch.org/docs/1.9.1/search.html?q=element_size&check_keywords=yes&area=default to latest: https://pytorch.org/docs/2.2/search.html?q=element_size&check_keywords=yes&area=default

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@hiyouga hiyouga Apr 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.hf_quantizer.quantization_config.bnb_4bit_quant_storage is a torch.dtype instance while only the torch.tensor has element_size()

huggingface/peft#1635

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking care of the backward compatibility for previous torch versions !

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for handling this fix!

Just a small nit

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@younesbelkada younesbelkada merged commit 57fc00f into huggingface:main Apr 23, 2024
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AttributeError: 'torch.dtype' object has no attribute 'element_size'
9 participants