diff --git a/csrc/ops.h b/csrc/ops.h
index 4952e826ec8ac..06b60e748886f 100644
--- a/csrc/ops.h
+++ b/csrc/ops.h
@@ -97,6 +97,9 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
 void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
                               torch::Tensor const& scale);
 
+void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
+                               torch::Tensor& scales);
+
 void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
                      torch::Tensor lookup_table);
 
diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp
index cdbec4a34d77f..547823aa1b04e 100644
--- a/csrc/pybind.cpp
+++ b/csrc/pybind.cpp
@@ -70,6 +70,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
   ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
           "Compute int8 quantized tensor for given scaling factor");
 
+  ops.def("dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant,
+          "Compute int8 quantized tensor and scaling factor");
+
   // Cache ops
   pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
   cache_ops.def("swap_blocks", &swap_blocks,
diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu
index 11baa5d414c19..280b0327111da 100644
--- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu
@@ -3,6 +3,7 @@
 #include <cmath>
 
 #include "../../dispatch_utils.h"
+#include "../../reduction_utils.cuh"
 
 static inline __device__ int8_t float_to_int8_rn(float x) {
 #ifdef USE_ROCM
@@ -27,17 +28,48 @@ namespace vllm {
 
 template <typename scalar_t, typename scale_type>
 __global__ void static_scaled_int8_quant_kernel(
-    const scalar_t* __restrict__ input, int8_t* __restrict__ out,
-    const scale_type* scale_ptr, const int hidden_size) {
-  const int tid = threadIdx.x;
-  const int token_idx = blockIdx.x;
-  scale_type scale = *scale_ptr;
+    scalar_t const* __restrict__ input, int8_t* __restrict__ out,
+    scale_type const* scale_ptr, const int hidden_size) {
+  int const tid = threadIdx.x;
+  int const token_idx = blockIdx.x;
+  scale_type const scale = *scale_ptr;
 
   for (int i = tid; i < hidden_size; i += blockDim.x) {
-    out[token_idx * hidden_size + i] =
-        float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
+    out[token_idx * hidden_size + i] = float_to_int8_rn(
+        static_cast<float>(input[token_idx * hidden_size + i]) / scale);
   }
 }
+
+template <typename scalar_t, typename scale_type>
+__global__ void dynamic_scaled_int8_quant_kernel(
+    scalar_t const* __restrict__ input, int8_t* __restrict__ out,
+    scale_type* scale, const int hidden_size) {
+  int const tid = threadIdx.x;
+  int const token_idx = blockIdx.x;
+  float absmax_val = 0.0f;
+  float const zero = 0.0f;
+
+  for (int i = tid; i < hidden_size; i += blockDim.x) {
+    float val = static_cast<float>(input[token_idx * hidden_size + i]);
+    val = val > zero ? val : -val;
+    absmax_val = val > absmax_val ? val : absmax_val;
+  }
+
+  float const block_absmax_val_maybe = blockReduceMax(absmax_val);
+  __shared__ float block_absmax_val;
+  if (tid == 0) {
+    block_absmax_val = block_absmax_val_maybe;
+    scale[token_idx] = block_absmax_val / 127.0f;
+  }
+  __syncthreads();
+
+  float const tmp_scale = 127.0f / block_absmax_val;
+  for (int i = tid; i < hidden_size; i += blockDim.x) {
+    out[token_idx * hidden_size + i] = float_to_int8_rn(
+        static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
+  }
+}
+
 }  // namespace vllm
 
 void static_scaled_int8_quant(torch::Tensor& out,          // [..., hidden_size]
@@ -47,10 +79,10 @@ void static_scaled_int8_quant(torch::Tensor& out,          // [..., hidden_size]
   TORCH_CHECK(out.is_contiguous());
   TORCH_CHECK(scale.numel() == 1);
 
-  int hidden_size = input.size(-1);
-  int num_tokens = input.numel() / hidden_size;
-  dim3 grid(num_tokens);
-  dim3 block(std::min(hidden_size, 1024));
+  int const hidden_size = input.size(-1);
+  int const num_tokens = input.numel() / hidden_size;
+  dim3 const grid(num_tokens);
+  dim3 const block(std::min(hidden_size, 1024));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   VLLM_DISPATCH_FLOATING_TYPES(
       input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
@@ -60,3 +92,24 @@ void static_scaled_int8_quant(torch::Tensor& out,          // [..., hidden_size]
                                          scale.data_ptr<float>(), hidden_size);
       });
 }
+
+void dynamic_scaled_int8_quant(
+    torch::Tensor& out,          // [..., hidden_size]
+    torch::Tensor const& input,  // [..., hidden_size]
+    torch::Tensor& scales) {
+  TORCH_CHECK(input.is_contiguous());
+  TORCH_CHECK(out.is_contiguous());
+
+  int const hidden_size = input.size(-1);
+  int const num_tokens = input.numel() / hidden_size;
+  dim3 const grid(num_tokens);
+  dim3 const block(std::min(hidden_size, 1024));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  VLLM_DISPATCH_FLOATING_TYPES(
+      input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
+        vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
+            <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
+                                         out.data_ptr<int8_t>(),
+                                         scales.data_ptr<float>(), hidden_size);
+      });
+}
diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh
index 9af4aae516151..08063356012b8 100644
--- a/csrc/reduction_utils.cuh
+++ b/csrc/reduction_utils.cuh
@@ -21,29 +21,47 @@
 #include "cuda_compat.h"
 
 namespace vllm {
+
+namespace detail {
+
+template <typename T>
+__inline__ __device__ T _max(T a, T b) {
+  return max(a, b);
+}
+
+template <typename T>
+__inline__ __device__ T _sum(T a, T b) {
+  return a + b;
+}
+
+}  // namespace detail
+
+template <typename T>
+using ReduceFnType = T (*)(T, T);
+
+// Helper function to return the next largest power of 2
+static constexpr int _nextPow2(unsigned int num) {
+  if (num <= 1) return num;
+  return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
+}
+
 template <typename T, int numLanes = WARP_SIZE>
-__inline__ __device__ T warpReduceSum(T val) {
+__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) {
   static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
                 "numLanes is not a positive power of 2!");
   static_assert(numLanes <= WARP_SIZE);
 #pragma unroll
   for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
-    val += VLLM_SHFL_XOR_SYNC(val, mask);
-  return val;
-}
+    val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask));
 
