Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
leimao committed Dec 12, 2023
1 parent 7cef627 commit f59c267
Show file tree
Hide file tree
Showing 10 changed files with 929 additions and 310 deletions.
17 changes: 17 additions & 0 deletions include/cuda_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,35 @@ void launch_gemm_kernel_v02(size_t m, size_t n, size_t k, T const* alpha,
T const* beta, T* C, size_t ldc,
cudaStream_t stream);

template <typename T>
void launch_gemm_kernel_v02_vectorized(size_t m, size_t n, size_t k,
T const* alpha, T const* A, size_t lda,
T const* B, size_t ldb, T const* beta,
T* C, size_t ldc, cudaStream_t stream);

template <typename T>
void launch_gemm_kernel_v03(size_t m, size_t n, size_t k, T const* alpha,
T const* A, size_t lda, T const* B, size_t ldb,
T const* beta, T* C, size_t ldc,
cudaStream_t stream);

template <typename T>
void launch_gemm_kernel_v03_vectorized(size_t m, size_t n, size_t k,
T const* alpha, T const* A, size_t lda,
T const* B, size_t ldb, T const* beta,
T* C, size_t ldc, cudaStream_t stream);
template <typename T>
void launch_gemm_kernel_v04(size_t m, size_t n, size_t k, T const* alpha,
T const* A, size_t lda, T const* B, size_t ldb,
T const* beta, T* C, size_t ldc,
cudaStream_t stream);

template <typename T>
void launch_gemm_kernel_v04_vectorized(size_t m, size_t n, size_t k,
T const* alpha, T const* A, size_t lda,
T const* B, size_t ldb, T const* beta,
T* C, size_t ldc, cudaStream_t stream);

