Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Naive Run Compressed Support #109

Merged
merged 32 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
eba80e0
fix constant def
Jun 4, 2024
3f558a3
fix constant def
Jun 4, 2024
3ef57b2
fix config load from dict
Jun 26, 2024
b8ade2f
fix assignment of parameters on meta device
Jun 28, 2024
e357ff6
Merge branch 'main' into rename_config
horheynm Jun 28, 2024
27e7d2b
initial commit
Jul 2, 2024
c55b2e0
refactor compressor
Jul 4, 2024
99d3d2d
fix circular imports
Jul 10, 2024
d9c46a4
Merge branch 'rename_config' of github.com:neuralmagic/compressed-ten…
Jul 10, 2024
c8d526f
Merge branch 'main' into rename_config
Jul 10, 2024
e6c1f18
Merge branch 'rename_config' into sa/naive_run_compressed
Jul 10, 2024
be6f58c
fixes for hfquantizer
Jul 10, 2024
fc08911
fix tests
Jul 10, 2024
afa963a
update to compressed state
Jul 10, 2024
06a2eb9
update imports
Jul 11, 2024
d61e1c0
Merge branch 'main' into sa/naive_run_compressed
Aug 6, 2024
7acdf2f
fix rebase errors
Aug 6, 2024
c335bb9
fixing tests
Aug 6, 2024
7789554
fixes
Aug 7, 2024
8dff9f5
fix input compression
Aug 7, 2024
50357ab
Merge branch 'main' into sa/naive_run_compressed
Aug 8, 2024
67596da
style
Aug 8, 2024
d3cc494
docstrings and cleanup
Aug 8, 2024
1920dcc
Merge branch 'main' into sa/naive_run_compressed
Aug 12, 2024
5bdb7d3
clarity comments
Aug 15, 2024
eede715
Merge branch 'main' into sa/naive_run_compressed
Aug 21, 2024
ca4fa3e
Merge branch 'main' into sa/naive_run_compressed
Aug 23, 2024
db8cec9
Merge branch 'main' into sa/naive_run_compressed
Aug 27, 2024
e5afcd6
PR comments
Aug 27, 2024
9d8cf80
fix shape
Aug 27, 2024
8de28ba
fix packed:
Aug 27, 2024
45609ff
wrap linear
Aug 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 197 additions & 8 deletions src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Generator, Tuple, Union
Satrat marked this conversation as resolved.
Show resolved Hide resolved
import logging
from typing import Dict, Generator, Optional, Tuple, Union

import torch
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationConfig
from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
from compressed_tensors.registry import RegistryMixin
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from safetensors import safe_open
from torch import Tensor
from torch.nn.modules import Module
from tqdm import tqdm


_LOGGER: logging.Logger = logging.getLogger(__name__)

__all__ = ["Compressor"]


class Compressor(RegistryMixin):
"""
Base class representing a model compression algorithm
Base class representing a model compression algorithm. Each child class should
implement compression_param_info, compress_weight and decompress_weight.

Compressors support compressing/decompressing a full module state dict or a single
quantized PyTorch leaf module.

Model Load Lifecycle (run_compressed=False):
- ModelCompressor.decompress()
- apply_quantization_config()
- Compressor.decompress()
- Compressor.decompress_weight()

Model Save Lifecycle:
- ModelCompressor.compress()
- Compressor.compress()
- Compressor.compress_weight()

Module Lifecycle (run_compressed=True):
- apply_quantization_config()
- compressed_module = CompressedLinear(module)
- initialize_module_for_quantization()
- Compressor.compression_param_info()
- register_parameters()
- compressed_module.forward()
-compressed_module.decompress()


:param config: config specifying compression parameters
"""
Expand All @@ -35,26 +68,182 @@ def __init__(
):
self.config = config

def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
def compression_param_info(
self,
weight_shape: torch.Size,
quantization_args: Optional[QuantizationArgs] = None,
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
"""
Creates a dictionary of expected shapes and dtypes for each compression
parameter used by the compressor

:param weight_shape: uncompressed weight shape
:param quantization_args: quantization parameters for the weight
:return: dictionary mapping compressed parameter names to shape and dtype
"""
raise NotImplementedError()

def compress(
self,
model_state: Dict[str, Tensor],
names_to_scheme: Dict[str, QuantizationArgs],
**kwargs,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict

:param model_state: state dict of uncompressed model
:param names_to_scheme: quantization args for each quantized weight, needed for
quantize function to calculate bit depth
:return: compressed state dict
"""
raise NotImplementedError()
compressed_dict = {}
weight_suffix = ".weight"
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)

for name, value in tqdm(model_state.items(), desc="Compressing model"):
if name.endswith(weight_suffix):
prefix = name[: -(len(weight_suffix))]
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
if scale is not None:
# weight is quantized, compress it
quant_args = names_to_scheme[prefix]
compressed_data = self.compress_weight(
weight=value,
scale=scale,
zero_point=zp,
quantization_args=quant_args,
)
for key, value in compressed_data.items():
compressed_dict[merge_names(prefix, key)] = value
else:
compressed_dict[name] = value.to("cpu")
elif name.endswith("zero_point") and torch.all(value == 0):
# all zero_points are 0, no need to include in
# compressed state_dict
continue
else:
compressed_dict[name] = value.to("cpu")

return compressed_dict

def decompress(
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
self,
path_to_model_or_tensors: str,
names_to_scheme: Dict[str, QuantizationArgs],
device: str = "cpu",
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Reads a compressed state dict located at path_to_model_or_tensors
and returns a generator for sequentially decompressing back to a
dense state dict

:param model_path: path to compressed safetensors model (directory with
one or more safetensors files) or compressed tensors file
:param path_to_model_or_tensors: path to compressed safetensors model (directory
with one or more safetensors files) or compressed tensors file
:param names_to_scheme: quantization args for each quantized weight
:param device: optional device to load intermediate weights into
:return: compressed state dict
"""
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
)
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, safe_path in weight_mappings[weight_name].items():
full_name = merge_names(weight_name, param_name)
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)

