Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into kylesayrs/upstream-candidates
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Nov 19, 2024
2 parents 0d23183 + a43dad2 commit 82235b3
Show file tree
Hide file tree
Showing 38 changed files with 567 additions and 1,746 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/trigger-all.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ jobs:
test_configs: '[{"python":"3.11.4","label":"ubuntu-22.04","timeout":"40"},
{"python":"3.10.12","label":"ubuntu-20.04","timeout":"40"},
{"python":"3.9.17","label":"k8s-a100-solo","timeout":"40"},
{"python":"3.8.17","label":"k8s-a100-duo","timeout":"40"}]'
{"python":"3.12.6","label":"k8s-a100-duo","timeout":"40"}]'

secrets: inherit
62 changes: 60 additions & 2 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from enum import Enum, unique
from typing import List, Optional

from compressed_tensors.registry import RegistryMixin
from pydantic import BaseModel


__all__ = ["SparsityCompressionConfig", "CompressionFormat"]
__all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"]


@unique
class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
Expand All @@ -32,6 +33,63 @@ class CompressionFormat(Enum):
marlin_24 = "marlin-24"


@unique
class SparsityStructure(Enum):
"""
An enumeration to represent different sparsity structures.
Attributes
----------
TWO_FOUR : str
Represents a 2:4 sparsity structure.
ZERO_ZERO : str
Represents a 0:0 sparsity structure.
UNSTRUCTURED : str
Represents an unstructured sparsity structure.
Examples
--------
>>> SparsityStructure('2:4')
<SparsityStructure.TWO_FOUR: '2:4'>
>>> SparsityStructure('unstructured')
<SparsityStructure.UNSTRUCTURED: 'unstructured'>
>>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR
True
>>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED
True
>>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
True
>>> SparsityStructure('invalid')
Traceback (most recent call last):
...
ValueError: invalid is not a valid SparsityStructure
"""

TWO_FOUR = "2:4"
UNSTRUCTURED = "unstructured"
ZERO_ZERO = "0:0"

def __new__(cls, value):
obj = object.__new__(cls)
obj._value_ = value.lower() if value is not None else value
return obj

@classmethod
def _missing_(cls, value):
# Handle None and case-insensitive values
if value is None:
return cls.UNSTRUCTURED
for member in cls:
if member.value == value.lower():
return member
raise ValueError(f"{value} is not a valid {cls.__name__}")


class SparsityCompressionConfig(RegistryMixin, BaseModel):
"""
Base data class for storing sparsity compression parameters
Expand Down
1 change: 0 additions & 1 deletion src/compressed_tensors/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,3 @@
from .quant_config import *
from .quant_scheme import *
from .lifecycle import *
from .cache import QuantizedKVParameterCache
201 changes: 0 additions & 201 deletions src/compressed_tensors/quantization/cache.py

This file was deleted.

2 changes: 0 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
# flake8: noqa
# isort: skip_file

from .calibration import *
from .forward import *
from .frozen import *
from .initialize import *
from .compressed import *
from .apply import *
Expand Down
17 changes: 1 addition & 16 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,9 @@

import torch
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.lifecycle.calibration import (
set_module_for_calibration,
)
from compressed_tensors.quantization.lifecycle.compressed import (
compress_quantized_weights,
)
from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
from compressed_tensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
Expand Down Expand Up @@ -233,6 +229,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
:param model: model to apply quantization to
:param status: status to update the module to
"""

current_status = infer_quantization_status(model)

if status >= QuantizationStatus.INITIALIZED > current_status:
Expand All @@ -243,18 +240,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
)
)

if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
# only quantize weights up front when our end goal state is calibration,
# weight quantization parameters are already loaded for frozen/compressed
quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
model.apply(
lambda module: set_module_for_calibration(
module, quantize_weights_upfront=quantize_weights_upfront
)
)
if current_status < status >= QuantizationStatus.FROZEN > current_status:
model.apply(freeze_module_quantization)

if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
model.apply(compress_quantized_weights)

Expand Down
Loading

0 comments on commit 82235b3

Please sign in to comment.