template <typename T>
void launch_gemm_kernel_v05(size_t m, size_t n, size_t k, T const* alpha,
T const* A, size_t lda, T const* B, size_t ldb,
Expand Down
41 changes: 29 additions & 12 deletions src/02_2d_block_tiling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,19 @@ __global__ void gemm_v02(size_t m, size_t n, size_t k, T alpha, T const* A,
{
val = A[A_row_idx * lda + A_col_idx];
}
if (A_thread_block_tile_row_idx < BLOCK_TILE_SIZE_Y &&
A_thread_block_tile_col_idx < BLOCK_TILE_SIZE_K)
{
A_thread_block_tile[A_thread_block_tile_row_idx]
[A_thread_block_tile_col_idx] = val;
}
// This if will slow down the kernel.
// Add static asserts from the host code to guarantee this if is
// always true.
static_assert(BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y % NUM_THREADS ==
0U);
// if (A_thread_block_tile_row_idx < BLOCK_TILE_SIZE_Y &&
// A_thread_block_tile_col_idx < BLOCK_TILE_SIZE_K)
// {
// A_thread_block_tile[A_thread_block_tile_row_idx]
// [A_thread_block_tile_col_idx] = val;
// }
A_thread_block_tile[A_thread_block_tile_row_idx]
[A_thread_block_tile_col_idx] = val;
}
// Load data from B on DRAM to B_thread_block_tile on shared memory.
#pragma unroll
Expand Down Expand Up @@ -95,12 +102,19 @@ __global__ void gemm_v02(size_t m, size_t n, size_t k, T alpha, T const* A,
{
val = B[B_row_idx * ldb + B_col_idx];
}
if (B_thread_block_tile_row_idx < BLOCK_TILE_SIZE_K &&
B_thread_block_tile_col_idx < BLOCK_TILE_SIZE_X)
{
B_thread_block_tile[B_thread_block_tile_row_idx]
[B_thread_block_tile_col_idx] = val;
}
// This if will slow down the kernel.
// Add static asserts from the host code to guarantee this if is
// always true.
static_assert(BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K % NUM_THREADS ==
0U);
// if (B_thread_block_tile_row_idx < BLOCK_TILE_SIZE_K &&
// B_thread_block_tile_col_idx < BLOCK_TILE_SIZE_X)
// {
// B_thread_block_tile[B_thread_block_tile_row_idx]
// [B_thread_block_tile_col_idx] = val;
// }
B_thread_block_tile[B_thread_block_tile_row_idx]
[B_thread_block_tile_col_idx] = val;
}
__syncthreads();

Expand Down Expand Up @@ -140,6 +154,9 @@ void launch_gemm_kernel_v02(size_t m, size_t n, size_t k, T const* alpha,
constexpr unsigned int BLOCK_TILE_SIZE_X{32U};
constexpr unsigned int BLOCK_TILE_SIZE_Y{32U};
constexpr unsigned int BLOCK_TILE_SIZE_K{32U};
constexpr unsigned int NUM_THREADS{BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y};
static_assert(BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y % NUM_THREADS == 0U);
static_assert(BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K % NUM_THREADS == 0U);
dim3 const block_dim{BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, 1U};
dim3 const grid_dim{
(static_cast<unsigned int>(n) + block_dim.x - 1U) / block_dim.x,
Expand Down
223 changes: 223 additions & 0 deletions src/02_2d_block_tiling_vectorized_memory_access.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
#include <cuda_fp16.h>

#include "cuda_gemm.hpp"
#include "cuda_gemm_utils.hpp"

// GEMM kernel v02.
// Coalesced read and write from global memory.
// We guarantee that matrix A, B, and C are 32 byte aligned.
// This implementation is slower because we waste a lot of threads.
template <typename T, size_t BLOCK_TILE_SIZE_X, size_t BLOCK_TILE_SIZE_Y,
size_t BLOCK_TILE_SIZE_K>
__global__ void gemm_v02_vectorized(size_t m, size_t n, size_t k, T alpha,
T const* A, size_t lda, T const* B,
size_t ldb, T beta, T* C, size_t ldc)
{
// Avoid using blockDim.x * blockDim.y as the number of threads per block.
// Because it is a runtime constant and the compiler cannot optimize the
// loop unrolling based on that.
// Use a compile time constant instead.
constexpr size_t NUM_THREADS{BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y};
size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x};

// Compute the row and column of C that this thread is responsible for.
size_t const C_col_idx{blockIdx.x * blockDim.x + threadIdx.x};
size_t const C_row_idx{blockIdx.y * blockDim.y + threadIdx.y};

// Cache a tile of A and B in shared memory for data reuse.
__shared__ T A_thread_block_tile[BLOCK_TILE_SIZE_Y][BLOCK_TILE_SIZE_K];
__shared__ T B_thread_block_tile[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_X];

size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) /
BLOCK_TILE_SIZE_K};

constexpr size_t NUM_VECTOR_UNITS{sizeof(int4) / sizeof(T)};
static_assert(sizeof(int4) % sizeof(T) == 0U);
static_assert(BLOCK_TILE_SIZE_K % NUM_VECTOR_UNITS == 0U);
static_assert(BLOCK_TILE_SIZE_X % NUM_VECTOR_UNITS == 0U);
constexpr size_t VECTORIZED_BLOCK_TILE_SIZE_K{BLOCK_TILE_SIZE_K /
NUM_VECTOR_UNITS};
static_assert(BLOCK_TILE_SIZE_K % NUM_VECTOR_UNITS == 0U);
constexpr size_t VECTORIZED_BLOCK_TILE_SIZE_X{BLOCK_TILE_SIZE_X /
NUM_VECTOR_UNITS};
static_assert(BLOCK_TILE_SIZE_X % NUM_VECTOR_UNITS == 0U);

T sum{static_cast<T>(0)};
for (size_t thread_block_tile_idx{0U};
thread_block_tile_idx < num_thread_block_tiles;
++thread_block_tile_idx)
{
// Load data from A on DRAM to A_thread_block_tile on shared memory.
#pragma unroll
for (size_t load_idx{0U};
load_idx < (BLOCK_TILE_SIZE_Y * VECTORIZED_BLOCK_TILE_SIZE_K +
NUM_THREADS - 1U) /
NUM_THREADS;
++load_idx)
{
size_t const A_thread_block_tile_row_idx{
(thread_linear_idx + load_idx * NUM_THREADS) /
VECTORIZED_BLOCK_TILE_SIZE_K};
size_t const A_thread_block_tile_col_idx{
(thread_linear_idx + load_idx * NUM_THREADS) %
VECTORIZED_BLOCK_TILE_SIZE_K * NUM_VECTOR_UNITS};
size_t const A_row_idx{blockIdx.y * BLOCK_TILE_SIZE_Y +
A_thread_block_tile_row_idx};
size_t const A_col_idx{thread_block_tile_idx * BLOCK_TILE_SIZE_K +
A_thread_block_tile_col_idx};

// These boundary checks might slow down the kernel to some extent.
// But they guarantee the correctness of the kernel for all
// different GEMM configurations.
int4 A_row_vector_vals{0, 0, 0, 0};
if (A_row_idx < m && A_col_idx < k)
{
A_row_vector_vals = *reinterpret_cast<int4 const*>(
&A[A_row_idx * lda + A_col_idx]);
}
if (A_col_idx + NUM_VECTOR_UNITS > k)
{
// Number of invalid elements in the last vector.
size_t const num_invalid_elements{A_col_idx + NUM_VECTOR_UNITS -
k};
// Mask out the invalid elements.
T* const A_row_vector_vals_ptr{
reinterpret_cast<T*>(&A_row_vector_vals)};
for (size_t i{0U}; i < num_invalid_elements; ++i)
{
A_row_vector_vals_ptr[NUM_VECTOR_UNITS - 1U - i] =
static_cast<T>(0);
}
}
// If this is true, the following if can be removed.
// static_assert(VECTORIZED_BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y %
// NUM_THREADS ==
// 0U);
if (A_thread_block_tile_row_idx < BLOCK_TILE_SIZE_Y &&
A_thread_block_tile_col_idx < BLOCK_TILE_SIZE_K)
{
*reinterpret_cast<int4*>(
&A_thread_block_tile[A_thread_block_tile_row_idx]
[A_thread_block_tile_col_idx]) =
A_row_vector_vals;
}
}
// Load data from B on DRAM to B_thread_block_tile on shared memory.
#pragma unroll
for (size_t load_idx{0U};
load_idx < (BLOCK_TILE_SIZE_K * VECTORIZED_BLOCK_TILE_SIZE_X +
NUM_THREADS - 1U) /
NUM_THREADS;
++load_idx)
{
size_t const B_thread_block_tile_row_idx{
(thread_linear_idx + load_idx * NUM_THREADS) /
VECTORIZED_BLOCK_TILE_SIZE_X};
size_t const B_thread_block_tile_col_idx{
(thread_linear_idx + load_idx * NUM_THREADS) %
VECTORIZED_BLOCK_TILE_SIZE_X * NUM_VECTOR_UNITS};
size_t const B_row_idx{thread_block_tile_idx * BLOCK_TILE_SIZE_K +
B_thread_block_tile_row_idx};
size_t const B_col_idx{blockIdx.x * BLOCK_TILE_SIZE_X +
B_thread_block_tile_col_idx};

// These boundary checks might slow down the kernel to some extent.
// But they guarantee the correctness of the kernel for all
// different GEMM configurations.
int4 B_row_vector_vals{0, 0, 0, 0};
if (B_row_idx < k && B_col_idx < n)
{
B_row_vector_vals = *reinterpret_cast<int4 const*>(
&B[B_row_idx * ldb + B_col_idx]);
}
if (B_col_idx + NUM_VECTOR_UNITS > n)
{
// Number of invalid elements in the last vector.
size_t const num_invalid_elements{B_col_idx + NUM_VECTOR_UNITS -
n};
// Mask out the invalid elements.
T* const B_row_vector_vals_ptr{
reinterpret_cast<T*>(&B_row_vector_vals)};
for (size_t i{0U}; i < num_invalid_elements; ++i)
{
B_row_vector_vals_ptr[NUM_VECTOR_UNITS - 1U - i] =
static_cast<T>(0);
}
}
// If this is true, the following if can be removed.
// static_assert(VECTORIZED_BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K %
// NUM_THREADS ==
// 0U);
if (B_thread_block_tile_row_idx < BLOCK_TILE_SIZE_K &&
B_thread_block_tile_col_idx < BLOCK_TILE_SIZE_X)
{
*reinterpret_cast<int4*>(
&B_thread_block_tile[B_thread_block_tile_row_idx]
[B_thread_block_tile_col_idx]) =
B_row_vector_vals;
}
}
__syncthreads();

#pragma unroll
for (size_t k_i{0U}; k_i < BLOCK_TILE_SIZE_K; ++k_i)
{
// Doing this reulst in 2 TOPS.
// Suppose blockDim.x = blockDim.y = 32.
// Effectively, for a warp, in one iteration, we read the value from
// A_thread_block_tile at the same location on the shared memory
// resulting in a broadcast, we also read 32 values that have no
// bank conflicts from B_thread_block_tile. Even with that, all the
// values have to be read from the shared memory and consequence is
// the shared memory instruction runs very intensively just to
// compute a small number of values using simple arithmetic
// instructions, which is not efficient.
sum += A_thread_block_tile[threadIdx.y][k_i] *
B_thread_block_tile[k_i][threadIdx.x];
}
__syncthreads();
}
if (C_row_idx < m && C_col_idx < n)
{
C[C_row_idx * ldc + C_col_idx] =
alpha * sum + beta * C[C_row_idx * ldc + C_col_idx];
}
}