if "weight_scale" in weight_data:
quant_args = names_to_scheme[weight_name]
decompressed = self.decompress_weight(
compressed_data=weight_data, quantization_args=quant_args
)
yield merge_names(weight_name, "weight"), decompressed

def compress_weight(
self,
weight: Tensor,
scale: Tensor,
zero_point: Optional[Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
quantization_args: Optional[QuantizationArgs] = None,
) -> Dict[str, torch.Tensor]:
"""
Compresses a single uncompressed weight

:param weight: uncompressed weight tensor
:param scale: quantization scale for weight
:param zero_point: quantization zero point for weight
:param g_idx: optional mapping from column index to group index
:param quantization_args: quantization parameters for weight
:return: dictionary of compressed weight data
"""
raise NotImplementedError()

def decompress_weight(
self,
compressed_data: Dict[str, Tensor],
quantization_args: Optional[QuantizationArgs] = None,
) -> torch.Tensor:
"""
Decompresses a single compressed weight

:param compressed_data: dictionary of data needed for decompression
:param quantization_args: quantization parameters for the weight
:return: tensor of the decompressed weight
"""
raise NotImplementedError()

def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]:
"""
Compresses a single quantized leaf PyTorch module. If the module is not
quantized, this function has no effect.

:param module: PyTorch module to compress
:return: dictionary of compressed weight data, or None if module is not
quantized
"""
if not hasattr(module, "quantization_scheme"):
return None # module is not quantized
quantization_scheme = module.quantization_scheme
if not hasattr(quantization_scheme, "weights"):
return None # weights are not quantized

quantization_args = quantization_scheme.weights
weight = getattr(module, "weight", None)
weight_scale = getattr(module, "weight_scale", None)
weight_zero_point = getattr(module, "weight_zero_point", None)

return self.compress_weight(
weight=weight,
scale=weight_scale,
zero_point=weight_zero_point,
quantization_args=quantization_args,
)

def decompress_module(self, module: Module):
"""
Decompresses a single compressed leaf PyTorch module. If the module is not
quantized, this function has no effect.

:param module: PyTorch module to decompress
:return: tensor of the decompressed weight, or None if module is not quantized
"""
if not hasattr(module, "quantization_scheme"):
return None # module is not quantized
quantization_scheme = module.quantization_scheme
if not hasattr(quantization_scheme, "weights"):
return None # weights are not quantized

quantization_args = quantization_scheme.weights
compressed_data = {}
for name, parameter in module.named_parameters():
compressed_data[name] = parameter

return self.decompress_weight(
compressed_data=compressed_data, quantization_args=quantization_args
)
8 changes: 8 additions & 0 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ def parse_sparsity_config(compression_config: Dict) -> Union[Dict, None]:
if hasattr(compression_config, SPARSITY_CONFIG_NAME):
# for loaded HFQuantizer config
return getattr(compression_config, SPARSITY_CONFIG_NAME)
if SPARSITY_CONFIG_NAME in compression_config:
# for loaded HFQuantizer config from dict
return compression_config[SPARSITY_CONFIG_NAME]

# SparseAutoModel format
return compression_config.get(SPARSITY_CONFIG_NAME, None)
Expand All @@ -189,6 +192,10 @@ def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
# for loaded HFQuantizer config
return getattr(compression_config, QUANTIZATION_CONFIG_NAME)

if QUANTIZATION_CONFIG_NAME in compression_config:
# for loaded HFQuantizer config from dict
return compression_config[QUANTIZATION_CONFIG_NAME]

# SparseAutoModel format
quantization_config = deepcopy(compression_config)
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
Expand Down Expand Up @@ -234,6 +241,7 @@ def compress(
compressed_state_dict = self.quantization_compressor.compress(
state_dict, names_to_scheme=quantized_modules_to_args
)
self.quantization_config.quantization_status = QuantizationStatus.COMPRESSED

if self.sparsity_compressor is not None:
compressed_state_dict = self.sparsity_compressor.compress(
Expand Down
Loading
Loading