Skip to content

Commit

Permalink
Merge pull request apache#61 from octoml/mlc-serve-v0.2.0-feature-fp8
Browse files Browse the repository at this point in the history
[FP8] Bring fp8 support to OLLM tracking branch
  • Loading branch information
jroesch authored Mar 21, 2024
2 parents 1f52f52 + a502c40 commit 4f5dd41
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 1 deletion.
5 changes: 5 additions & 0 deletions cmake/modules/CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ if(USE_CUDA)
tvm_file_glob(GLOB RELAX_VM_CUDA_BUILTIN_SRC_CC src/runtime/relax_vm/cuda/*.cc)
list(APPEND RUNTIME_SRCS ${RELAX_VM_CUDA_BUILTIN_SRC_CC})

# Add CUDA contrib kernels
tvm_file_glob(GLOB RUNTIME_CUDA_CONTRIB_SRC_CU src/runtime/contrib/cuda/*.cu)
list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_CONTRIB_SRC_CU})
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")

if (USE_CUDA_FP8)
message(STATUS "Build with CUDA FP8 support")
add_definitions(-DUSE_CUDA_FP8=1)
Expand Down
135 changes: 135 additions & 0 deletions src/runtime/contrib/cuda/reduce.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file Externally defined CUDA kernels for use in TVM runtime
*/

#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <dlpack/dlpack.h>
#include <tvm/runtime/registry.h>

#include "../../cuda/cuda_common.h"

namespace tvm {
namespace contrib {

using namespace runtime;

template <typename T>
__device__ T device_max(T a, T b) {
return max(a, b);
}

template <>
__device__ __half device_max(__half a, __half b) {
return __hmax(a, b);
}

template <typename T>
__device__ T device_abs(T a) {
return abs(a);
}

template <>
__device__ __half device_abs(__half a) {
return __habs(a);
}

template <typename T>
__inline__ __device__ T warp_reduce_max(T val) {
for (int offset = warpSize / 2; offset > 0; offset >>= 1) {
val = device_max(val, __shfl_down_sync(0xffffffff, val, offset));
}
return val;
}

// Single block reduce, assumes size % 1024 == 0
template <typename T>
__global__ void max_reduce_kernel_single_block(T* input, T* output, int size) {
__shared__ T shared[32];

int tid = threadIdx.x;
T max_val = std::numeric_limits<T>::lowest();

// Step 1: Each thread reduces across the elements it owns
for (int i = tid; i < size; i += blockDim.x) {
// use __hmax for float16
max_val = device_max(max_val, device_abs(input[i]));
}

// Step 2: Perform reduce across warps
max_val = warp_reduce_max(max_val);

// Step 3: Write the reduced value from each warp to shared memory
if (tid % warpSize == 0) {
shared[tid / warpSize] = max_val;
}
__syncthreads();

// Step 4: Perform a final reduction in the first warp across shared values
if (tid < warpSize) {
max_val = shared[tid];
max_val = warp_reduce_max(max_val);
if (tid == 0) {
*output = max_val;
}
}
}

template __global__ void max_reduce_kernel_single_block<float>(float* input, float* output,
int size);
template __global__ void max_reduce_kernel_single_block<__half>(__half* input, __half* output,
int size);
template <typename T>
void LaunchMaxReduceKernelSingleBlock(DLTensor* input, DLTensor* output, int size) {
T* input_ptr = static_cast<T*>(input->data);
T* output_ptr = static_cast<T*>(output->data);

int blocks = 1;
int threads = 1024;
max_reduce_kernel_single_block<T><<<blocks, threads>>>(input_ptr, output_ptr, size);
}

TVM_REGISTER_GLOBAL("tvm.contrib.cuda.reduce_max_abs").set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* input = args[0];
DLTensor* output = args[1];

int size = 1;
for (int i = 0; i < input->ndim; ++i) {
size *= input->shape[i];
}

CHECK_EQ(size % 1024, 0) << "tvm.contrib.cuda.reduce_max_abs currently only supports reducing "
"tensors that are an even factor of 1024 elements";

auto dtype = DLDataType2String(input->dtype);

if (dtype == "float32") {
LaunchMaxReduceKernelSingleBlock<float>(input, output, size);
} else if (dtype == "float16") {
LaunchMaxReduceKernelSingleBlock<__half>(input, output, size);
} else {
LOG(FATAL) << "Unsupported input dtype: " << dtype;
}
});

} // namespace contrib
} // namespace tvm
28 changes: 28 additions & 0 deletions src/runtime/contrib/cutlass/fp16_group_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,37 @@ void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDAr
static_cast<ElementC*>(out->data), stream);
}

