Skip to content

Commit

Permalink
comment
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Nov 22, 2024
1 parent 81fa1bb commit 65f6acf
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 0 additions & 2 deletions src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]:
return None # module is not quantized
quantization_scheme = module.quantization_scheme
if not hasattr(quantization_scheme, "weights"):
# models that ran CompressedLinear.from_linear will
# run delattr(module, "weight")
return None # weights are not quantized

quantization_args = quantization_scheme.weights
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import torch
import transformers
from compressed_tensors.base import (
COMPRESSION_CONFIG_NAME,
COMPRESSION_VERSION_NAME,
QUANTIZATION_CONFIG_NAME,
QUANTIZATION_METHOD_NAME,
Expand Down Expand Up @@ -103,14 +102,14 @@ def from_pretrained(
:return: compressor for the configs, or None if model is not compressed
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None) or getattr(
config, QUANTIZATION_CONFIG_NAME, None
)
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)

return cls.from_compression_config(compression_config)

@classmethod
def from_compression_config(
cls, compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
cls,
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
):
"""
:param compression_config:
Expand Down Expand Up @@ -267,7 +266,10 @@ def compress(
state_dict = model.state_dict()

compressed_state_dict = state_dict

# submodule name to q_args
quantized_modules_to_args = map_modules_to_quant_args(model)

if self.quantization_compressor is not None:
compressed_state_dict = self.quantization_compressor.compress(
state_dict, names_to_scheme=quantized_modules_to_args
Expand Down

0 comments on commit 65f6acf

Please sign in to comment.