-// Helper function to return the next largest power of 2
-static constexpr int _nextPow2(unsigned int num) {
-  if (num <= 1) return num;
-  return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
+  return val;
 }
 
-/* Calculate the sum of all elements in a block */
 template <typename T, int maxBlockSize = 1024>
-__inline__ __device__ T blockReduceSum(T val) {
+__inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) {
   static_assert(maxBlockSize <= 1024);
   if constexpr (maxBlockSize > WARP_SIZE) {
-    val = warpReduceSum<T>(val);
+    val = warpReduce<T>(val, fn);
     // Calculates max number of lanes that need to participate in the last
     // warpReduce
     constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
@@ -56,12 +74,22 @@ __inline__ __device__ T blockReduceSum(T val) {
 
     val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
                                                         : (T)(0.0f);
-    val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
+    val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn);
   } else {
     // A single warpReduce is equal to blockReduce
-    val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
+    val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn);
   }
   return val;
 }
 
+template <typename T, int maxBlockSize = 1024>
+__inline__ __device__ T blockReduceMax(T val) {
+  return blockReduce<T, maxBlockSize>(val, detail::_max<T>);
+}
+
+template <typename T, int maxBlockSize = 1024>
+__inline__ __device__ T blockReduceSum(T val) {
+  return blockReduce<T, maxBlockSize>(val, detail::_sum<T>);
+}
+
 }  // namespace vllm
diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py
index 29890118c93dc..aab7af9d2cbf6 100644
--- a/tests/kernels/test_int8_quant.py
+++ b/tests/kernels/test_int8_quant.py
@@ -4,27 +4,59 @@
 from vllm._C import ops
 
 DTYPES = [torch.half, torch.bfloat16, torch.float]
-HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192]  # Arbitrary values for testing
+HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
+                8193]  # Arbitrary values for testing
 NUM_TOKENS = [1, 7, 83, 4096]  # Arbitrary values for testing
 SEEDS = [0]
 SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
 
 
