Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] compressed-tensors marlin 24 support #5435

Merged
merged 5 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from vllm import SamplingParams
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16,
CompressedTensors24, CompressedTensorsLinearMethod, CompressedTensorsW4A16,
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)


Expand Down Expand Up @@ -51,8 +51,7 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):

def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
with vllm_runner(model_path, enforce_eager=True,
dtype=torch.float16) as llm:
with vllm_runner(model_path, dtype=torch.float16) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]

Expand Down Expand Up @@ -83,3 +82,20 @@ def test_compressed_tensors_w4a16(vllm_runner, w4a16_args):
assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.weight_packed.pack_factor == 8


def test_compressed_tensors_w4a16_marlin24(vllm_runner):
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
with vllm_runner(model_path) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj

assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
assert qkv_proj.weight_packed.dtype is torch.int32

sampling_params = SamplingParams()
output = llm.generate("Hello world!", sampling_params=sampling_params)
assert output
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsW4A16,
CompressedTensors24, CompressedTensorsScheme, CompressedTensorsW4A16,
dsikka marked this conversation as resolved.
Show resolved Hide resolved
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
CompressionFormat, QuantizationArgs, QuantizationStrategy,
find_first_name_or_class_match)


class CompressedTensorsConfig(QuantizationConfig):

def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]):
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str],
quant_format: str):
self.ignore = ignore
self.layer_quant_details = layer_quant_details
self.quant_format = quant_format

def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
Expand Down Expand Up @@ -46,6 +49,7 @@ def get_quant_method(
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
layer_quant_details: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None)
quant_format: str = config.get("format", None)

# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
Expand All @@ -69,7 +73,9 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
except Exception:
layer_quant_details[target]["input_activations"] = None

return cls(layer_quant_details=layer_quant_details, ignore=ignore)
return cls(layer_quant_details=layer_quant_details,
ignore=ignore,
quant_format=quant_format)

@classmethod
def get_config_filenames(cls) -> List[str]:
Expand Down Expand Up @@ -110,17 +116,25 @@ def _get_schema(self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":

if self._is_w4a16(weight_quant, input_quant):
return CompressedTensorsW4A16(num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)

if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8StaticTensor()

if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8DynamicToken()

raise NotImplementedError("Scheme not supported.")
if self.quant_format == CompressionFormat.marlin_24.value:
return CompressedTensors24(strategy=weight_quant.strategy,
num_bits=weight_quant.num_bits,
group_size=weight_quant.group_size)
if self.quant_format == CompressionFormat.pack_quantized.value:
return CompressedTensorsW4A16(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)

if self.quant_format == CompressionFormat.int_quantized.value:
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8StaticTensor()

if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8DynamicToken()

raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")

def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":

Expand Down Expand Up @@ -165,9 +179,9 @@ def create_weights(self, layer: torch.nn.Module,
scheme = self.quantization_config.get_scheme(layer=layer)
scheme.create_weights(
layer=layer,
input_size=input_size,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
input_size=input_size,
output_size=output_size,
params_dtype=params_dtype,
weight_loader=weight_loader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized)
from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401
from .compressed_tensors_w4a16_24 import CompressedTensors24 # noqa: F401
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
CompressedTensorsW8A8DynamicToken)
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from typing import Callable, List, Optional

import torch
from torch.nn import Parameter

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
from vllm.model_executor.utils import set_weight_attrs

__all__ = ["CompressedTensors24"]


class CompressedTensors24(CompressedTensorsScheme):

def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None):
self.strategy = strategy
self.group_size = group_size
self.num_bits = num_bits
self.tile_size = 16

if self.strategy == "group" and self.group_size is None:
mgoin marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"group_size must be given when using strategy group")

def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

pack_factor = 32 // self.num_bits
output_size_per_partition = sum(output_partition_sizes)

qweight = Parameter(
torch.empty(
input_size_per_partition // self.tile_size // 2,
output_size_per_partition * self.tile_size // pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": pack_factor,
"marlin_tile_size": self.tile_size,
"weight_loader": weight_loader
},
)

layer.register_parameter("weight_packed", qweight)

input_groups = (1 if self.group_size is None else
input_size_per_partition // self.group_size)

scales = Parameter(
torch.empty(
input_groups,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"output_dim": 1,
"input_dim": None if input_groups == 1 else 0,
"weight_loader": weight_loader
},
)
layer.register_parameter("scale_packed", scales)

weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
requires_grad=False)

layer.register_parameter("weight_shape", weight_shape)
set_weight_attrs(weight_shape, {"weight_loader": weight_loader})

meta = Parameter(
torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
dtype=torch.int16,
),
requires_grad=False,
)
set_weight_attrs(
meta,
{
"input_dim": 0,
"packed_dim": 1,
"pack_factor": 1,
"output_dim": 1,
"marlin_tile_size": 2,
"weight_loader": weight_loader
},
)
layer.register_parameter("meta", meta)

max_workspace_size = (
output_size_per_partition //
GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL
workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int),
requires_grad=False)
layer.workspace = workspace

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
qweight = layer.weight_packed
meta = layer.meta
scales = layer.scale_packed
workspace = layer.workspace

x_2d = x.view(-1, x.shape[-1])

size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = scales.shape[1]

output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace, self.num_bits, size_m,
size_n, size_k)

output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
return output
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
from torch.nn import Module


class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
int_quantized = "int-quantized"
pack_quantized = "pack-quantized"
marlin_24 = "marlin-24"


class QuantizationType(str, Enum):
"""
Enum storing quantization type options
Expand Down
Loading