template <typename T>
void launch_gemm_kernel_v02_vectorized(size_t m, size_t n, size_t k,
T const* alpha, T const* A, size_t lda,
T const* B, size_t ldb, T const* beta,
T* C, size_t ldc, cudaStream_t stream)
{
// Feel free to play with the block tile sizes.
// The algorithm correctness should always be guaranteed.
constexpr unsigned int BLOCK_TILE_SIZE_X{32U};
constexpr unsigned int BLOCK_TILE_SIZE_Y{32U};
constexpr unsigned int BLOCK_TILE_SIZE_K{32U};
constexpr unsigned int NUM_THREADS{BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y};
static_assert(BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y % NUM_THREADS == 0U);
static_assert(BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K % NUM_THREADS == 0U);
dim3 const block_dim{BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, 1U};
dim3 const grid_dim{
(static_cast<unsigned int>(n) + block_dim.x - 1U) / block_dim.x,
(static_cast<unsigned int>(m) + block_dim.y - 1U) / block_dim.y, 1U};
gemm_v02_vectorized<T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y,
BLOCK_TILE_SIZE_K><<<grid_dim, block_dim, 0U, stream>>>(
m, n, k, *alpha, A, lda, B, ldb, *beta, C, ldc);
CHECK_LAST_CUDA_ERROR();
}

