Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add gemma_rmsnorm and gemma_fused_add_rmsnorm #477

Merged
merged 2 commits into from
Aug 27, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: add gemma_rmsnorm
  • Loading branch information
zhyncs committed Aug 27, 2024
commit 88df5aed5b53e71e7de6874f8ebd8bbb533aa84c
86 changes: 86 additions & 0 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
@@ -212,6 +212,92 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
return cudaSuccess;
}

template <uint32_t VEC_SIZE, typename T>
__global__ void GemmaRMSNormKernel(T* __restrict__ input, T* __restrict__ weight,
T* __restrict__ output, const uint32_t d, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
const uint32_t num_warps = blockDim.y;
const uint32_t thread_id = tx + ty * warp_size;
const uint32_t num_threads = num_warps * warp_size;
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
extern __shared__ float smem[];

float sum_sq = 0.f;

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
sum_sq += float(input_vec[j]) * float(input_vec[j]);
}
}

// first, warp reduce sum
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}

smem[ty] = sum_sq;
__syncthreads();
// then, cross warp reduce sum using only the first warp
if (ty == 0) {
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}
smem[0] = sum_sq;
}
__syncthreads();

float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> math::rsqrt(smem[0] / (float(d) + eps));


for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> output_vec;
input_vec.fill(0.f);
weight_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] = float(input_vec[j]) * rms_rcp * (1.0f + float(weight_vec[j]));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}
}

template <typename T>
cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
float eps = 1e-5, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
void* args[] = {&input, &weight, &output, &d, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = GemmaRMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
}

} // namespace norm

} // namespace flashinfer
1 change: 1 addition & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
@@ -39,6 +39,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Speculative sampling from sequence of probabilities");
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization");
m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization");
m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul");
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul");
m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul");
2 changes: 2 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
@@ -77,6 +77,8 @@ torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps);
void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
double eps);

torch::Tensor gemma_rmsnorm(torch::Tensor input, torch::Tensor weight, double eps);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
25 changes: 25 additions & 0 deletions python/csrc/norm.cu
Original file line number Diff line number Diff line change
@@ -73,3 +73,28 @@ void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tenso
return true;
});
}

torch::Tensor gemma_rmsnorm(torch::Tensor input, torch::Tensor weight, double eps) {
CHECK_INPUT(input);
CHECK_INPUT(weight);
auto device = input.device();
CHECK_EQ(weight.device(), device);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
CHECK_EQ(input.size(1), weight.size(0));
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto output = torch::empty_like(input);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::GemmaRMSNorm(static_cast<c_type*>(input.data_ptr()),
static_cast<c_type*>(weight.data_ptr()),
static_cast<c_type*>(output.data_ptr()), batch_size,
hidden_size, eps, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"GemmaRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
return true;
});
return output;
}
2 changes: 1 addition & 1 deletion python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@
single_decode_with_kv_cache,
)
from .gemm import SegmentGEMMWrapper, bmm_fp8
from .norm import fused_add_rmsnorm, rmsnorm
from .norm import fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
from .page import append_paged_kv_cache
from .prefill import (
BatchPrefillWithPagedKVCacheWrapper,
20 changes: 20 additions & 0 deletions python/flashinfer/norm.py
Original file line number Diff line number Diff line change
@@ -69,3 +69,23 @@ def fused_add_rmsnorm(
Epsilon for numerical stability.
"""
_kernels.fused_add_rmsnorm(input, residual, weight, eps)


def gemma_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
r"""Gemma Root mean square normalization.

Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.

Returns
-------
output: torch.Tensor
Gemma Normalized tensor, shape (batch_size, hidden_size).
"""
return _kernels.gemma_rmsnorm(input, weight, eps)
25 changes: 25 additions & 0 deletions python/tests/test_norm.py
Original file line number Diff line number Diff line change
@@ -29,6 +29,16 @@ def _norm(x):
return output * w


def gemma_rms_norm(x, w, eps=1e-6):
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w)
x = x.to(orig_dtype)
return x


def fused_add_rms_norm(x, residual, weight, eps):
orig_dtype = x.dtype
x = x.to(torch.float32)
@@ -76,3 +86,18 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):

torch.testing.assert_close(x_fused, x_native, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_gemma_norm(batch_size, hidden_size, dtype):
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
w = torch.randn(hidden_size).to(0).to(dtype)

y_ref = gemma_rms_norm(x, w)
y = flashinfer.norm.gemma_rmsnorm(x, w)

numpy.testing.assert_allclose(
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3
)