diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index eba5c5882..4c1e798b2 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -236,6 +236,10 @@ def patch_tied_tensors_bug(model: torch.nn.Module): input_embed = model.get_input_embeddings() output_embed = model.get_output_embeddings() + if input_embed is None or output_embed is None: + # some models fail to properly override the abstract methods + return + if storage_ptr(input_embed.weight) == storage_ptr(output_embed.weight): for module in (input_embed, output_embed): if not is_module_offloaded(module):