// Explicit instantiation.
template void launch_gemm_kernel_v02_vectorized<float>(
size_t m, size_t n, size_t k, float const* alpha, float const* A,
size_t lda, float const* B, size_t ldb, float const* beta, float* C,
size_t ldc, cudaStream_t stream);
template void launch_gemm_kernel_v02_vectorized<double>(
size_t m, size_t n, size_t k, double const* alpha, double const* A,
size_t lda, double const* B, size_t ldb, double const* beta, double* C,
size_t ldc, cudaStream_t stream);
template void launch_gemm_kernel_v02_vectorized<__half>(
size_t m, size_t n, size_t k, __half const* alpha, __half const* A,
size_t lda, __half const* B, size_t ldb, __half const* beta, __half* C,
size_t ldc, cudaStream_t stream);
38 changes: 26 additions & 12 deletions src/03_2d_block_tiling_1d_thread_tiling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,19 @@ __global__ void gemm_v03(size_t m, size_t n, size_t k, T alpha, T const* A,
{
val = A[A_row_idx * lda + A_col_idx];
}
if (A_thread_block_tile_row_idx < BLOCK_TILE_SIZE_Y &&
A_thread_block_tile_col_idx < BLOCK_TILE_SIZE_K)
{
A_thread_block_tile[A_thread_block_tile_row_idx]
[A_thread_block_tile_col_idx] = val;
}
// This if will slow down the kernel.
// Add static asserts from the host code to guarantee this if is
// always true.
static_assert(BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y % NUM_THREADS ==
0U);
// if (A_thread_block_tile_row_idx < BLOCK_TILE_SIZE_Y &&
// A_thread_block_tile_col_idx < BLOCK_TILE_SIZE_K)
// {
// A_thread_block_tile[A_thread_block_tile_row_idx]
// [A_thread_block_tile_col_idx] = val;
// }
A_thread_block_tile[A_thread_block_tile_row_idx]
[A_thread_block_tile_col_idx] = val;
}
// Load data from B on DRAM to B_thread_block_tile on shared memory.
#pragma unroll
Expand Down Expand Up @@ -99,12 +106,19 @@ __global__ void gemm_v03(size_t m, size_t n, size_t k, T alpha, T const* A,
{
val = B[B_row_idx * ldb + B_col_idx];
}
if (B_thread_block_tile_row_idx < BLOCK_TILE_SIZE_K &&
B_thread_block_tile_col_idx < BLOCK_TILE_SIZE_X)
{
B_thread_block_tile[B_thread_block_tile_row_idx]
[B_thread_block_tile_col_idx] = val;
}
// This if will slow down the kernel.
// Add static asserts from the host code to guarantee this if is
// always true.
static_assert(BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K % NUM_THREADS ==
0U);
// if (B_thread_block_tile_row_idx < BLOCK_TILE_SIZE_K &&
// B_thread_block_tile_col_idx < BLOCK_TILE_SIZE_X)
// {
// B_thread_block_tile[B_thread_block_tile_row_idx]
// [B_thread_block_tile_col_idx] = val;
// }
B_thread_block_tile[B_thread_block_tile_row_idx]
[B_thread_block_tile_col_idx] = val;
}
__syncthreads();

Expand Down
Loading

0 comments on commit f59c267

Please sign in to comment.