Skip to content

Commit

Permalink
feat: add group gemm operators (#282)
Browse files Browse the repository at this point in the history
First step towards #199 .

Group gemm should also be helpful for MoE.
  • Loading branch information
yzh119 authored Jun 5, 2024
1 parent 7aadc0d commit e08ba42
Show file tree
Hide file tree
Showing 24 changed files with 764 additions and 37 deletions.
4 changes: 4 additions & 0 deletions docs/api/python/decode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,9 @@ Batch Decoding
.. autoclass:: BatchDecodeWithPagedKVCacheWrapper
:members:

.. automethod:: __init__

.. autoclass:: CUDAGraphBatchDecodeWithPagedKVCacheWrapper
:members:

.. automethod:: __init__
13 changes: 13 additions & 0 deletions docs/api/python/group_gemm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. _apigroup_gemm:

flashinfer.group_gemm
=====================

This module provides a set of functions to group GEMM operations.

.. currentmodule:: flashinfer.group_gemm

.. autoclass:: SegmentGEMMWrapper
:members:

.. automethod:: __init__
5 changes: 4 additions & 1 deletion docs/api/python/prefill.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ Batch Prefill/Append Attention
.. autoclass:: BatchPrefillWithPagedKVCacheWrapper
:members:

.. automethod:: __init__

.. autoclass:: BatchPrefillWithRaggedKVCacheWrapper
:members:


.. automethod:: __init__
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ FlashInfer is a library for Language Languages Models that provides high-perform
api/python/cascade
api/python/page
api/python/sampling
api/python/group_gemm
api/python/norm
44 changes: 44 additions & 0 deletions include/flashinfer/allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_ALLOCATOR_H_
#define FLASHINFER_ALLOCATOR_H_

#include <memory>
#include <stdexcept>

namespace flashinfer {

struct AlignedAllocator {
void* ptr;
size_t space;
AlignedAllocator(void* buf, size_t space) : ptr(buf), space(space) {}
template <typename T>
T* aligned_alloc(size_t size, size_t alignment) {
if (std::align(alignment, size, ptr, space)) {
T* result = reinterpret_cast<T*>(ptr);
ptr = (char*)ptr + size;
space -= size;
return result;
} else {
throw std::runtime_error("RuntimeError: Out of workspace memory in AlignedAlloactor");
}
return nullptr;
}
};

} // namespace flashinfer

#endif // FLASHINFER_ALLOCATOR_H_
27 changes: 4 additions & 23 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_HANDLER_CUH_
#define FLASHINFER_HANDLER_CUH_
#ifndef FLASHINFER_ATTENTION_HANDLER_CUH_
#define FLASHINFER_ATTENTION_HANDLER_CUH_

#include <algorithm>
#include <cstddef>
#include <memory>
#include <sstream>
#include <unordered_map>
#include <vector>

#include "../allocator.h"
#include "../page.cuh"
#include "../pos_enc.cuh"
#include "../utils.cuh"
Expand Down Expand Up @@ -241,24 +240,6 @@ cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo(
return cudaSuccess;
}

struct AlignedAllocator {
void* ptr;
size_t space;
AlignedAllocator(void* buf, size_t space) : ptr(buf), space(space) {}
template <typename T>
T* aligned_alloc(size_t size, size_t alignment) {
if (std::align(alignment, size, ptr, space)) {
T* result = reinterpret_cast<T*>(ptr);
ptr = (char*)ptr + size;
space -= size;
return result;
} else {
throw std::runtime_error("RuntimeError: Out of workspace memory in AlignedAlloactor");
}
return nullptr;
}
};

class BatchDecodeHandler {
public:
template <typename DType>
Expand Down Expand Up @@ -584,4 +565,4 @@ class BatchPrefillHandler {
};

} // namespace flashinfer
#endif // FLASHINFER_HANDLER_CUH_
#endif // FLASHINFER_ATTENTION_HANDLER_CUH_
65 changes: 65 additions & 0 deletions include/flashinfer/group_gemm/group_gemm_cutlass.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_GROUP_GEMM_CUTLASS_CUH_
#define FLASHINFER_GROUP_GEMM_CUTLASS_CUH_

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"

namespace flashinfer {

namespace group_gemm {

template <typename T>
struct cutlass_dtype {
using type = T;
};

template <>
struct cutlass_dtype<half> {
using type = cutlass::half_t;
};

template <>
struct cutlass_dtype<nv_bfloat16> {
using type = cutlass::bfloat16_t;
};

template <typename T>
__global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_problems, T** ptr_x,
T** ptr_w, T** ptr_y, int64_t* ld_x, int64_t* ld_w,
int64_t* ld_y, T* x, T* w, T* y, int64_t* xy_indptr,
int64_t* w_indices, size_t d_in, size_t d_out,
bool w_column_major) {
int i = blockIdx.x;
int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out;
all_problems[i] = cutlass::gemm::GemmCoord(m, n, k);
ptr_w[i] = w + (w_indices == nullptr ? i : w_indices[i]) * d_in * d_out;
ptr_x[i] = x + xy_indptr[i] * d_in;
ptr_y[i] = y + xy_indptr[i] * d_out;
ld_x[i] = k; // m * k
ld_w[i] = w_column_major ? k : n; // k * n if column major, n * k if row major
ld_y[i] = n; // m * n
}

} // namespace group_gemm

} // namespace flashinfer

