-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support bf16 inputs for GPTQ/Marlin format quantization (#90)
Support bf16 inputs for GPTQ/Marlin format quantization
- Loading branch information
1 parent
1b4b0d4
commit e62b7e1
Showing
19 changed files
with
545 additions
and
245 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,5 @@ | |
extend-ignore-identifiers-re = [ | ||
"mmaped", | ||
"arange", | ||
"cudaDevAttrMaxSharedMemoryPerBlockOptin", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#pragma once | ||
#include <cuda.h> | ||
#include <cuda_fp16.h> | ||
#include <cuda_runtime.h> | ||
#include <iostream> | ||
#include <cassert> | ||
|
||
// #define CHECK(cond, ...) \ | ||
// assert(cond); \ | ||
#define CHECK(cond, ...) | ||
|
||
namespace marlin { | ||
|
||
// Marlin params | ||
|
||
// 8 warps are a good choice since every SM has 4 schedulers and having more | ||
// than 1 warp per schedule allows some more latency hiding. At the same time, | ||
// we want relatively few warps to have many registers per warp and small tiles. | ||
|
||
static constexpr int repack_threads = 256; | ||
static constexpr int repack_stages = 8; | ||
static constexpr int min_thread_n = 64; | ||
static constexpr int min_thread_k = 64; | ||
|
||
static constexpr int tile_size = 16; | ||
static constexpr int max_par = 16; | ||
static constexpr int tile_k_size = tile_size; | ||
static constexpr int tile_n_size = tile_k_size * 4; | ||
|
||
__device__ inline constexpr int ceildiv(int a, int b) { | ||
return (a + b - 1) / b; | ||
} | ||
|
||
// Predicated asynchronous global->shared copy; used for inputs A where we apply | ||
// predication to handle batchsizes that are not multiples of 16. | ||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, | ||
bool pred = true) { | ||
const int BYTES = 16; | ||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); | ||
asm volatile( | ||
"{\n" | ||
" .reg .pred p;\n" | ||
" setp.ne.b32 p, %0, 0;\n" | ||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n" | ||
"}\n" ::"r"((int)pred), | ||
"r"(smem), "l"(glob_ptr), "n"(BYTES)); | ||
} | ||
|
||
// Asynchronous global->shared copy | ||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { | ||
const int BYTES = 16; | ||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); | ||
asm volatile( | ||
"{\n" | ||
" cp.async.cg.shared.global [%0], [%1], %2;\n" | ||
"}\n" ::"r"(smem), | ||
"l"(glob_ptr), "n"(BYTES)); | ||
} | ||
|
||
// Async copy fence. | ||
__device__ inline void cp_async_fence() { | ||
asm volatile("cp.async.commit_group;\n" ::); | ||
} | ||
|
||
// Wait until at most `n` async copy stages are still pending. | ||
template <int n> | ||
__device__ inline void cp_async_wait() { | ||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); | ||
} | ||
|
||
// Wait until barrier reaches `count`, then lock for current threadblock. | ||
__device__ inline void barrier_acquire(int* lock, int count) { | ||
if (threadIdx.x == 0) { | ||
int state = -1; | ||
do | ||
// Guarantee that subsequent writes by this threadblock will be visible | ||
// globally. | ||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" | ||
: "=r"(state) | ||
: "l"(lock)); | ||
while (state != count); | ||
} | ||
__syncthreads(); | ||
} | ||
|
||
// Release barrier and increment visitation count. | ||
__device__ inline void barrier_release(int* lock, bool reset = false) { | ||
__syncthreads(); | ||
if (threadIdx.x == 0) { | ||
if (reset) { | ||
lock[0] = 0; | ||
return; | ||
} | ||
int val = 1; | ||
// Make sure that all writes since acquiring this barrier are visible | ||
// globally, while releasing the barrier. | ||
asm volatile("fence.acq_rel.gpu;\n"); | ||
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" | ||
: | ||
: "l"(lock), "r"(val)); | ||
} | ||
} | ||
|
||
// Instances of `Vec` are used to organize groups of >>registers<<, as needed | ||
// for instance as inputs to tensor core operations. Consequently, all | ||
// corresponding index accesses must be compile-time constants, which is why we | ||
// extensively use `#pragma unroll` throughout the kernel code to guarantee | ||
// this. | ||
template <typename T, int n> | ||
struct Vec { | ||
T elems[n]; | ||
__device__ T& operator[](int i) { return elems[i]; } | ||
}; | ||
|
||
using I4 = Vec<int, 4>; | ||
|
||
} // namespace marlin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
|
||
#ifndef _data_types_cuh | ||
#define _data_types_cuh | ||
#include "marlin.cuh" | ||
#include <cuda_fp16.h> | ||
#include <cuda_bf16.h> | ||
|
||
namespace marlin { | ||
|
||
template <typename scalar_t> | ||
class ScalarType {}; | ||
|
||
template <> | ||
class ScalarType<half> { | ||
public: | ||
using scalar_t = half; | ||
using scalar_t2 = half2; | ||
|
||
// Matrix fragments for tensor core instructions; their precise layout is | ||
// documented here: | ||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type | ||
using FragA = Vec<half2, 4>; | ||
using FragB = Vec<half2, 2>; | ||
using FragC = Vec<float, 4>; | ||
using FragS = Vec<half2, 1>; | ||
using FragZP = Vec<half2, 4>; | ||
|
||
static __device__ float inline num2float(const half x) { | ||
return __half2float(x); | ||
} | ||
|
||
static __device__ half2 inline num2num2(const half x) { | ||
return __half2half2(x); | ||
} | ||
|
||
static __device__ half2 inline nums2num2(const half x1, const half x2) { | ||
return __halves2half2(x1, x2); | ||
} | ||
|
||
static __host__ __device__ half inline float2num(const float x) { | ||
return __float2half(x); | ||
} | ||
}; | ||
|
||
template <> | ||
class ScalarType<nv_bfloat16> { | ||
public: | ||
using scalar_t = nv_bfloat16; | ||
using scalar_t2 = nv_bfloat162; | ||
|
||
using FragA = Vec<nv_bfloat162, 4>; | ||
using FragB = Vec<nv_bfloat162, 2>; | ||
using FragC = Vec<float, 4>; | ||
using FragS = Vec<nv_bfloat162, 1>; | ||
using FragZP = Vec<nv_bfloat162, 4>; | ||
|
||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 | ||
static __device__ float inline num2float(const nv_bfloat16 x) { | ||
return __bfloat162float(x); | ||
} | ||
|
||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { | ||
return __bfloat162bfloat162(x); | ||
} | ||
|
||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, | ||
const nv_bfloat16 x2) { | ||
return __halves2bfloat162(x1, x2); | ||
} | ||
|
||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) { | ||
return __float2bfloat16(x); | ||
} | ||
#endif | ||
}; | ||
|
||
} // namespace marlin | ||
|
||
#endif |
Oops, something went wrong.