+@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
+@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
+@pytest.mark.parametrize("dtype", DTYPES)
+@pytest.mark.parametrize("seed", SEEDS)
+@torch.inference_mode()
+def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
+                                   dtype: torch.dtype, seed: int) -> None:
+    torch.random.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    int8_traits = torch.iinfo(torch.int8)
+
+    x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
+
+    x_token_max, _ = x.max(dim=1)
+    x_token_max = x_token_max.to(dtype=torch.float32)
+    scales = (x_token_max / float(127.0))[:, None].to(device="cuda",
+                                                      dtype=torch.float32)
+    torch_out = (x / scales).round().clamp(int8_traits.min,
+                                           int8_traits.max).to(torch.int8)
+
+    ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda")
+    scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda")
+    ops.dynamic_scaled_int8_quant(ops_out, x, scales_out)
+
+    assert torch.allclose(scales_out, scales)
+    assert torch.allclose(torch_out, ops_out,
+                          atol=1)  # big atol to account for rounding errors
+
+
 @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
 @pytest.mark.parametrize("dtype", DTYPES)
 @pytest.mark.parametrize("seed", SEEDS)
 @pytest.mark.parametrize("scale", SCALE)
 @torch.inference_mode()
-def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
-               seed: int, scale: float) -> None:
+def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
+                                  dtype: torch.dtype, seed: int,
+                                  scale: float) -> None:
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed(seed)
+    int8_traits = torch.iinfo(torch.int8)
+
     x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
 
-    out1 = (x / scale).round().clamp(
-        torch.iinfo(torch.int8).min,
-        torch.iinfo(torch.int8).max).to(torch.int8)
+    out1 = (x / scale).round().clamp(int8_traits.min,
+                                     int8_traits.max).to(torch.int8)
     out2 = torch.empty_like(x, dtype=torch.int8)
     scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
 
diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py
index b83286992da3d..8b48f418fe49f 100644
--- a/tests/quantization/test_compressed_tensors.py
+++ b/tests/quantization/test_compressed_tensors.py
@@ -6,7 +6,8 @@
 import torch
 
 from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (  # noqa: E501
-    CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor)
+    CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken,
+    CompressedTensorsW8A8StaticTensor)
 
 
 def test_compressed_tensors_w8a8_static_setup(vllm_runner):
@@ -34,3 +35,19 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
     assert qkv_proj.weight_scale.shard_splitter is not None
     assert qkv_proj.weight_scale.logical_widths is not None
     assert qkv_proj.input_scale.dtype is torch.float32
+
+
+def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
+    model_path = "nm-testing/tinyllama-one-shot-dynamic-test"
+    llm = vllm_runner(model_path,
+                      quantization="sparseml",
+                      enforce_eager=True,
+                      dtype=torch.float16)
+    model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
+    layer = model.model.layers[0]
+
+    qkv_proj = layer.self_attn.qkv_proj
+
+    assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
+    assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
+    assert qkv_proj.weight.dtype is torch.int8
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index 8a6f6d96d81f3..ddcd132079e30 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -264,21 +264,33 @@ def scaled_fp8_quant(
 
 
 # int8
-def static_scaled_int8_quant(input: torch.Tensor,
-                             scale: torch.Tensor) -> torch.Tensor:
+def scaled_int8_quant(
+        input: torch.Tensor,
+        scale: Optional[torch.Tensor] = None
+) -> Tuple[torch.Tensor, torch.Tensor]:
     """
-    Quantize the input tensor to int8 and return the quantized tensor.
+    Quantize the input tensor to int8 and return the quantized tensor and scale.
 
     Args:
         input: The input tensor to be quantized to int8.
-        scale: Scaling factor for the int8 quantization.
+        scale: Optional scaling factor for the int8 quantization.
+            When not provided, we invoke dynamic-per-token quantization.
 
     Returns:
-        torch.Tensor: Output tensor in int8.
+      Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
     """
-    q = torch.empty_like(input, dtype=torch.int8)
-    vllm_ops.static_scaled_int8_quant(q, input, scale)
-    return q
+    output = torch.empty_like(input, dtype=torch.int8)
+    if scale is not None:
+        # static-per-tensor quantization.
+        vllm_ops.static_scaled_int8_quant(output, input, scale)
+        return output, scale
+
+    # dynamic-per-token quantization.
+    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
+                               device=input.device,
+                               dtype=torch.float32)
+    vllm_ops.dynamic_scaled_int8_quant(output, input, input_scales)
+    return output, input_scales
 
 
 # moe
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
index 19e464bd64325..d2b0ce0dbbf0b 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
@@ -1,12 +1,16 @@
 from typing import Any, Dict, List, Optional
 
 import torch