template <typename ElementA, typename ElementB, typename ElementC>
void tvm_cutlass_group_gemm_sm90_scale(NDArray x, NDArray weight, NDArray indptr, NDArray workspace,
NDArray alpha, NDArray out) {
// Workspace is used for storing device-side group gemm arguments and cutlass internal workspace.
// Recommened size is 4MB.
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
CHECK_EQ(x->ndim, 2);
CHECK_EQ(weight->ndim, 3);
CHECK_EQ(indptr->ndim, 1);
CHECK_EQ(workspace->ndim, 1);
CHECK_EQ(alpha->dtype.code, kDLFloat);
CHECK_EQ(alpha->dtype.bits, 32);
CHECK_EQ(out->ndim, 2);
int num_groups = weight->shape[0];
int n = weight->shape[1];
int k = weight->shape[2];
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
cutlass_group_gemm(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
workspace->shape[0], n, k, num_groups, static_cast<float*>(alpha->data),
static_cast<float*>(nullptr), static_cast<ElementC*>(out->data), stream);
}

TVM_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90")
.set_body_typed(tvm_cutlass_group_gemm_sm90<cutlass::half_t, cutlass::half_t, cutlass::half_t>);

TVM_REGISTER_GLOBAL("cutlass.group_gemm_scale_fp16_sm90")
.set_body_typed(
tvm_cutlass_group_gemm_sm90_scale<cutlass::half_t, cutlass::half_t, cutlass::half_t>);

} // namespace runtime
} // namespace tvm

Expand Down
30 changes: 30 additions & 0 deletions src/runtime/contrib/cutlass/fp8_group_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,29 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArr
static_cast<ElementC*>(out->data), stream);
}

template <typename ElementA, typename ElementB, typename ElementC>
void tvm_cutlass_fp8_group_gemm_host_scale(NDArray x, NDArray weight, NDArray indptr,
NDArray workspace, double alpha, NDArray out) {
// Workspace is used for storing device-side group gemm arguments and cutlass internal workspace.
// Recommened size is 4MB.
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
CHECK_EQ(x->ndim, 2);
CHECK_EQ(weight->ndim, 3);
CHECK_EQ(indptr->ndim, 1);
CHECK_EQ(workspace->ndim, 1);
CHECK_EQ(out->ndim, 2);
int num_groups = weight->shape[0];
int n = weight->shape[1];
int k = x->shape[1];
double beta = 0.0;
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
cutlass_group_gemm(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
workspace->shape[0], n, k, num_groups, static_cast<float>(alpha),
static_cast<float>(beta), static_cast<ElementC*>(out->data), stream);
}

TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16")
.set_body_typed(
tvm_cutlass_fp8_group_gemm<cutlass::float_e5m2_t, cutlass::float_e5m2_t, cutlass::half_t>);
Expand All @@ -77,6 +100,13 @@ TVM_REGISTER_GLOBAL("cutlass.group_gemm_e4m3_e4m3_fp16")
.set_body_typed(
tvm_cutlass_fp8_group_gemm<cutlass::float_e4m3_t, cutlass::float_e4m3_t, cutlass::half_t>);

TVM_REGISTER_GLOBAL("cutlass.group_gemm_e4m3_e5m2_fp16")
.set_body_typed(
tvm_cutlass_fp8_group_gemm<cutlass::float_e4m3_t, cutlass::float_e5m2_t, cutlass::half_t>);

TVM_REGISTER_GLOBAL("cutlass.group_gemm_e4m3_e5m2_fp16_host_scale")
.set_body_typed(tvm_cutlass_fp8_group_gemm_host_scale<cutlass::float_e4m3_t,
cutlass::float_e5m2_t, cutlass::half_t>);
} // namespace runtime
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2232,7 +2232,7 @@ llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm)
auto dwarf_type = [&]() -> llvm::dwarf::TypeKind {
if (dtype.is_bool()) {
return llvm::dwarf::DW_ATE_boolean;
} else if (dtype.is_float()) {
} else if (dtype.is_float() || dtype.is_float8()) {
return llvm::dwarf::DW_ATE_float;
} else if (dtype.is_int()) {
return llvm::dwarf::DW_ATE_signed;
Expand Down
20 changes: 20 additions & 0 deletions tests/python/contrib/test_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
import tvm
import tvm.testing


def test_reduce_max_abs():
target = "cuda"
dev = tvm.device(target, 0)
x_shape = (4, 4096)
dtype = "float16"
x = tvm.nd.array(np.random.uniform(-2, 1.4, x_shape).astype(dtype), dev)
scalar = tvm.nd.array(np.array([0], dtype=dtype), dev)

reduce = tvm._ffi.get_global_func("tvm.contrib.cuda.reduce_max_abs")
reduce(x, scalar)
tvm.testing.assert_allclose(scalar.numpy(), np.array([2], dtype=dtype))


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 4f5dd41

Please sign in to comment.