Skip to content

Commit

Permalink
Support bf16 inputs for GPTQ/Marlin format quantization (#90)
Browse files Browse the repository at this point in the history
Support bf16 inputs for GPTQ/Marlin format quantization
  • Loading branch information
guoqingbao authored Oct 15, 2024
1 parent 1b4b0d4 commit e62b7e1
Show file tree
Hide file tree
Showing 19 changed files with 545 additions and 245 deletions.
1 change: 1 addition & 0 deletions .typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
extend-ignore-identifiers-re = [
"mmaped",
"arange",
"cudaDevAttrMaxSharedMemoryPerBlockOptin",
]
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Efficient, easy-to-use platform for inference and serving local LLMs including a
- Efficient management of key-value cache with PagedAttention.
- Continuous batching.
- `In-situ` quantization
- `GPTQ/Marlin` format quantization
- `GPTQ/Marlin` format quantization (4-bit)

## Develop Status

Expand Down Expand Up @@ -192,15 +192,15 @@ async def benchmark():
asyncio.run(benchmark())
```

## GPTQ/Marlin quantization
Candle-vllm now supports GPTQ (Marlin kernel), you may supply the `quant` (marlin) and `dtype` (f16) parameters if you have `Marlin` format quantized weights, such as:
## GPTQ/Marlin 4-bit quantization
Candle-vllm now supports GPTQ (Marlin kernel), you may supply the `quant` (marlin) parameter if you have `Marlin` format quantized weights, such as:

```
cargo run --release -- --port 2000 --dtype f16 --weight-path /home/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4-Marlin/ llama3 --quant marlin
```
You may also use `AutoGPTQ` to transform a model to marlin format by loading the model and supply the `use_marlin=True` in `AutoGPTQ`.
You may also use `AutoGPTQ` to transform a model to marlin format by loading the (quantized) model, supplying the `use_marlin=True` in `AutoGPTQ` and resaving it with "save_pretrained".

**Note:** only 4bit GPTQ quantization supported for marlin format at the moment, and the input data type should be `f16` (--dtype f16). You need also renamed the transformed marlin format weight to "model.safetensors" and copy the "tokenizer.json" from the source model folder.
**Note:** only 4-bit GPTQ (marlin format) quantization supported at the moment, and the input data type should be `f16` (--dtype f16) or `bf16` (--dtype bf16). You need rename the transformed marlin weight to "model.safetensors" and copy the "tokenizer.json" from the source model folder.

## In-situ quantization for consumer-grade GPUs

Expand Down
23 changes: 17 additions & 6 deletions kernels/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,31 @@ extern "C" {
weight: *const c_int,
scales: *const c_void,
out: *const c_void,
m: c_int,
k: c_int,
n: c_int,
m: c_int,
k: c_int,
n: c_int,
workspace: *const c_void, //tensor with at least `n / 128 * max_par` entries that are all zero
groupsize: c_int,
) -> i32;
);

pub fn marlin_4bit_bf16(
inputs: *const c_void,
weight: *const c_int,
scales: *const c_void,
out: *const c_void,
m: c_int,
k: c_int,
n: c_int,
workspace: *const c_void, //tensor with at least `n / 128 * max_par` entries that are all zero
groupsize: c_int,
);

pub fn gptq_marlin_repack(
weight: *const c_void,
perm: *const c_void,
result: *const c_void,
k: c_int,
k: c_int,
n: c_int,
bits: c_int,
);

}
3 changes: 2 additions & 1 deletion kernels/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub const COPY_BLOCKS_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/copy_blocks_kernel.ptx"));
pub const MARLIN_CUDA_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/marlin_cuda_kernel.ptx"));
pub const MARLIN_CUDA_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/marlin_cuda_kernel.ptx"));
pub const PAGEDATTENTION: &str = include_str!(concat!(env!("OUT_DIR"), "/pagedattention.ptx"));
pub const RESHAPE_AND_CACHE_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx"));
Expand Down
118 changes: 118 additions & 0 deletions kernels/src/marlin/marlin.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include <cassert>

// #define CHECK(cond, ...) \
// assert(cond); \
#define CHECK(cond, ...)

namespace marlin {

// Marlin params

// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.

static constexpr int repack_threads = 256;
static constexpr int repack_stages = 8;
static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;

static constexpr int tile_size = 16;
static constexpr int max_par = 16;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;

__device__ inline constexpr int ceildiv(int a, int b) {
return (a + b - 1) / b;
}

// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}

// Asynchronous global->shared copy
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}

// Async copy fence.
__device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}

// Wait until at most `n` async copy stages are still pending.
template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}

// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
}
__syncthreads();
}

// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
}
}

// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) { return elems[i]; }
};

using I4 = Vec<int, 4>;

} // namespace marlin
79 changes: 79 additions & 0 deletions kernels/src/marlin/marlin_dtypes.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@

#ifndef _data_types_cuh
#define _data_types_cuh
#include "marlin.cuh"
#include <cuda_fp16.h>
#include <cuda_bf16.h>

namespace marlin {

template <typename scalar_t>
class ScalarType {};

template <>
class ScalarType<half> {
public:
using scalar_t = half;
using scalar_t2 = half2;

// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>;
using FragZP = Vec<half2, 4>;

static __device__ float inline num2float(const half x) {
return __half2float(x);
}

static __device__ half2 inline num2num2(const half x) {
return __half2half2(x);
}

static __device__ half2 inline nums2num2(const half x1, const half x2) {
return __halves2half2(x1, x2);
}

static __host__ __device__ half inline float2num(const float x) {
return __float2half(x);
}
};

template <>
class ScalarType<nv_bfloat16> {
public:
using scalar_t = nv_bfloat16;
using scalar_t2 = nv_bfloat162;

using FragA = Vec<nv_bfloat162, 4>;
using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) {
return __bfloat162float(x);
}

static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
return __bfloat162bfloat162(x);
}

static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
const nv_bfloat16 x2) {
return __halves2bfloat162(x1, x2);
}

static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
return __float2bfloat16(x);
}
#endif
};

} // namespace marlin

#endif
Loading

0 comments on commit e62b7e1

Please sign in to comment.