+from pydantic import BaseModel
 
 from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
 from vllm.model_executor.layers.quantization.base_config import (  # noqa: E501
     QuantizationConfig)
 from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
-    CompressedTensorsScheme, CompressedTensorsW8A8StaticTensor)
+    CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken,
+    CompressedTensorsW8A8StaticTensor)
+from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
+    QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
 
 
 class CompressedTensorsConfig(QuantizationConfig):
@@ -47,10 +51,12 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
             targets = quant_config.get("targets")
             for target in targets:
                 layer_quant_details[target] = {}
-                layer_quant_details[target]["weight"] = quant_config.get(
-                    "weights")
-                layer_quant_details[target]["input"] = quant_config.get(
-                    "input_activations")
+                layer_quant_details[target][
+                    "weight"] = QuantizationArgs.parse_obj(
+                        quant_config.get("weights"))
+                layer_quant_details[target][
+                    "input"] = QuantizationArgs.parse_obj(
+                        quant_config.get("input_activations"))
 
         return cls(layer_quant_details=layer_quant_details, ignore=ignore)
 
@@ -58,40 +64,46 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
     def get_config_filenames(cls) -> List[str]:
         return []
 
-    def _get_schema(self, weight_quant: Dict, input_quant: Dict):
-        # TODO: Refactor as additional cases are supported
-
-        weight_bit = weight_quant.get("num_bits")
-        input_bit = input_quant.get("num_bits")
-
-        weight_strategy = weight_quant.get("strategy")
-        input_strategy = input_quant.get("strategy")
-
-        weight_symmetric = weight_quant.get("symmetric")
-        input_symmetric = input_quant.get("symmetric")
+    def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
+                               input_quant: BaseModel) -> bool:
+        is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
+        is_tensor = (weight_quant.strategy == input_quant.strategy ==
+                     QuantizationStrategy.TENSOR.value)
+        is_symmetric = weight_quant.symmetric and input_quant.symmetric
+        is_static = not weight_quant.dynamic and not input_quant.dynamic
+
+        return is_8_bits and is_tensor and is_symmetric and is_static
+
+    def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
+                               input_quant: BaseModel) -> bool:
+        is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
+        is_token_tensor = (weight_quant.strategy
+                           == QuantizationStrategy.TENSOR.value) and (
+                               input_quant.strategy
+                               == QuantizationStrategy.TOKEN.value)
+        is_symmetric = weight_quant.symmetric and input_quant.symmetric
+        is_dynamic = not weight_quant.dynamic and input_quant.dynamic
+
+        return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
+
+    def _get_schema(self, weight_quant: BaseModel,
+                    input_quant: BaseModel) -> "CompressedTensorsScheme":
+        if self._is_static_tensor_w8a8(weight_quant, input_quant):
+            return CompressedTensorsW8A8StaticTensor()
 
-        is_8_bits = weight_bit == input_bit == 8
-        is_tensor = weight_strategy == input_strategy == "tensor"
-        is_symmetric = weight_symmetric and input_symmetric
+        if self._is_dynamic_token_w8a8(weight_quant, input_quant):
+            return CompressedTensorsW8A8DynamicToken()
 
-        if is_8_bits and is_tensor and is_symmetric and \
-                torch.cuda.is_available():
-            # CompressedTensorsW8A8StaticTensor only supports CUDA path for
-            # now.
-            return CompressedTensorsW8A8StaticTensor()
-        raise NotImplementedError(
-            "Scheme not supported. Only CUDA, 8-bit static symmtetric "
-            "per tensor quantization is currently supported")
+        raise NotImplementedError("Scheme not supported.")
 
     def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
 
