Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

lazy import CompressedTensorsW8A8StaticTensor #220

Merged
merged 3 commits into from
May 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
@@ -6,9 +6,19 @@

static inline __device__ int8_t float_to_int8_rn(float x)
{
#ifdef USE_ROCM
float dst;
// Round to nearest even
asm volatile("v_rndne_f32 %0, %1;\n" : "=r"(dst) : "v"(x));
// Saturate
dst = dst < -128.0f ? -128.0f : dst:
dst = dst > 127.0f ? 127.0f : dst;
return static_cast<int8_t>(dst);
#else
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
#endif
}

namespace vllm {
Original file line number Diff line number Diff line change
@@ -6,8 +6,7 @@
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsUnquantized,
CompressedTensorsW8A8StaticTensor)
CompressedTensorsScheme, CompressedTensorsUnquantized)


class CompressedTensorsConfig(QuantizationConfig):
@@ -80,7 +79,11 @@ def _get_schema(self, weight_quant: Dict, input_quant: Dict):
is_tensor = weight_strategy == input_strategy == "tensor"
is_symmetric = weight_symmetric and input_symmetric

if is_8_bits and is_tensor and is_symmetric:
if is_8_bits and is_tensor and is_symmetric and \
torch.cuda.is_available():
# CompressedTensorsW8A8StaticTensor only supports CUDA path for now.
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( # noqa: E501
CompressedTensorsW8A8StaticTensor)
return CompressedTensorsW8A8StaticTensor(
fake_quant=self.fake_quant)
raise NotImplementedError(