Skip to content

Commit

Permalink
Update to CUTLASS 3.2
Browse files Browse the repository at this point in the history
ghstack-source-id: 3af42e35dbf32e9b485bcd9e73d2bf5937429c67
Pull Request resolved: https://github.com/fairinternal/xformers/pull/804

__original_commit__ = fairinternal/xformers@87789beb00a3ca8a35d5bcf8d3cd7aee99626c54
  • Loading branch information
danthe3rd authored and xFormers Bot committed Sep 22, 2023
1 parent 326272e commit af6b866
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 10 deletions.
1 change: 1 addition & 0 deletions .github/workflows/win-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
pull_request:
paths:
- "xformers/csrc/**"
- "third-party/**"
- ".github/workflows/win-build.yml"
- "setup.py"
- "requirements*.txt"
Expand Down
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
sources=[os.path.join(flash_root, path) for path in sources],
extra_compile_args={
**extra_compile_args,
"cxx": extra_compile_args["cxx"] + ["-std=c++17"],
"nvcc": extra_compile_args.get("nvcc", [])
+ [
"-O3",
Expand Down Expand Up @@ -226,7 +225,7 @@ def get_extensions():

define_macros = []

extra_compile_args = {"cxx": ["-O3"]}
extra_compile_args = {"cxx": ["-O3", "-std=c++17"]}
if sys.platform == "win32":
define_macros += [("xformers_EXPORTS", None)]
extra_compile_args["cxx"].extend(["/MP", "/Zc:lambda", "/Zc:preprocessor"])
Expand All @@ -253,6 +252,7 @@ def get_extensions():
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--extended-lambda",
"-D_ENABLE_EXTENDED_ALIGNED_STORAGE",
"-std=c++17",
] + get_extra_nvcc_flags_for_build_type()
if os.getenv("XFORMERS_ENABLE_DEBUG_ASSERTIONS", "0") != "1":
nvcc_flags.append("-DNDEBUG")
Expand All @@ -266,7 +266,6 @@ def get_extensions():
]
if sys.platform == "win32":
nvcc_flags += [
"-std=c++17",
"-Xcompiler",
"/Zc:lambda",
"-Xcompiler",
Expand Down
2 changes: 1 addition & 1 deletion third_party/cutlass
Submodule cutlass updated 1094 files
16 changes: 10 additions & 6 deletions xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul_(
cutlass::layout::RowMajor,
scalar_t,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
Expand All @@ -106,8 +107,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul_(
int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1;
using RefA = typename cutlass::
TensorRef<typename DualGemm::ElementA, typename DualGemm::LayoutA>;
using RefB = typename cutlass::
TensorRef<typename DualGemm::ElementB, typename DualGemm::LayoutB>;
using RefB0 = typename cutlass::
TensorRef<typename DualGemm::ElementB, typename DualGemm::LayoutB0>;
using RefB1 = typename cutlass::
TensorRef<typename DualGemm::ElementB, typename DualGemm::LayoutB1>;
using RefC = typename cutlass::
TensorRef<typename DualGemm::ElementC, typename DualGemm::LayoutC>;
RefC ref_b0, ref_b1;
Expand All @@ -120,20 +123,21 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul_(
RefC{(scalar_t*)b1->data_ptr(), typename DualGemm::LayoutC::Stride(0)};
}
typename DualGemm::Arguments arguments{
cutlass::gemm::DualGemmMode::kGemm,
problem_size,
RefA{
(scalar_t*)x.data_ptr(),
typename DualGemm::LayoutA::Stride(x.stride(0))},
RefB{
RefB0{
(scalar_t*)w0.data_ptr(),
typename DualGemm::LayoutB::Stride(w0.stride(0))},
typename DualGemm::LayoutB0::Stride(w0.stride(0))},
ref_b0,
RefC{
(scalar_t*)d0.data_ptr(),
typename DualGemm::LayoutC::Stride(d0.stride(0))},
RefB{
RefB1{
(scalar_t*)w1.data_ptr(),
typename DualGemm::LayoutB::Stride(w1.stride(0))},
typename DualGemm::LayoutB1::Stride(w1.stride(0))},
ref_b1,
RefC{
(scalar_t*)d1.data_ptr(),
Expand Down

0 comments on commit af6b866

Please sign in to comment.