-        # TODO: update with matching function from `compressed_tensors`
-        layer_type_name = None
-        layer_name_class = type(layer).__name__.lower()
-        for target in self.layer_quant_details:
-            if target.lower() in layer_name_class:
-                layer_type_name = target
-                break
+        layer_type_name = find_first_name_or_class_match(
+            name="",
+            module=layer,
+            targets=self.layer_quant_details.keys(),
+            check_contains=True)
+
         if layer_type_name is None:
             raise ValueError(f"Could not matching target for layer {layer}")
 
@@ -117,7 +129,9 @@ def create_weights(self, layer: torch.nn.Module,
                        **extra_weight_attrs):
         """
         Use the CompressedTensorsScheme associated with each layer to create 
-        the necessary parameters for the layer.
+        the necessary parameters for the layer. See LinearMethodBase for param
+        details
+
         """
         weight_loader = extra_weight_attrs.get("weight_loader")
 
@@ -139,7 +153,8 @@ def apply(self,
         """
         Use the output of create_weights and the CompressedTensorsScheme 
         associated with the layer to apply the forward pass with the 
-        layer input.
+        layer input.  See LinearMethodBase for param details
+
         """
 
         if bias is not None:
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
index 831905b63e2c9..9a910f061f580 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
@@ -1,5 +1,7 @@
 from .compressed_tensors_scheme import CompressedTensorsScheme  # noqa: F401
 from .compressed_tensors_unquantized import (  # noqa: F401
     CompressedTensorsUnquantized)
+from .compressed_tensors_w8a8_dynamictoken import (  # noqa: F401, E501
+    CompressedTensorsW8A8DynamicToken)
 from .compressed_tensors_w8a8_statictensor import (  # noqa: F401, E501
     CompressedTensorsW8A8StaticTensor)
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py
new file mode 100644
index 0000000000000..25b707caeef33
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py
@@ -0,0 +1,85 @@
+from typing import Callable, List, Tuple, Union
+
+import torch
+from torch.nn import Parameter
+
+from vllm import _custom_ops as custom_ops
+from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
+    CompressedTensorsScheme)
+from vllm.model_executor.utils import set_weight_attrs
+
+__all__ = ["CompressedTensorsW8A8DynamicToken"]
+
+
+class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
+
+    def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
+        if isinstance(shard_id, int):
+            return shard_id
+
+        assert isinstance(shard_id, str)
+        qkv_idxs = {"q": 0, "k": 1, "v": 2}
+        assert shard_id in qkv_idxs
+        return qkv_idxs[shard_id]
+
+    def scales_shard_splitter(
+            self, param: torch.Tensor, loaded_weight: torch.Tensor,
+            shard_id: Union[str, int],
+            logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        shard_id = self._shard_id_as_int(shard_id)
+        offset = sum(logical_widths[:shard_id])
+        size = logical_widths[shard_id]
+        # update loaded weight with copies for broadcast.
+        loaded_weight = loaded_weight.repeat(size)
+        return param[offset:offset + size], loaded_weight
+
+    def create_weights(self, layer: torch.nn.Module,
+                       output_partition_sizes: List[int],
+                       input_size_per_partition: int,
+                       params_dtype: torch.dtype, weight_loader: Callable,
+                       **kwargs):
+
+        # When the scales have a single value, it is required that they be
+        # on the CPU for performance and CUDA Graphs compatibility. Please
+        # refer to the comment in
+        # CompressedTensorsW8A8StaticTensor::create_weights for further
+        # information.
+        is_tensor_partitioned = len(output_partition_sizes) != 1
+        weight_scale_dim = sum(
+            output_partition_sizes) if is_tensor_partitioned else 1
+
+        weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
+                                      requires_grad=False)
+
+        weight_scale = Parameter(torch.empty(weight_scale_dim,
+                                             dtype=torch.float32),
+                                 requires_grad=False)
+
+        weight = Parameter(torch.empty(sum(output_partition_sizes),
+                                       input_size_per_partition,
+                                       dtype=torch.int8),
+                           requires_grad=False)
+
+        layer.register_parameter("weight", weight)
+        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
+        set_weight_attrs(weight, {"weight_loader": weight_loader})
+        set_weight_attrs(weight, {"logical_widths": output_partition_sizes})
+
+        layer.register_parameter("weight_scale", weight_scale)
+        set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
+        set_weight_attrs(
+            weight_scale, {
+                "shard_splitter": self.scales_shard_splitter,
+                "logical_widths": output_partition_sizes
+            })
+
+        layer.register_parameter("weight_zero_point", weight_zero_point)
+        set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader})
+
+    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
+        weight = layer.weight
+        weight_scale = layer.weight_scale
+
+        x_q, input_scales = custom_ops.scaled_int8_quant(x)
+        return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), input_scales,
+                                               weight_scale, x.dtype)
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
index 2dfc6e2b07782..7559fc0f95b24 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
@@ -97,7 +97,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
         act_scale = layer.input_scale
 
         # Input quantize
