Skip to content

Commit

Permalink
Improve build time
Browse files Browse the repository at this point in the history
ghstack-source-id: 4b2445fe3c83eef3282643862ed83cef85dc5997
Pull Request resolved: #539
  • Loading branch information
danthe3rd committed Nov 24, 2022
1 parent 103e863 commit 1e679df
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
#include <ATen/ScalarOps.h>
#include <ATen/Tensor.h>
#include <ATen/TensorOperators.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>

#include "kernel_backward.h"

#define DISPATCH_MAXK(func) \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
#include <ATen/ScalarOps.h>
#include <ATen/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>

#include "kernel_forward.h"

#define DISPATCH_BLOCKSIZE(VALUE_HEAD_DIM, FN) \
Expand Down Expand Up @@ -62,6 +68,57 @@
}

namespace {
template <typename scalar_t>
struct TypeTraits;

template <>
struct TypeTraits<cutlass::half_t> {
using scalar_t = cutlass::half_t;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Half;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return at::PackedTensorAccessor32<scalar_t, nDim>(
(scalar_t*)(tensor.data_ptr()),
tensor.sizes().data(),
tensor.strides().data());
}
};

template <>
struct TypeTraits<cutlass::bfloat16_t> {
using scalar_t = cutlass::bfloat16_t;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::BFloat16;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return at::PackedTensorAccessor32<scalar_t, nDim>(
(scalar_t*)(tensor.data_ptr()),
tensor.sizes().data(),
tensor.strides().data());
}
};

template <>
struct TypeTraits<float> {
using scalar_t = float;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Float;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return tensor.packed_accessor32<scalar_t, nDim>();
}
};

/*
There are 2 modes for using this function.
(Mode BMHK) With all the heads having the same seqlen
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \
threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
threadIdx.z == 0) { \
printf(msg "\n", __VA_ARGS__); \
printf(msg "\n", ##__VA_ARGS__); \
}
struct __string_view {
char const* data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
return false; \
}
#else
#include <iostream>
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
std::cerr << #PTR " is not correctly aligned\n"; \
Expand All @@ -97,59 +98,6 @@

namespace gemm_kernel_utils {

#ifdef HAS_PYTORCH
template <typename scalar_t>
struct TypeTraits;

template <>
struct TypeTraits<cutlass::half_t> {
using scalar_t = cutlass::half_t;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Half;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return at::PackedTensorAccessor32<scalar_t, nDim>(
(scalar_t*)(tensor.data_ptr()),
tensor.sizes().data(),
tensor.strides().data());
}
};

template <>
struct TypeTraits<cutlass::bfloat16_t> {
using scalar_t = cutlass::bfloat16_t;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::BFloat16;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return at::PackedTensorAccessor32<scalar_t, nDim>(
(scalar_t*)(tensor.data_ptr()),
tensor.sizes().data(),
tensor.strides().data());
}
};

template <>
struct TypeTraits<float> {
using scalar_t = float;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Float;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return tensor.packed_accessor32<scalar_t, nDim>();
}
};
#endif

template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
return (n + m - 1) / m;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
#pragma once

#include <ATen/ATen.h>
#include <torch/library.h>
#include <cmath>
#include <vector>

#include <cuda_fp16.h>

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/vector.h"
Expand Down Expand Up @@ -676,7 +671,7 @@ struct AttentionBackwardKernel {
}
};

static void __host__ check_supported(Params const& p) {
static bool __host__ check_supported(Params const& p) {
CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
#ifdef HAS_PYTORCH
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#endif

#include <cmath>
#include <vector>

Expand Down

0 comments on commit 1e679df

Please sign in to comment.