Skip to content

Commit

Permalink
llama : add Command R Plus support (#6491)
Browse files Browse the repository at this point in the history
* Add Command R Plus GGUF

* Add Command R Plus GGUF

* Loading works up to LayerNorm2D

* Export new tensors in 1D so they are not quantized.

* Fix embedding layer based on Noeda's example

* Whitespace

* Add line

* Fix unexpected tokens on MPS. Re-add F16 fix. ((Noeda)

* dranger003: Fix block index overflow in CUDA dequantizing.

* Reverted blocked multiplication code as it still has issues and could affect other Llama arches

* export norms as f32

* fix overflow issues during quant and other cleanup

* Type convention

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* dranger003: Fix more int overflow during quant.

---------

Co-authored-by: S <seast@Ss-Mac-Studio.local>
Co-authored-by: S <s@example.com>
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
5 people authored Apr 9, 2024
1 parent e11a899 commit 5dc9dd7
Show file tree
Hide file tree
Showing 16 changed files with 366 additions and 326 deletions.
2 changes: 1 addition & 1 deletion convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def write_tensors(self):
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")):
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
Expand Down
6 changes: 3 additions & 3 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,7 @@ static void ggml_cuda_op_mul_mat_cublas(

// the main device has a larger memory buffer to hold the results from all GPUs
// ldc == nrows of the matrix that cuBLAS writes into
int ldc = id == ctx.device ? ne0 : row_diff;
int64_t ldc = id == ctx.device ? ne0 : row_diff;

const int compute_capability = ggml_cuda_info().devices[id].cc;

Expand Down Expand Up @@ -1377,8 +1377,8 @@ static void ggml_cuda_op_mul_mat(
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];

const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
const int64_t nb2 = dst->nb[2];
const int64_t nb3 = dst->nb[3];

GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
Expand Down
2 changes: 1 addition & 1 deletion ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
// TODO: move to ggml-common.h
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};

typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);


//////////////////////
Expand Down
74 changes: 37 additions & 37 deletions ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
#define CUDA_Q8_0_NE_ALIGN 2048

template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x);

if (i >= k) {
return;
}

const int ib = i/qk; // block index
const int64_t ib = i/qk; // block index
const int iqs = (i%qk)/qr; // quant index
const int iybs = i - i%qk; // y block start index
const int y_offset = qr == 1 ? 1 : qk/2;
Expand All @@ -25,7 +25,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
}

template <bool need_check>
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) {
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
#if __CUDA_ARCH__ >= CC_PASCAL
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;

Expand Down Expand Up @@ -68,13 +68,13 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
template<typename dst_t>
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {

const int i = blockIdx.x;
const int64_t i = blockIdx.x;

// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
const int ib = 8*i + ir;
const int64_t ib = 8*i + ir;
if (ib >= nb32) {
return;
}
Expand All @@ -96,13 +96,13 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
template<typename dst_t>
static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {

const int i = blockIdx.x;
const int64_t i = blockIdx.x;

// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
const int ib = 8*i + ir;
const int64_t ib = 8*i + ir;
if (ib >= nb32) {
return;
}
Expand Down Expand Up @@ -313,14 +313,14 @@ template<typename dst_t>
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q6_K * x = (const block_q6_K *) vx;

const int i = blockIdx.x;
const int64_t i = blockIdx.x;
#if QK_K == 256

// assume 64 threads - this is very slightly better than the one below
const int tid = threadIdx.x;
const int ip = tid/32; // ip is 0 or 1
const int il = tid - 32*ip; // 0...32
const int is = 8*ip + il/16;
const int64_t tid = threadIdx.x;
const int64_t ip = tid/32; // ip is 0 or 1
const int64_t il = tid - 32*ip; // 0...32
const int64_t is = 8*ip + il/16;

dst_t * y = yy + i*QK_K + 128*ip + il;

Expand All @@ -337,9 +337,9 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
#else

// assume 32 threads
const int tid = threadIdx.x;
const int ip = tid/16; // 0 or 1
const int il = tid - 16*ip; // 0...15
const int64_t tid = threadIdx.x;
const int64_t ip = tid/16; // 0 or 1
const int64_t il = tid - 16*ip; // 0...15

dst_t * y = yy + i*QK_K + 16*ip + il;

Expand Down Expand Up @@ -571,12 +571,12 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
#endif

template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}

static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int k, cudaStream_t stream) {
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
if (k % CUDA_Q8_0_NE_ALIGN == 0) {
const bool need_check = false;
Expand All @@ -588,7 +588,7 @@ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half *
}

template<typename dst_t>
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
Expand All @@ -598,7 +598,7 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cu
}

template<typename dst_t>
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
Expand All @@ -608,27 +608,27 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
}

template<typename dst_t>
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
}

template<typename dst_t>
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
}

template<typename dst_t>
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
Expand All @@ -638,7 +638,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu
}

template<typename dst_t>
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
Expand All @@ -648,55 +648,55 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
}

template<typename dst_t>
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
#if QK_K == 64
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
Expand All @@ -706,8 +706,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
}

template <typename src_t, typename dst_t>
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
Expand All @@ -719,7 +719,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
}

template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
Expand Down
2 changes: 1 addition & 1 deletion ggml-cuda/convert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256

template<typename T>
using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream);
using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream);

typedef to_t_cuda_t<float> to_fp32_cuda_t;
typedef to_t_cuda_t<half> to_fp16_cuda_t;
Expand Down
10 changes: 5 additions & 5 deletions ggml-cuda/dequantize.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "common.cuh"

static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;

const dfloat d = x[ib].d;
Expand All @@ -19,7 +19,7 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
#endif // GGML_CUDA_F16
}

static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx;

const dfloat d = __low2half(x[ib].dm);
Expand All @@ -39,7 +39,7 @@ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const in
#endif // GGML_CUDA_F16
}

static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q5_0 * x = (const block_q5_0 *) vx;

const dfloat d = x[ib].d;
Expand All @@ -62,7 +62,7 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
#endif // GGML_CUDA_F16
}

static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx;

const dfloat d = __low2half(x[ib].dm);
Expand All @@ -86,7 +86,7 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
#endif // GGML_CUDA_F16
}

static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q8_0 * x = (const block_q8_0 *) vx;

const dfloat d = x[ib].d;
Expand Down
6 changes: 3 additions & 3 deletions ggml-cuda/dmmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
}
}

static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const half * x = (const half *) vx;

// automatic half -> float type cast if dfloat == float
Expand All @@ -577,7 +577,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
// qk = quantized weights per x block
// qr = number of quantized weights per data value in x block
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;

if (row >= nrows) {
return;
Expand All @@ -598,7 +598,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons

for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid;
const int ib = (row*ncols + col)/qk; // x block index
const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index
const int iqs = (col%qk)/qr; // x quant index
const int iybs = col - col%qk; // y block start index

Expand Down
Loading

0 comments on commit 5dc9dd7

Please sign in to comment.