Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add gemma_rmsnorm and gemma_fused_add_rmsnorm #477

Merged
merged 2 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,190 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
return cudaSuccess;
}

template <uint32_t VEC_SIZE, typename T>
__global__ void GemmaRMSNormKernel(T* __restrict__ input, T* __restrict__ weight,
T* __restrict__ output, const uint32_t d, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
const uint32_t num_warps = blockDim.y;
const uint32_t thread_id = tx + ty * warp_size;
const uint32_t num_threads = num_warps * warp_size;
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
extern __shared__ float smem[];

float sum_sq = 0.f;

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
sum_sq += float(input_vec[j]) * float(input_vec[j]);
}
}

// first, warp reduce sum
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}

smem[ty] = sum_sq;
__syncthreads();
// then, cross warp reduce sum using only the first warp
if (ty == 0) {
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}
smem[0] = sum_sq;
}
__syncthreads();

float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> math::rsqrt(smem[0] / (float(d) + eps));


for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> output_vec;
input_vec.fill(0.f);
weight_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] = float(input_vec[j]) * rms_rcp * (1.0f + float(weight_vec[j]));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}
}

template <typename T>
cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
float eps = 1e-5, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
void* args[] = {&input, &weight, &output, &d, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = GemmaRMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
}

template <uint32_t VEC_SIZE, typename T>
__global__ void GemmaFusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual,
T* __restrict__ weight, const uint32_t d, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
const uint32_t num_warps = blockDim.y;
const uint32_t thread_id = tx + ty * warp_size;
const uint32_t num_threads = num_warps * warp_size;
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
extern __shared__ float smem[];

float sum_sq = 0.f;

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0.f);
vec_t<T, VEC_SIZE> residual_vec;
residual_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
float x = float(input_vec[j]);
x += float(residual_vec[j]);
sum_sq += x * x;
residual_vec[j] = (T)x;
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}

// first, warp reduce sum
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}

smem[ty] = sum_sq;
__syncthreads();
// then, cross warp reduce sum using only the first warp
if (ty == 0) {
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}
smem[0] = sum_sq;
}
__syncthreads();

float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> residual_vec;
input_vec.fill(0.f);
weight_vec.fill(0.f);
residual_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
input_vec[j] = float(residual_vec[j]) * rms_rcp * (1.0f + float(weight_vec[j]));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}
}

template <typename T>
cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
float eps = 1e-5, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
void* args[] = {&input, &residual, &weight, &d, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = GemmaFusedAddRMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});

return cudaSuccess;
}

} // namespace norm

} // namespace flashinfer
Expand Down
3 changes: 3 additions & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Speculative sampling from sequence of probabilities");
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization");
m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization");
m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm,
"Gemma Fused add root mean square normalization");
m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul");
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul");
m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul");
Expand Down
5 changes: 5 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps);
void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
double eps);

torch::Tensor gemma_rmsnorm(torch::Tensor input, torch::Tensor weight, double eps);

void gemma_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
double eps);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
Expand Down
54 changes: 54 additions & 0 deletions python/csrc/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,57 @@ void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tenso
return true;
});
}

torch::Tensor gemma_rmsnorm(torch::Tensor input, torch::Tensor weight, double eps) {
CHECK_INPUT(input);
CHECK_INPUT(weight);
auto device = input.device();
CHECK_EQ(weight.device(), device);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
CHECK_EQ(input.size(1), weight.size(0));
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto output = torch::empty_like(input);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::GemmaRMSNorm(static_cast<c_type*>(input.data_ptr()),
static_cast<c_type*>(weight.data_ptr()),
static_cast<c_type*>(output.data_ptr()), batch_size,
hidden_size, eps, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"GemmaRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
return true;
});
return output;
}

void gemma_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
double eps) {
CHECK_INPUT(input);
CHECK_INPUT(residual);
CHECK_INPUT(weight);
auto device = input.device();
CHECK_EQ(residual.device(), device);
CHECK_EQ(weight.device(), device);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
CHECK_EQ(input.size(0), residual.size(0));
CHECK_EQ(input.size(1), residual.size(1));
CHECK_EQ(input.size(1), weight.size(0));
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::GemmaFusedAddRMSNorm(
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, eps,
torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "GemmaFusedAddRMSNorm failed with error code " +
std::string(cudaGetErrorString(status)));
return true;
});
}
2 changes: 1 addition & 1 deletion python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
single_decode_with_kv_cache,
)
from .gemm import SegmentGEMMWrapper, bmm_fp8
from .norm import fused_add_rmsnorm, rmsnorm
from .norm import fused_add_rmsnorm, gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
from .page import append_paged_kv_cache
from .prefill import (
BatchPrefillWithPagedKVCacheWrapper,
Expand Down
39 changes: 39 additions & 0 deletions python/flashinfer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,42 @@ def fused_add_rmsnorm(
Epsilon for numerical stability.
"""
_kernels.fused_add_rmsnorm(input, residual, weight, eps)


def gemma_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
r"""Gemma Root mean square normalization.

Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.

Returns
-------
output: torch.Tensor
Gemma Normalized tensor, shape (batch_size, hidden_size).
"""
return _kernels.gemma_rmsnorm(input, weight, eps)


def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
):
r"""Gemma Fused add root mean square normalization.

Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
"""
_kernels.gemma_fused_add_rmsnorm(input, residual, weight, eps)
59 changes: 59 additions & 0 deletions python/tests/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,28 @@ def _norm(x):
return output * w


def gemma_rms_norm(x, w, eps=1e-6):
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w)
x = x.to(orig_dtype)
return x


def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6):
orig_dtype = x.dtype
x = x + residual
residual = x
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w)
x = x.to(orig_dtype)
return x, residual


def fused_add_rms_norm(x, residual, weight, eps):
orig_dtype = x.dtype
x = x.to(torch.float32)
Expand Down Expand Up @@ -76,3 +98,40 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):

torch.testing.assert_close(x_fused, x_native, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_gemma_norm(batch_size, hidden_size, dtype):
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
w = torch.randn(hidden_size).to(0).to(dtype)

y_ref = gemma_rms_norm(x, w)
y = flashinfer.norm.gemma_rmsnorm(x, w)

numpy.testing.assert_allclose(
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3
)


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
eps = 1e-6

x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")

x_native, residual_native = gemma_fused_add_rms_norm(
x.clone(), residual.clone(), weight, eps
)

x_fused = x.clone()
residual_fused = residual.clone()
flashinfer.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)

torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)