#endif // FLASHINFER_GROUP_GEMM_CUTLASS_WRAPPER_CUH_
29 changes: 29 additions & 0 deletions include/flashinfer/group_gemm/group_gemm_lora.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_GROUP_GEMM_LORA_CUH_
#define FLASHINFER_GROUP_GEMM_LORA_CUH_

namespace flashinfer {

namespace group_gemm {

// TODO(Zihao): port punica's sgmv kernel

} // namespace group_gemm

} // namespace flashinfer

#endif // FLASHINFER_GROUP_GEMM_LORA_CUH_
29 changes: 29 additions & 0 deletions include/flashinfer/group_gemm/group_gemv.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_GROUP_GEMV_CUH_
#define FLASHINFER_GROUP_GEMV_CUH_

namespace flashinfer {

namespace group_gemm {

// TODO(Zihao): port punica's bgmv kernel

} // namespace group_gemm

} // namespace flashinfer

#endif // FLASHINFER_GROUP_GEMV_CUH_
66 changes: 66 additions & 0 deletions include/flashinfer/group_gemm/handler.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_GROUP_GEMM_HANDLER_CUH_
#define FLASHINFER_GROUP_GEMM_HANDLER_CUH_

#include <cstddef>

#include "../allocator.h"
#include "../utils.cuh"
#include "group_gemm_cutlass.cuh"
#include "group_gemm_lora.cuh"
#include "group_gemv.cuh"

namespace flashinfer {

namespace group_gemm {

enum class GroupGEMMKernelConfig {
kGeneral, // large d_in, d_out
kShrink, // large d_in, small d_out
kExpand, // small d_in, large d_out
};

class CutlassSegmentGEMMHandler {
public:
void RegisterWorkspace(void* buffer, size_t size) {
buffer_ = buffer;
workspace_size_in_bytes_ = size;
}

void* GetWorkspace() const { return buffer_; }

size_t GetWorkspaceSizeInBytes() const { return workspace_size_in_bytes_; }

cudaStream_t GetCUDAStream() const { return stream_; }

void SetCUDAStream(cudaStream_t stream) { stream_ = stream; }

CutlassSegmentGEMMHandler() {}

~CutlassSegmentGEMMHandler() {}

private:
void* buffer_;
size_t workspace_size_in_bytes_;
cudaStream_t stream_;
};

} // namespace group_gemm

} // namespace flashinfer

#endif // FLASHINFER_GROUP_GEMM_HANDLER_CUH_
Loading

0 comments on commit e08ba42

Please sign in to comment.