-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve cuBLAS performance by dequantizing on the GPU #1065
Merged
+221
−41
Merged
Changes from 2 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
#include <stdint.h> | ||
#include "ggml-cuda.h" | ||
#include <cuda_fp16.h> | ||
|
||
typedef uint16_t ggml_fp16_t; | ||
static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size"); | ||
|
||
#define QK4_0 32 | ||
typedef struct { | ||
float d; // delta | ||
uint8_t qs[QK4_0 / 2]; // nibbles / quants | ||
} block_q4_0; | ||
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); | ||
|
||
#define QK4_1 32 | ||
typedef struct { | ||
float d; // delta | ||
float m; // min | ||
uint8_t qs[QK4_1 / 2]; // nibbles / quants | ||
} block_q4_1; | ||
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); | ||
|
||
#define QK4_2 16 | ||
typedef struct { | ||
__half d; // delta | ||
uint8_t qs[QK4_2 / 2]; // nibbles / quants | ||
} block_q4_2; | ||
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding"); | ||
|
||
|
||
static __global__ void dequantize_block_q4_0(const void * vx, float * y) { | ||
const block_q4_0 * x = (const block_q4_0 *) vx; | ||
|
||
int i = blockIdx.x; | ||
|
||
const float d = x[i].d; | ||
|
||
const uint8_t * pp = x[i].qs; | ||
|
||
for (int l = 0; l < QK4_0; l += 2) { | ||
const uint8_t vi = pp[l/2]; | ||
|
||
const int8_t vi0 = vi & 0xf; | ||
const int8_t vi1 = vi >> 4; | ||
|
||
const float v0 = (vi0 - 8)*d; | ||
const float v1 = (vi1 - 8)*d; | ||
|
||
y[i*QK4_0 + l + 0] = v0; | ||
y[i*QK4_0 + l + 1] = v1; | ||
} | ||
} | ||
|
||
static __global__ void dequantize_block_q4_1(const void * vx, float * y) { | ||
const block_q4_1 * x = (const block_q4_1 *) vx; | ||
|
||
int i = blockIdx.x; | ||
|
||
const float d = x[i].d; | ||
const float m = x[i].m; | ||
|
||
const uint8_t * pp = x[i].qs; | ||
|
||
for (int l = 0; l < QK4_1; l += 2) { | ||
const uint8_t vi = pp[l/2]; | ||
|
||
const int8_t vi0 = vi & 0xf; | ||
const int8_t vi1 = vi >> 4; | ||
|
||
const float v0 = vi0*d + m; | ||
const float v1 = vi1*d + m; | ||
|
||
y[i*QK4_1 + l + 0] = v0; | ||
y[i*QK4_1 + l + 1] = v1; | ||
} | ||
} | ||
|
||
static __global__ void dequantize_block_q4_2(const void * vx, float * y) { | ||
const block_q4_2 * x = (const block_q4_2 *) vx; | ||
|
||
int i = blockIdx.x; | ||
|
||
const float d = x[i].d; | ||
|
||
const uint8_t * pp = x[i].qs; | ||
|
||
for (int l = 0; l < QK4_2; l += 2) { | ||
const uint8_t vi = pp[l/2]; | ||
|
||
const int8_t vi0 = vi & 0xf; | ||
const int8_t vi1 = vi >> 4; | ||
|
||
const float v0 = (vi0 - 8)*d; | ||
const float v1 = (vi1 - 8)*d; | ||
|
||
y[i*QK4_2 + l + 0] = v0; | ||
y[i*QK4_2 + l + 1] = v1; | ||
} | ||
} | ||
|
||
extern "C" { | ||
__host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k) { | ||
const int nb = k / QK4_0; | ||
dequantize_block_q4_0<<<nb, 1>>>(vx, y); | ||
} | ||
|
||
__host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k) { | ||
const int nb = k / QK4_1; | ||
dequantize_block_q4_1<<<nb, 1>>>(vx, y); | ||
} | ||
|
||
__host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k) { | ||
const int nb = k / QK4_2; | ||
dequantize_block_q4_2<<<nb, 1>>>(vx, y); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k); | ||
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k); | ||
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k); | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there was a discussion somewhere recently about splitting out the accel specific code into dedicated .c files. what was the state on that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to keep all the cuda code in
ggml-cuda.cu
to avoid having to compile ggml with nvcc, but otherwise nothing changed.