-        x_q = custom_ops.static_scaled_int8_quant(x, act_scale)
+        x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
 
         return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
                                                weight_scale, x.dtype)
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py
new file mode 100644
index 0000000000000..fcc6649101845
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py
@@ -0,0 +1,114 @@
+import re
+from enum import Enum
+from typing import Any, Dict, Iterable, Optional
+
+from pydantic import BaseModel, Field
+from torch.nn import Module
+
+
+class QuantizationType(str, Enum):
+    """
+    Enum storing quantization type options
+    """
+
+    INT = "int"
+    FLOAT = "float"
+
+
+class QuantizationStrategy(str, Enum):
+    """
+    Enum storing quantization strategy options
+    """
+
+    TENSOR = "tensor"
+    CHANNEL = "channel"
+    GROUP = "group"
+    BLOCK = "block"
+    TOKEN = "token"
+
+
+class QuantizationArgs(BaseModel):
+    """
+    User facing arguments used to define a quantization config 
+    for weights or activations
+
+    :param num_bits: quantization bit depth
+    :param type: dtype to quantized to, either int or float
+    :param symmetric: whether or not quantization scale is symmetric
+    :param strategy: string determining the scope of scale/zero-point to apply
+    :param group_size: group length to use for the group strategy
+    :param block_structure: 2d block structure to use for the block 
+    strategy, must be of the format "2x4", "8x16", etc.
+    :param dynamic: set True to perform dynamic quantization -
+        values will not be calibrated during calibration phase, 
+        instead during inference new quantization ranges will be 
+        observed with every sample. Defaults to False for static
+        quantization. Note that enabling dynamic quantization 
+        will change the default observer to a memoryless one
+    """
+
+    num_bits: int = 8
+    type: QuantizationType = QuantizationType.INT
+    symmetric: bool = True
+    group_size: Optional[int] = None
+    strategy: Optional[QuantizationStrategy] = None
+    block_structure: Optional[str] = None
+    dynamic: bool = False
+    observer: str = Field(
+        default="minmax",
+        description=("The class to use to compute the quantization param - "
+                     "scale and zero-point'"),
+    )
+    observer_kwargs: Dict[str, Any] = Field(
+        default_factory=dict,
+        description=
+        ("optional dict of kwargs to be passed directly to torch quantization "
+         "Observers constructor excluding quantization range or symmetry"),
+    )
+
+
+def find_first_name_or_class_match(
+        name: str,
+        module: Module,
+        targets: Iterable[str],
+        check_contains: bool = False) -> Optional[str]:
+    """
+    Helper function to map the quantization details listed in the config 
+    for a given list of targets against each model layer. First uses the
+    layer name to try and find a match. If no name match is found, uses
+    the layer class name. Returns None otherwise.
+
+    :param name: layer name
+    :param module: torch.nn.Module
+    :param targets: list of targets to match the layer against
+    :param check_contains: whether or not to do a substring match
+    """
+
+    return _find_first_match(name, targets) or _find_first_match(
+        module.__class__.__name__, targets, check_contains)
+
+
+def _find_first_match(value: str,
+                      targets: Iterable[str],
+                      check_contains: bool = False) -> Optional[str]:
+    """
+    Returns first element of target that matches value either
+    exactly or as a regex after 're:'. If check_contains is set to True,
+    additionally checks if the target string is contained within the value.
+
+    :param value: string to compare the list of targets against
+    :param targets: list of targets to match the layer against
+    :param check_contains: whether or not to do a substring match
+    """
+
+    for target in targets:
+        if target.startswith("re:"):
+            pattern = target[3:]
+            if re.match(pattern, value):
+                return target
+        elif check_contains:
+            if target.lower() in value.lower():
+                return target
+        elif target == value:
+            return target
+    return None