diff --git a/CHANGELOG.md b/CHANGELOG.md
index b956c07351..9bf6d239ba 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,7 +1,7 @@
# NVIDIA CUTLASS Changelog
-## [1.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.1) (2018-12-19)
- * Resolved issue with sm50 and sm52 architectures
+## [1.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.0) (2019-03-20)
+ * Efficient GEMM kernel targeting Volta Tensor Cores via `mma.sync` instruction added in CUDA 10.1.
## [1.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.0) (2018-10-26)
* Parallelized reductions across threadblocks ("Split-K")
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 2ec8cd7bba..25a967b881 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
@@ -20,7 +20,7 @@
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-cmake_minimum_required(VERSION 3.3.0)
+cmake_minimum_required(VERSION 3.3.0 FATAL_ERROR)
set(CUTLASS_LANGUAGES CXX)
@@ -36,7 +36,8 @@ else()
# FindCUDA fails to detect VS 2017 due to a changed directory format of the toolkits.
# For this configuration we need CMake >= 3.9.0 to use the native CUDA support.
if (WIN32 AND MSVC_VERSION GREATER 1800)
- message(FATAL_ERROR "Please upgrade CMake to version >= 3.9.0 to support Visual Studio 2017 or higher")
+ message(SEND_ERROR "Please upgrade CMake to version >= 3.9.0 to support Visual Studio 2017 or higher")
+ cmake_minimum_required(VERSION 3.9.0 FATAL_ERROR)
endif()
# Fall back to the FindCUDA version to create an executable with CUDA files
@@ -52,7 +53,11 @@ if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 )
message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!")
endif()
-find_package(CUDA)
+find_package(CUDA REQUIRED)
+include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
+# Some platforms (e.g. Visual Studio) don't add the CUDA include directories to the system include
+# paths by default, so we add it explicitly here.
+
find_package(Doxygen QUIET)
###################################################################################################
@@ -61,9 +66,18 @@ find_package(Doxygen QUIET)
#
###################################################################################################
-find_library(CUBLAS_LIBRARY cublas HINTS
+#
+# Conditionally enable cuBLAS
+#
+set(CUTLASS_ENABLE_CUBLAS ON CACHE BOOL "Enable CUTLASS Tests to build with cuBLAS library.")
+
+if(CUTLASS_ENABLE_CUBLAS)
+
+ find_library(CUBLAS_LIBRARY cublas HINTS
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
+endif()
+
# By default we want to build in Release mode to ensure that we're getting best performance
if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES))
@@ -78,26 +92,56 @@ if(WIN32)
endif()
if (WIN32)
- # Enable more warnings and treat as errors
- string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
+ # Enable more warnings and treat as errors
+ string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
- # Disable warning on Unicode characters
- string(APPEND NVCC_FLAGS " -Xcompiler /wd4819")
+ # Disable warning on Unicode characters
+ string(APPEND NVCC_FLAGS " -Xcompiler /wd4819")
- # Disable excess x86 floating point precision that can lead to results being labeled incorrectly
- string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
+ # Disable excess x86 floating point precision that can lead to results being labeled incorrectly
+ string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
- # Verbose option
- if (${CUTLASS_NVCC_VERBOSE})
- string(APPEND NVCC_FLAGS " -v")
- endif()
+ # Verbose option
+ if (${CUTLASS_NVCC_VERBOSE})
+ string(APPEND NVCC_FLAGS " -v")
+ endif()
endif(WIN32)
-set(CUTLASS_NVCC_ARCHS "50;60;61;70;75" CACHE STRING "The SM architectures to build code for.")
+set(CUTLASS_NVCC_ARCHS_DEFAULT "")
+if(NOT CUDA_VERSION VERSION_LESS 7.5)
+ list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 50)
+endif()
+if(NOT CUDA_VERSION VERSION_LESS 8.0)
+ list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 60 61)
+endif()
+if(NOT CUDA_VERSION VERSION_LESS 9.0)
+ list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 70)
+endif()
+if(NOT CUDA_VERSION VERSION_LESS 9.2)
+ list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 72)
+endif()
+if(NOT CUDA_VERSION VERSION_LESS 10.0)
+ list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 75)
+endif()
+set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_DEFAULT} CACHE STRING "The SM architectures to build code for.")
+
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
+# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
+if (CUDA_VERSION VERSION_LESS 10.1)
+ set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT OFF)
+else()
+ set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT ON)
+endif()
+
+set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
+ "Enable PTX mma instruction for collective matrix multiply operations.")
+
+set(CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST ${CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST} CACHE BOOL
+ "Enable more kernels instantiated in the perf suite. This might result in longer compiler time. ")
+
#
# NOTE: running with asan and CUDA requires the following environment variable:
#
@@ -131,6 +175,18 @@ foreach(ARCH ${CUTLASS_NVCC_ARCHS})
endif()
endforeach()
+if (CUTLASS_ENABLE_TENSOR_CORE_MMA)
+ string(APPEND NVCC_FLAGS " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1")
+endif()
+
+if (CUTLASS_ENABLE_CUBLAS)
+ string(APPEND NVCC_FLAGS " -DCUTLASS_ENABLE_CUBLAS=1")
+endif()
+
+if (CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST)
+ add_definitions(-DEXHAUSTIVE_PROF)
+endif()
+
if (CUTLASS_NVCC_KEEP)
string(APPEND NVCC_FLAGS " -keep")
endif()
@@ -174,6 +230,7 @@ file(GLOB CUTLASS_UTIL RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/util/*.h)
file(GLOB CUTLASS_DEVICE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/device/*.h)
file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h)
file(GLOB CUTLASS_REDUCTION RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/reduction/*.h )
+file(GLOB CUTLASS_LAYOUT_THREAD RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/layout/thread/*.h)
###################################################################################################
#
@@ -185,16 +242,24 @@ source_group("cutlass\\gemm" FILES ${CUTLASS_GEMM})
source_group("cutlass\\util" FILES ${CUTLASS_UTIL})
source_group("cutlass\\device" FILES ${CUTLASS_DEVICE})
source_group("cutlass\\reduction" FILES ${CUTLASS_REDUCTION})
+source_group("cutlass\\layout\\thread" FILES ${CUTLASS_LAYOUT_THREAD})
source_group("cutlass" FILES ${CUTLASS_CORE})
add_library(CUTLASS INTERFACE)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
+
+# Special policy introduced in CMake 3.13
+if (POLICY CMP0076)
+ cmake_policy(SET CMP0076 NEW)
+endif()
+
target_sources(CUTLASS INTERFACE
${CUTLASS_GEMM}
${CUTLASS_UTIL}
${CUTLASS_DEVICE}
${CUTLASS_CORE}
${CUTLASS_REDUCTION}
+ ${CUTLASS_LAYOUT_THREAD}
)
target_include_directories(CUTLASS INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})
@@ -206,6 +271,7 @@ add_custom_target(cutlass_ide SOURCES
${CUTLASS_DEVICE}
${CUTLASS_CORE}
${CUTLASS_REDUCTION}
+ ${CUTLASS_LAYOUT_THREAD}
)
# Doxygen is available. Generate documentation
if (DOXYGEN_FOUND)
diff --git a/CUTLASS.md b/CUTLASS.md
index eb9e25e5f6..b6553db4e9 100644
--- a/CUTLASS.md
+++ b/CUTLASS.md
@@ -14,7 +14,7 @@ CUTLASS core components, and to identify their role in implementing GEMM computa
# 1. Design Patterns
CUTLASS strives to achieve the highest performance possible on NVIDIA GPUs while also offering a
-flexible composition that an be easily applied to solve new problems related to Deep Learning and
+flexible composition that can be easily applied to solve new problems related to Deep Learning and
linear algebra. Though we intend to make CUTLASS as simple and straightforward as possible, given
a tradeoff between simplicity and performance, CUTLASS chooses performance. Consequently, several
design patterns are necessary to yield a composable structure while also satisfying these performance
@@ -31,7 +31,7 @@ CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvla
## Tiles and Iterators
-Efficient dense linear algebra computations emphasize data movement to match the execution of mathemtical operators to the flow of data. Consequently, CUTLASS defines a rich set of primitives for partitioning a tile of data among participating threads, warps, and threadblocks. CUTLASS applies the familiar iterator design pattern to provide an abstraction layer to (1.) access these tile objects and (2.) traverse a sequence of objects embedded in a higher level data structure. These subpartitions are typically defined by compile-time constants
+Efficient dense linear algebra computations emphasize data movement to match the execution of mathematical operators to the flow of data. Consequently, CUTLASS defines a rich set of primitives for partitioning a tile of data among participating threads, warps, and threadblocks. CUTLASS applies the familiar iterator design pattern to provide an abstraction layer to (1.) access these tile objects and (2.) traverse a sequence of objects embedded in a higher level data structure. These subpartitions are typically defined by compile-time constants
specifying element type, size, and data layout. CUTLASS refers to subpartitions as _tiles_.
_Iterators_ are familiar design patterns in C++ that provide an abstraction for accessing individual
@@ -353,7 +353,7 @@ An example of splitK usage can be found [here](examples/06_splitK_gemm/splitK_ge
# Copyright
-Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
```
Redistribution and use in source and binary forms, with or without modification, are permitted
diff --git a/README.md b/README.md
index c612b0d2f4..231eafbab3 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
-# CUTLASS 1.2
+# CUTLASS 1.3
-_CUTLASS 1.2 - October 2018_
+_CUTLASS 1.3.0 - March 2019_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
@@ -20,13 +20,18 @@ multiply-accumulate abstractions for 8-bit integer, half-precision floating
point (FP16), single-precision floating point (FP32), and double-precision floating
point (FP64) types. Furthermore, CUTLASS demonstrates CUDA's WMMA API for targeting
the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture
-and beyond.
+and beyond. Even faster performance on Volta is possible via direct access to
+Volta Tenor Cores via `mma.sync` (added in CUDA 10.1).
-CUTLASS 1.2 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying
+CUTLASS 1.3 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass).
We describe the structure of an efficient GEMM in our talk at the
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
+# What's New in CUTLASS 1.3
+_March 2019_
+* CUTLASS 1.3 includes an efficient GEMM implementation with the `mma.sync` instruction added in CUDA 10.1.
+
# What's New in CUTLASS 1.2
_October 2018_
* [Parallelized Reductions](CUTLASS.md#parallel-reductions-across-gemm-k)
@@ -63,8 +68,8 @@ when compiled with CUDA 10.0.
# Compatibility
-CUTLASS performs best when compiled with the [CUDA 10.0 Toolkit](ttps://developer.nvidia.com/cuda-toolkit).
-It is compatible with CUDA 9.0, 9.1, and 9.2, but these versions of the CUDA Toolkit do not support new Turing WMMA features.
+CUTLASS performs best when compiled with the [CUDA 10.1 Toolkit](ttps://developer.nvidia.com/cuda-toolkit).
+It is also compatible with CUDA 9.0, 9.1, 9.2, and 10.0.
We have tested the following environments.
@@ -77,7 +82,7 @@ We have tested the following environments.
| Ubuntu 18.04 | GCC 7.3.0 |
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
-any Maxwell-, Pascal-, or Volta-architecture NVIDIA GPU.
+any Maxwell-, Pascal-, Volta-, and Turing-architecture NVIDIA GPUs.
|**GPU**|
|---|
@@ -220,6 +225,9 @@ Program usage:
# Varies GEMM K dimension for SGEMM and IGEMM with column-major multiplicands
$ ./tools/test/perf/cutlass_perf_test --m=10240 --n=4096 --k=1024:8192:128 --kernels=sgemm_nn,igemm_nn
+
+ # Executes GEMM kernel on Volta Tensor Cores
+ $ ./tools/test/perf/cutlass_perf_test --kernels=s884gemm_nt
```
# About
@@ -230,7 +238,7 @@ CUTLASS is released by NVIDIA Corporation as Open Source software under the
# Copyright
-Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
```
Redistribution and use in source and binary forms, with or without modification, are permitted
@@ -253,4 +261,3 @@ Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```
-
diff --git a/cutlass/arch/mma.h b/cutlass/arch/mma.h
new file mode 100644
index 0000000000..1b01df4663
--- /dev/null
+++ b/cutlass/arch/mma.h
@@ -0,0 +1,380 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Templates wrapping direct issue of MMA instructions to Tensor Cores.
+*/
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include "cutlass/shape.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace arch {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Specifies internal data type for computation
+struct ComputeType {
+ enum Kind {
+ kBegin,
+ kDefault, /// Compute type implied by operand and accumulator types
+ kEnd
+ };
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Direct wrapper for native MMA instruction
+template <
+ /// Warp-level matrix multiply-accumulate operation
+ typename WmmaTile,
+ /// Layout of A multiplicand
+ MatrixLayout::Kind LayoutA,
+ /// Data type of A multiplicand
+ typename ScalarA,
+ /// Layout of B multiplicand
+ MatrixLayout::Kind LayoutB,
+ /// Data type of A multiplicand
+ typename ScalarB,
+ /// Data type of accumulators
+ typename ScalarC,
+ /// Specifies particular compute type, overriding data types of operands
+ ComputeType::Kind ComputeTy>
+inline __device__ void mma(ScalarA const A[], ScalarB const B[], ScalarC const C[], ScalarC D[]);
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+//
+// 16x16x4
+//
+
+//
+// FP16 accumulation
+//
+
+/// Volta mma.sync instruction
+template <>
+inline __device__ void mma,
+ MatrixLayout::kRowMajor,
+ half,
+ MatrixLayout::kColumnMajor,
+ half,
+ half,
+ ComputeType::kDefault>(half const a[],
+ half const b[],
+ half const c[],
+ half d[]) {
+#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
+
+ unsigned const *A = reinterpret_cast(a);
+ unsigned const *B = reinterpret_cast(b);
+ unsigned const *C = reinterpret_cast(c);
+ unsigned *D = reinterpret_cast(d);
+
+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
+
+#else
+ CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
+#endif
+}
+
+/// Volta mma.sync instruction
+template <>
+inline __device__ void mma,
+ MatrixLayout::kColumnMajor,
+ half,
+ MatrixLayout::kColumnMajor,
+ half,
+ half,
+ ComputeType::kDefault>(half const a[],
+ half const b[],
+ half const c[],
+ half d[]) {
+#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
+
+ unsigned const *A = reinterpret_cast(a);
+ unsigned const *B = reinterpret_cast(b);
+ unsigned const *C = reinterpret_cast(c);
+ unsigned *D = reinterpret_cast(d);
+
+ asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
+
+#else
+ CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
+#endif
+}
+
+/// Volta mma.sync instruction
+template <>
+inline __device__ void mma,
+ MatrixLayout::kRowMajor,
+ half,
+ MatrixLayout::kRowMajor,
+ half,
+ half,
+ ComputeType::kDefault>(half const a[],
+ half const b[],
+ half const c[],
+ half d[]) {
+#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
+
+ unsigned const *A = reinterpret_cast(a);
+ unsigned const *B = reinterpret_cast(b);
+ unsigned const *C = reinterpret_cast(c);
+ unsigned *D = reinterpret_cast(d);
+
+ asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
+
+#else
+ CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
+#endif
+}
+
+/// Volta mma.sync instruction
+template <>
+inline __device__ void mma,
+ MatrixLayout::kColumnMajor,
+ half,
+ MatrixLayout::kRowMajor,
+ half,
+ half,
+ ComputeType::kDefault>(half const a[],
+ half const b[],
+ half const c[],
+ half d[]) {
+#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
+
+ unsigned const *A = reinterpret_cast(a);
+ unsigned const *B = reinterpret_cast(b);
+ unsigned const *C = reinterpret_cast(c);
+ unsigned *D = reinterpret_cast(d);
+
+ asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
+
+#else
+ CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
+#endif
+}
+
+//
+// FP32 accumulation
+//
+
+/// Volta mma.sync instruction
+template <>
+inline __device__ void mma,
+ MatrixLayout::kRowMajor,
+ half,
+ MatrixLayout::kColumnMajor,
+ half,
+ float,
+ ComputeType::kDefault>(half const a[],
+ half const b[],
+ float const C[],
+ float D[]) {
+#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
+
+ unsigned const *A = reinterpret_cast(a);
+ unsigned const *B = reinterpret_cast(b);
+
+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
+ "{%12,%13,%14,%15,%16,%17,%18,%19};"
+ : "=f"(D[0]),
+ "=f"(D[1]),
+ "=f"(D[2]),
+ "=f"(D[3]),
+ "=f"(D[4]),
+ "=f"(D[5]),
+ "=f"(D[6]),
+ "=f"(D[7])
+ : "r"(A[0]),
+ "r"(A[1]),
+ "r"(B[0]),
+ "r"(B[1]),
+ "f"(C[0]),
+ "f"(C[1]),
+ "f"(C[2]),
+ "f"(C[3]),
+ "f"(C[4]),
+ "f"(C[5]),
+ "f"(C[6]),
+ "f"(C[7]));
+
+#else
+ CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
+#endif
+}
+
+/// Volta mma.sync instruction
+template <>
+inline __device__ void mma,
+ MatrixLayout::kColumnMajor,
+ half,
+ MatrixLayout::kColumnMajor,
+ half,
+ float,
+ ComputeType::kDefault>(half const a[],
+ half const b[],
+ float const C[],
+ float D[]) {
+
+#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
+
+ unsigned const *A = reinterpret_cast(a);
+ unsigned const *B = reinterpret_cast(b);
+
+ asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
+ "{%12,%13,%14,%15,%16,%17,%18,%19};"
+ : "=f"(D[0]),
+ "=f"(D[1]),
+ "=f"(D[2]),
+ "=f"(D[3]),
+ "=f"(D[4]),
+ "=f"(D[5]),
+ "=f"(D[6]),
+ "=f"(D[7])
+ : "r"(A[0]),
+ "r"(A[1]),
+ "r"(B[0]),
+ "r"(B[1]),
+ "f"(C[0]),
+ "f"(C[1]),
+ "f"(C[2]),
+ "f"(C[3]),
+ "f"(C[4]),
+ "f"(C[5]),
+ "f"(C[6]),
+ "f"(C[7]));
+
+#else
+ CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
+#endif
+}
+
+/// Volta mma.sync instruction
+template <>
+inline __device__ void mma,
+ MatrixLayout::kRowMajor,
+ half,
+ MatrixLayout::kRowMajor,
+ half,
+ float,
+ ComputeType::kDefault>(half const a[],
+ half const b[],
+ float const C[],
+ float D[]) {
+#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
+
+ unsigned const *A = reinterpret_cast(a);
+ unsigned const *B = reinterpret_cast(b);
+
+ asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
+ "{%12,%13,%14,%15,%16,%17,%18,%19};"
+ : "=f"(D[0]),
+ "=f"(D[1]),
+ "=f"(D[2]),
+ "=f"(D[3]),
+ "=f"(D[4]),
+ "=f"(D[5]),
+ "=f"(D[6]),
+ "=f"(D[7])
+ : "r"(A[0]),
+ "r"(A[1]),
+ "r"(B[0]),
+ "r"(B[1]),
+ "f"(C[0]),
+ "f"(C[1]),
+ "f"(C[2]),
+ "f"(C[3]),
+ "f"(C[4]),
+ "f"(C[5]),
+ "f"(C[6]),
+ "f"(C[7]));
+
+#else
+ CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
+#endif
+}
+
+/// Volta mma.sync instruction
+template <>
+inline __device__ void mma,
+ MatrixLayout::kColumnMajor,
+ half,
+ MatrixLayout::kRowMajor,
+ half,
+ float,
+ ComputeType::kDefault>(half const a[],
+ half const b[],
+ float const C[],
+ float D[]) {
+#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
+
+ unsigned const *A = reinterpret_cast(a);
+ unsigned const *B = reinterpret_cast(b);
+
+ asm volatile ("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
+ "{%12,%13,%14,%15,%16,%17,%18,%19};"
+ : "=f"(D[0]),
+ "=f"(D[1]),
+ "=f"(D[2]),
+ "=f"(D[3]),
+ "=f"(D[4]),
+ "=f"(D[5]),
+ "=f"(D[6]),
+ "=f"(D[7])
+ : "r"(A[0]),
+ "r"(A[1]),
+ "r"(B[0]),
+ "r"(B[1]),
+ "f"(C[0]),
+ "f"(C[1]),
+ "f"(C[2]),
+ "f"(C[3]),
+ "f"(C[4]),
+ "f"(C[5]),
+ "f"(C[6]),
+ "f"(C[7]));
+
+#else
+ CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
+#endif
+}
+
+} // namespace arch
+} // namespace cutlass
diff --git a/cutlass/convert.h b/cutlass/convert.h
index b4d0f8eddb..23fa5d560f 100644
--- a/cutlass/convert.h
+++ b/cutlass/convert.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/coord.h b/cutlass/coord.h
index e90af8a1b8..7e91d6e99d 100644
--- a/cutlass/coord.h
+++ b/cutlass/coord.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/core_io.h b/cutlass/core_io.h
index 849a7613f4..edfc1ec803 100644
--- a/cutlass/core_io.h
+++ b/cutlass/core_io.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/cutlass.h b/cutlass/cutlass.h
index ac5420d724..26de6c0278 100644
--- a/cutlass/cutlass.h
+++ b/cutlass/cutlass.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -29,11 +29,12 @@
#pragma once
+
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CUTLASS_MAJOR 1
-#define CUTLASS_MINOR 2
-#define CUTLASS_PATCH 1
+#define CUTLASS_MINOR 3
+#define CUTLASS_PATCH 0
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
#ifdef __NVCC__
@@ -47,9 +48,31 @@
// CUTLASS_DEVICE is an error if not compiling device code
#endif
+// CUDA 10.1 introduces the mma instruction
+#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA)
+#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0
+#endif
+
+// CUTLASS assert
#define CUTLASS_ASSERT(x) assert(x)
-#include "cutlass/util/performance_tuning.h"
+// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler.
+#if defined(__CUDA_ARCH__)
+ #define CUTLASS_PRAGMA_UNROLL #pragma unroll
+ #define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1
+
+ #define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL
+
+ #define CUTLASS_GEMM_LOOP_HEADER \
+ asm volatile (".pragma \"nounroll\";\n");
+#else
+
+ #define CUTLASS_PRAGMA_UNROLL
+ #define CUTLASS_PRAGMA_NO_UNROLL
+ #define CUTLASS_GEMM_LOOP_HEADER
+ #define CUTLASS_GEMM_LOOP
+
+#endif
// A small helper class to dump a type at compile time
// Usage:: DumpType::Class
diff --git a/cutlass/fragment.h b/cutlass/fragment.h
index 6a93d779c4..e048c525b1 100644
--- a/cutlass/fragment.h
+++ b/cutlass/fragment.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -98,9 +98,9 @@ struct StorageType<1> {
template
struct Fragment : public AlignedStruct {
/// Make sure the alignment makes sense wrt the size of elements.
- static_assert(kAlignment_ == 16 || kAlignment_ >= sizeof(Element_), "Alignment is too small");
+ static_assert(int(kAlignment_) == 16 || int(kAlignment_) >= sizeof(Element_), "Alignment is too small");
/// Alignment must be a power of two
- static_assert(is_pow2::value, "Alignment must be a power of two");
+ static_assert(is_pow2::value, "Alignment must be a power of two");
/// This class.
typedef Fragment This_;
@@ -109,27 +109,31 @@ struct Fragment : public AlignedStruct {
/// The number of elements.
static int const kElements = kElements_;
/// Alignment
- static int const kAlignment = kAlignment_;
+ static int const kAlignment = int(kAlignment_);
/// Clear a fragment.
CUTLASS_HOST_DEVICE void clear() {
// Avoid element-wise access for sub 32b element type
if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) {
uint64_t* ptr = reinterpret_cast(storage);
+ CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < (kElements * sizeof(Element)) / 8; ++i) {
ptr[i] = uint64_t(0);
}
} else if (kAlignment_ >= 4 && (kElements * sizeof(Element)) % 4 == 0) {
uint32_t* ptr = reinterpret_cast(storage);
+ CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < (kElements * sizeof(Element)) / 4; ++i) {
ptr[i] = uint32_t(0);
}
} else if (kAlignment_ >= 2 && (kElements * sizeof(Element)) % 2 == 0) {
uint16_t* ptr = reinterpret_cast(storage);
+ CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < (kElements * sizeof(Element)) / 2; ++i) {
ptr[i] = uint16_t(0);
}
} else {
+ CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kElements; ++i) {
storage[i] = 0;
}
@@ -146,7 +150,7 @@ struct Fragment : public AlignedStruct {
private:
/// Storage type to use for Elements
- typedef typename StorageType::Type StorageType;
+ typedef typename StorageType::Type StorageType;
/// Number of elements in the storage
static int const kStorageCount =
diff --git a/cutlass/fragment_multiply_add.h b/cutlass/fragment_multiply_add.h
index 8bcf81209a..f29b08e7ca 100644
--- a/cutlass/fragment_multiply_add.h
+++ b/cutlass/fragment_multiply_add.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -68,7 +68,6 @@ struct FragmentMultiplyAdd {
FragmentB_ const& b,
FragmentCd_ const& c,
FragmentCd_& d) {
-
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
for (int j = 0; j < FragmentCd_::kElements; ++j) {
d[j] = b[j * kReduction + 0];
diff --git a/cutlass/gemm/clear_accumulators.h b/cutlass/gemm/clear_accumulators.h
index 3a2f337525..9e336cb6f3 100644
--- a/cutlass/gemm/clear_accumulators.h
+++ b/cutlass/gemm/clear_accumulators.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/device_gemm.h b/cutlass/gemm/device_gemm.h
index aaf4bfe783..1380f90efc 100644
--- a/cutlass/gemm/device_gemm.h
+++ b/cutlass/gemm/device_gemm.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
-* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -33,6 +33,8 @@
#include "cutlass/coord.h"
#include "cutlass/util/platform.h"
+#include "cutlass/gemm/gemm.h"
+
namespace cutlass {
namespace gemm {
@@ -47,7 +49,8 @@ struct DeviceGemm {
#if !defined(__CUDACC_RTC__)
/// Launch the kernels in order
static __host__ cudaError_t launch(Params const& params) {
- Traits::GemmTraits::KernelClass::launch(params.GemmParams);
+ //Traits::GemmTraits::KernelClass::launch(params.GemmParams);
+ Gemm::launch(params.GemmParams);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
return err;
diff --git a/cutlass/gemm/device_gemm_traits.h b/cutlass/gemm/device_gemm_traits.h
index fbcfef3e1f..1ff2a11cc9 100644
--- a/cutlass/gemm/device_gemm_traits.h
+++ b/cutlass/gemm/device_gemm_traits.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
-* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -73,7 +73,7 @@ struct SplitkPIGemmTraits {
/// The pointer to workspace memory
ScalarAccum *workspace_ptr;
///
- int workspace_size;
+ size_t workspace_size;
/// The Params for the first kernel
typename GemmTraits::Params GemmParams;
/// The Params for the second kernel
@@ -112,7 +112,8 @@ struct SplitkPIGemmTraits {
Index ldc_,
ScalarD* d_d_,
Index ldd_,
- ScalarAccum *workspace_ptr_) {
+ ScalarAccum *workspace_ptr_,
+ Index partitionK_multiple = 1) {
workspace_ptr = workspace_ptr_;
@@ -133,7 +134,7 @@ struct SplitkPIGemmTraits {
TensorRef(workspace_ptr, problem_size.m()), /*m = ldc, workspace is not transposed and is packed*/
TensorRef(workspace_ptr, problem_size.m()) /*m = ldd, workspace is not transposed and is packed*/
);
- GemmParams.initialize(desc, ReductionTraits::ReductionSize);
+ GemmParams.initialize(desc, ReductionTraits::ReductionSize, partitionK_multiple);
//call batched reduction (second kernel) param
@@ -155,9 +156,12 @@ struct SplitkPIGemmTraits {
// workspace will be used to store D (output) from the first gemm kernel (not D of the entire gemm)
// note typedef typename GemmTraits::ScalarD ScalarAccum;
// workspace of size of M * N * Reduction
- int required_workspace_memory_in_byte(){
+ size_t required_workspace_memory_in_byte(){
assert(problem_size_initialized == true);
- workspace_size = problem_size.n() * problem_size.m() * ReductionTraits::ReductionSize * static_cast(sizeof(ScalarAccum));
+ workspace_size = static_cast(problem_size.n()) *
+ static_cast(problem_size.m()) *
+ static_cast(ReductionTraits::ReductionSize) *
+ sizeof(ScalarAccum);
return workspace_size;
}
diff --git a/cutlass/gemm/dgemm_traits.h b/cutlass/gemm/dgemm_traits.h
index 5c05590207..4005bf9647 100644
--- a/cutlass/gemm/dgemm_traits.h
+++ b/cutlass/gemm/dgemm_traits.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/fp16_sgemm_multiply_add.h b/cutlass/gemm/fp16_sgemm_multiply_add.h
index 534b8c8998..b45ae8cfbb 100644
--- a/cutlass/gemm/fp16_sgemm_multiply_add.h
+++ b/cutlass/gemm/fp16_sgemm_multiply_add.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -29,6 +29,7 @@
#include "cutlass/fragment.h"
#include "cutlass/gemm/thread_multiply_add.h"
+
namespace cutlass {
namespace gemm {
@@ -69,8 +70,10 @@ struct ThreadMultiplyAdd {
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
+
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
+
d[j * AccumulatorsPerThread::kW + i] = static_cast(a[i]) * static_cast(b[j]) + c[j * AccumulatorsPerThread::kW + i];
}
}
diff --git a/cutlass/gemm/fp16_sgemm_traits.h b/cutlass/gemm/fp16_sgemm_traits.h
index 361186455b..6ad6d96385 100644
--- a/cutlass/gemm/fp16_sgemm_traits.h
+++ b/cutlass/gemm/fp16_sgemm_traits.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/gemm.h b/cutlass/gemm/gemm.h
index 3aec792866..0d91919931 100644
--- a/cutlass/gemm/gemm.h
+++ b/cutlass/gemm/gemm.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -33,20 +33,30 @@
#include "cutlass/coord.h"
#include "cutlass/util/platform.h"
+#include
namespace cutlass {
namespace gemm {
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+
////////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel with launch bounds specified
template
__global__ __launch_bounds__(Gemm_::kThreads)
void gemm_kernel(typename Gemm_::Params params) {
- // Declare shared memory.
- __shared__ typename Gemm_::SharedStorage shared_storage;
+
+ // Dynamic shared memory base pointer
+ extern __shared__ int GemmSharedStorageBase[];
+
+ // Declare pointer to dynamic shared memory.
+ typename Gemm_::SharedStorage *shared_storage =
+ reinterpret_cast(GemmSharedStorageBase);
// Construct the GEMM object.
- Gemm_ gemm(params, shared_storage);
+ Gemm_ gemm(params, *shared_storage);
+
// Run GEMM.
gemm.multiply_add();
}
@@ -57,11 +67,17 @@ void gemm_kernel(typename Gemm_::Params params) {
template
__global__ /* __launch_bounds__(Gemm_::kThreads) */
void gemm_kernel_nolb(typename Gemm_::Params params) {
- // Declare shared memory.
- __shared__ typename Gemm_::SharedStorage shared_storage;
+
+ // Dynamic shared memory base pointer
+ extern __shared__ int GemmSharedStorageBase[];
+
+ // Declare pointer to dynamic shared memory.
+ typename Gemm_::SharedStorage *shared_storage =
+ reinterpret_cast(GemmSharedStorageBase);
// Construct the GEMM object.
- Gemm_ gemm(params, shared_storage);
+ Gemm_ gemm(params, *shared_storage);
+
// Run GEMM.
gemm.multiply_add();
}
@@ -72,7 +88,31 @@ void gemm_kernel_nolb(typename Gemm_::Params params) {
template
struct Launch {
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
- gemm_kernel<<< grid, block, 0, stream >>>(params);
+
+ int smem_size = int(sizeof(typename Gemm::SharedStorage));
+ if (smem_size >= (48 << 10)) {
+
+ cudaError_t result = cudaFuncSetAttribute(
+ gemm_kernel,
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
+ smem_size
+ );
+
+ if (result != cudaSuccess) {
+ return;
+ }
+
+ result = cudaFuncSetAttribute(
+ gemm_kernel_nolb,
+ cudaFuncAttributePreferredSharedMemoryCarveout,
+ 100);
+
+ if (result != cudaSuccess) {
+ return;
+ }
+ }
+
+ gemm_kernel<<< grid, block, sizeof(typename Gemm::SharedStorage), stream >>>(params);
}
};
@@ -82,50 +122,51 @@ struct Launch {
template
struct Launch {
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
- gemm_kernel_nolb<<< grid, block, 0, stream >>>(params);
+ int smem_size = int(sizeof(typename Gemm::SharedStorage));
+ if (smem_size >= (48 << 10)) {
+
+ cudaError_t result = cudaFuncSetAttribute(
+ gemm_kernel_nolb,
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
+ smem_size
+ );
+
+ if (result != cudaSuccess) {
+ return;
+ }
+
+ result = cudaFuncSetAttribute(
+ gemm_kernel_nolb,
+ cudaFuncAttributePreferredSharedMemoryCarveout,
+ 100);
+
+ if (result != cudaSuccess) {
+ // throw exception?
+ return;
+ }
+ }
+
+ gemm_kernel_nolb<<<
+ grid,
+ block,
+ smem_size,
+ stream >>>(params);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
-template
+template
struct Gemm {
- /// This class.
- typedef Gemm This_;
+
/// The traits.
- typedef GemmTraits_ Traits;
- /// The shared storage.
- typedef typename Traits::SharedStorage SharedStorage;
-
- /// The scalar for A.
- typedef typename Traits::ScalarA ScalarA;
- /// The scalar for B.
- typedef typename Traits::ScalarB ScalarB;
- /// The scalar in the epilogue.
- typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
- /// The scalar for C.
- typedef typename Traits::Epilogue::ScalarC ScalarC;
- /// The scalar for D.
- typedef typename Traits::Epilogue::ScalarD ScalarD;
- /// The index.
- typedef typename Traits::Index Index;
-
- /// Define the mainloop iteration size
- typedef typename Traits::MultiplyAdd MultiplyAdd;
-
- /// The number of threads.
- static int const kThreads = Traits::GemmConfig::kThreads;
-
- // Number of warp-level multiply-accumulate steps executed by each warp.
- static Index const kWarpGemmSteps =
- Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
-
- // Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
- static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps");
+ typedef Traits_ Traits;
/// Use the params object defined in traits
typedef typename Traits::Params Params;
+ typedef typename Traits::KernelClass KernelClass;
+
//
// Static function members
//
@@ -137,7 +178,7 @@ struct Gemm {
cudaStream_t stream = cudaStreamDefault) {
// Launch the kernel.
- Launch(
+ Launch(
params, params.grid, params.block, stream);
return cudaGetLastError();
@@ -164,189 +205,6 @@ struct Gemm {
}
#endif
-
- //
- // Methods
- //
-
- /// Ctor.
- CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_)
- : params(params_), shared_storage(shared_storage_) {}
-
- /// Computes a warp-level GEMM on data held in shared memory
- template
- CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream,
- typename Traits::SharedStream& shared_load_stream,
- typename MultiplyAdd::Accumulators& accumulators,
- Index outer_k) {
- // If residue portion and not calculating residue in prolog, update residue predicates now.
- if (Residue && outer_k <= Traits::OutputTile::kD) {
- global_to_shared_stream.residue(outer_k);
- }
-
- // Load data for the next iteration of the main loop (unless it's the last iteration).
- if (!LastIteration) {
- global_to_shared_stream.copy();
- }
-
- CUTLASS_PRAGMA_UNROLL
- for (int step = 0; step < kWarpGemmSteps - 1; ++step) {
- // Trigger the copy from shared memory for the next A/B values.
- shared_load_stream.copy(step + 1);
-
- // Make sure the values are available for the current iteration to do the multiply-add.
- shared_load_stream.commit(step);
-
- MultiplyAdd multiply_add;
-
- // Do the math on the fragments of the current iteration.
- multiply_add.multiply_add(shared_load_stream.fragment_a(step),
- shared_load_stream.fragment_b(step),
- accumulators,
- accumulators);
- }
-
- // Make sure the data from shared memory has been entirely consumed.
- Traits::shared_load_fence(true);
-
- // Commit the data in shared memory for A/B.
- if (!LastIteration) {
- global_to_shared_stream.commit();
- }
- // Make sure the data is in shared memory.
- Traits::shared_store_fence(true);
-
- if (!LastIteration) {
- // Move to the next stage for the load (if it makes sense).
- shared_load_stream.inc_stage();
- // Trigger the copy from shared memory for the next loop iteration.
- shared_load_stream.copy(0);
- }
- // Make sure the values are available for the current iteration to do the multiply-add.
- shared_load_stream.commit(kWarpGemmSteps - 1);
-
- // Do the math on the fragments of the current iteration.
- MultiplyAdd multiply_add;
- multiply_add.multiply_add(shared_load_stream.fragment_a(kWarpGemmSteps - 1),
- shared_load_stream.fragment_b(kWarpGemmSteps - 1),
- accumulators,
- accumulators);
- }
-
- /// Do the GEMM.
- CUTLASS_DEVICE void multiply_add() {
- // Swizzle the IDs of the block (to enable better cache behavior).
- typename Traits::BlockSwizzle block_swizzle;
- Coord<3> threadblock_offset =
- block_swizzle.get_threadblock_offset(make_Coord_from_shape());
-
- // We may want to use shared memory to clear the registers.
- typedef typename Traits::ClearAccumulators ClearAccumulators;
-
- // Get the bounds for each thread, it maybe different than problem_size
- Coord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size,
- params.partitionK_range);
-
- // The streams to read A/B from global memory to shared memory.
- typename Traits::GlobalLoadStream global_to_shared_stream(
- params.global_to_shared_stream,
- shared_storage.main_loop.global_to_shared_stream,
- shared_storage.main_loop.threadblock_tile.reference(),
- bounds,
- threadblock_offset);
-
- // update A and B pointer offset based on batch_id and batch_stride_offset
- global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id());
-
- // Create the accumulator clear.
- ClearAccumulators clear;
-
- // Deal with residue in prolog.
- // global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
- global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD);
-
- // Fetch the fragments for A and B from global memory.
- global_to_shared_stream.copy();
-
- // Copy the elements to shared memory (after transformation if needed).
- global_to_shared_stream.commit();
-
- // Make sure the data is in shared memory.
- Traits::shared_store_fence(false);
-
- // Rollback to the beginning of the first tile (if residue exists).
- // global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
- global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD);
-
- // The stream of data from shared memory to fragments.
- typename Traits::SharedStream shared_load_stream(
- params.shared_stream,
- shared_storage.main_loop.threadblock_tile.reference());
-
- // Trigger the copy from shared memory for the 1st stream.
- shared_load_stream.copy(0);
-
- // Allocate the accumulators.
- typename MultiplyAdd::Accumulators accumulators;
-
- // Clear the accumulators.
- clear.clear(accumulators);
-
- // Initial index
- // Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
- // problem_size[0] might be bigger than bounds[0]
- Index outer_k = bounds[0] - Traits::OutputTile::kD;
- // Check if we are computing residue in prolog or not.
- if (Traits::GemmConfig::kResidueInProlog) {
- // Execute all mainloop iterations but the last one.
-
- CUTLASS_GEMM_LOOP
- for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
- consume_tile(
- global_to_shared_stream, shared_load_stream, accumulators, outer_k);
- }
-
- // Don't load data for the last "residue" portion since we've already computed the residue.
- CUTLASS_GEMM_LOOP
- for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
- consume_tile(
- global_to_shared_stream, shared_load_stream, accumulators, outer_k);
- }
- } else {
- // When kResidueSeparate = true, execute all mainloop iterations but the last two without any
- // consideration for K-residue or predicate updates. This improves the steady state of some
- // kernels.
- if (Traits::GemmConfig::kResidueSeparate) {
-
- CUTLASS_GEMM_LOOP
- for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
- consume_tile(
- global_to_shared_stream, shared_load_stream, accumulators, outer_k);
- }
- }
-
- // Execute remaining tiles with K-residue predicate updates enabled.
- CUTLASS_GEMM_LOOP
- for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
- consume_tile(
- global_to_shared_stream, shared_load_stream, accumulators, outer_k);
- }
- }
-
- // Epilogue.
- typedef typename Traits::Epilogue Epilogue;
- Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm());
- epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());
- }
-
- //
- // Data members
- //
-
- /// The params.
- Params const& params;
- /// The shared storage.
- SharedStorage& shared_storage;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/cutlass/gemm/gemm_config.h b/cutlass/gemm/gemm_config.h
index 76df0add62..1153cc470e 100644
--- a/cutlass/gemm/gemm_config.h
+++ b/cutlass/gemm/gemm_config.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/gemm_coord.h b/cutlass/gemm/gemm_coord.h
index e029af3522..4d64eb96c8 100644
--- a/cutlass/gemm/gemm_coord.h
+++ b/cutlass/gemm/gemm_coord.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/gemm_desc.h b/cutlass/gemm/gemm_desc.h
index 80f4b36557..597e46d5c7 100644
--- a/cutlass/gemm/gemm_desc.h
+++ b/cutlass/gemm/gemm_desc.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/gemm_epilogue.h b/cutlass/gemm/gemm_epilogue.h
index 0e0cfc5374..086b50901a 100644
--- a/cutlass/gemm/gemm_epilogue.h
+++ b/cutlass/gemm/gemm_epilogue.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/gemm_epilogue_traits.h b/cutlass/gemm/gemm_epilogue_traits.h
index bffd5e516c..7f0f788e66 100644
--- a/cutlass/gemm/gemm_epilogue_traits.h
+++ b/cutlass/gemm/gemm_epilogue_traits.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/gemm_global_stream.h b/cutlass/gemm/gemm_global_stream.h
index 1ae2963c00..6a6d75b36d 100644
--- a/cutlass/gemm/gemm_global_stream.h
+++ b/cutlass/gemm/gemm_global_stream.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/gemm_global_tile.h b/cutlass/gemm/gemm_global_tile.h
index 5174ce67fe..154bc24dc0 100644
--- a/cutlass/gemm/gemm_global_tile.h
+++ b/cutlass/gemm/gemm_global_tile.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/gemm_mainloop.h b/cutlass/gemm/gemm_mainloop.h
new file mode 100644
index 0000000000..a65cb3ae21
--- /dev/null
+++ b/cutlass/gemm/gemm_mainloop.h
@@ -0,0 +1,274 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Implements a software-pipelined efficient GEMM.
+*/
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include "cutlass/coord.h"
+
+namespace cutlass {
+namespace gemm {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct GemmMainloop {
+
+ //
+ // Type definitions
+ //
+
+ /// The traits.
+ typedef Traits_ Traits;
+
+ /// The GEMM mainloop
+ typedef typename Traits::KernelClass KernelClass;
+
+ /// The shared storage.
+ typedef typename Traits::SharedStorage SharedStorage;
+
+ /// The scalar for A.
+ typedef typename Traits::ScalarA ScalarA;
+ /// The scalar for B.
+ typedef typename Traits::ScalarB ScalarB;
+ /// The scalar in the epilogue.
+ typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
+ /// The scalar for C.
+ typedef typename Traits::Epilogue::ScalarC ScalarC;
+ /// The scalar for D.
+ typedef typename Traits::Epilogue::ScalarD ScalarD;
+ /// The index.
+ typedef typename Traits::Index Index;
+
+ /// Define the mainloop iteration size
+ typedef typename Traits::MultiplyAdd MultiplyAdd;
+
+ /// The number of threads.
+ static int const kThreads = Traits::GemmConfig::kThreads;
+
+ // Number of warp-level multiply-accumulate steps executed by each warp.
+ static Index const kWarpGemmSteps =
+ Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
+
+ /*
+ // Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
+ static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps");
+ */
+
+ /// Use the params object defined in traits
+ typedef typename Traits::Params Params;
+
+ //
+ // Data members
+ //
+
+ /// The params.
+ Params const& params;
+
+ /// SharedStorage object
+ SharedStorage& shared_storage;
+
+ //
+ // Methods
+ //
+
+ /// Ctor.
+ CUTLASS_DEVICE GemmMainloop(Params const& params_, SharedStorage& shared_storage_)
+ : params(params_), shared_storage(shared_storage_) {}
+
+ /// Fetches global stream pair
+ template
+ CUTLASS_DEVICE void fetch_global(typename Traits::GlobalLoadStream& global_to_shared_stream,
+ Index outer_k) {
+ // If residue portion and not calculating residue in prolog, update residue predicates now.
+ if (Residue) {
+ global_to_shared_stream.residue(outer_k);
+ }
+ global_to_shared_stream.copy();
+ }
+
+ /// Computes a warp-level GEMM on data held in shared memory
+ template
+ CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream,
+ typename Traits::SharedStream& shared_load_stream,
+ typename MultiplyAdd::Accumulators& accumulators,
+ Index outer_k) {
+
+ // Whether to load global stream before loading shared stream
+ const bool kGlobalStreamFirst = (kWarpGemmSteps <= 4);
+
+ // Load data for the next iteration of the main loop (unless it's the last iteration).
+ if (kGlobalStreamFirst && !LastIteration) {
+ fetch_global(global_to_shared_stream, outer_k);
+ }
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int step = 0; step < kWarpGemmSteps; ++step) {
+
+ // Trigger the copy from shared memory for the next A/B values.
+ shared_load_stream.copy((step + 1) % kWarpGemmSteps);
+
+ // Load data for the next iteration of the main loop (unless it's the last iteration).
+ if (!kGlobalStreamFirst && (step == 0) && !LastIteration) {
+ fetch_global(global_to_shared_stream, outer_k);
+ }
+
+ if (step == kWarpGemmSteps - 2) {
+ // Make sure the data from shared memory has been entirely consumed.
+ Traits::shared_load_fence(true);
+
+ global_to_shared_stream.commit();
+
+ // Make sure the data is in shared memory.
+ Traits::shared_store_fence(true);
+
+ // Move to the next stage for the load (if it makes sense).
+ shared_load_stream.inc_stage();
+ }
+
+ // Make sure the values are available for the current iteration to do the multiply-add.
+ shared_load_stream.commit(step);
+
+ // Do the math on the fragments of the current iteration.
+ MultiplyAdd multiply_add;
+ multiply_add.multiply_add(shared_load_stream.fragment_a(step),
+ shared_load_stream.fragment_b(step),
+ accumulators,
+ accumulators);
+ }
+ }
+
+ /// Do the GEMM.
+ CUTLASS_DEVICE void multiply_add() {
+ // Swizzle the IDs of the block (to enable better cache behavior).
+ typename Traits::BlockSwizzle block_swizzle;
+ Coord<3> threadblock_offset =
+ block_swizzle.get_threadblock_offset(make_Coord_from_shape());
+
+ // We may want to use shared memory to clear the registers.
+ typedef typename Traits::ClearAccumulators ClearAccumulators;
+
+ // Get the bounds for each thread, it maybe different than problem_size
+ Coord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size,
+ params.partitionK_range);
+
+ // The streams to read A/B from global memory to shared memory.
+ typename Traits::GlobalLoadStream global_to_shared_stream(
+ params.global_to_shared_stream,
+ shared_storage.main_loop.global_to_shared_stream,
+ shared_storage.main_loop.threadblock_tile.reference(),
+ bounds,
+ threadblock_offset);
+
+ // update A and B pointer offset based on batch_id and batch_stride_offset
+ global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id());
+
+ // Create the accumulator clear.
+ ClearAccumulators clear;
+
+ // Deal with residue in prolog.
+ // global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
+ global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD);
+
+ // Fetch the fragments for A and B from global memory.
+ global_to_shared_stream.copy();
+
+ // Copy the elements to shared memory (after transformation if needed).
+ global_to_shared_stream.commit();
+
+ // Make sure the data is in shared memory.
+ Traits::shared_store_fence(false);
+
+ // Rollback to the beginning of the first tile (if residue exists).
+ // global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
+ global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD);
+
+ // The stream of data from shared memory to fragments.
+ typename Traits::SharedStream shared_load_stream(
+ params.shared_stream,
+ shared_storage.main_loop.threadblock_tile.reference());
+
+ // Trigger the copy from shared memory for the 1st stream.
+ shared_load_stream.copy(0);
+
+ // Allocate the accumulators.
+ typename MultiplyAdd::Accumulators accumulators;
+
+ // Clear the accumulators.
+ clear.clear(accumulators);
+
+ // Initial index
+ // Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
+ // problem_size[0] might be bigger than bounds[0]
+ Index outer_k = bounds[0] - Traits::OutputTile::kD;
+ // Check if we are computing residue in prolog or not.
+ if (Traits::GemmConfig::kResidueInProlog) {
+ // Execute all mainloop iterations but the last one.
+
+ CUTLASS_GEMM_LOOP
+ for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
+ CUTLASS_GEMM_LOOP_HEADER
+ consume_tile(
+ global_to_shared_stream, shared_load_stream, accumulators, outer_k);
+ }
+
+ consume_tile(
+ global_to_shared_stream, shared_load_stream, accumulators, outer_k);
+
+ } else {
+ // When kResidueSeparate = true, execute all mainloop iterations but the last two without any
+ // consideration for K-residue or predicate updates. This improves the steady state of some
+ // kernels.
+ if (Traits::GemmConfig::kResidueSeparate) {
+
+ CUTLASS_GEMM_LOOP
+ for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
+ CUTLASS_GEMM_LOOP_HEADER
+ consume_tile(
+ global_to_shared_stream, shared_load_stream, accumulators, outer_k);
+ }
+ }
+
+ // Execute remaining tiles with K-residue predicate updates enabled.
+ CUTLASS_GEMM_LOOP
+ for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
+ CUTLASS_GEMM_LOOP_HEADER
+ consume_tile(
+ global_to_shared_stream, shared_load_stream, accumulators, outer_k);
+ }
+ }
+
+ typedef typename Traits::Epilogue Epilogue;
+ Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm());
+ epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace gemm
+} // namespace cutlass
diff --git a/cutlass/gemm/gemm_operand.h b/cutlass/gemm/gemm_operand.h
index 2b4dcdc916..eef0e50234 100644
--- a/cutlass/gemm/gemm_operand.h
+++ b/cutlass/gemm/gemm_operand.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/gemm_shared_stream.h b/cutlass/gemm/gemm_shared_stream.h
index ed158d6b23..beb214ab7c 100644
--- a/cutlass/gemm/gemm_shared_stream.h
+++ b/cutlass/gemm/gemm_shared_stream.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -91,8 +91,16 @@ struct SharedLoadStream {
transformer = Transformer();
}
+ /// Clears the fragment
+ CUTLASS_DEVICE void clear() {
+ fetched[0].clear();
+ fetched[1].clear();
+ transformed[0].clear();
+ transformed[1].clear();
+ }
+
/// Load the data from shared memory to the fetch fragment.
- CUTLASS_DEVICE void copy() {
+ CUTLASS_DEVICE void copy() {
iterator.load_post_increment(fetched[0]);
}
diff --git a/cutlass/gemm/gemm_shared_tile.h b/cutlass/gemm/gemm_shared_tile.h
index 78fb1f2054..6882ca9b88 100644
--- a/cutlass/gemm/gemm_shared_tile.h
+++ b/cutlass/gemm/gemm_shared_tile.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/gemm_stream_pair.h b/cutlass/gemm/gemm_stream_pair.h
index f1c22edfc0..5690afa289 100644
--- a/cutlass/gemm/gemm_stream_pair.h
+++ b/cutlass/gemm/gemm_stream_pair.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -140,14 +140,19 @@ struct GlobalLoadStreamPair {
/// Trigger the copies from shared memory to registers.
CUTLASS_DEVICE void copy() {
+
stream_a.copy();
+
stream_b.copy();
+
}
/// Commit the data.
CUTLASS_DEVICE void commit() {
stream_a.commit();
+
stream_b.commit();
+
}
/// Execute the residue code.
@@ -233,6 +238,13 @@ struct SharedStreamPair {
stream_b.commit(step);
}
+ /// Clears all fragments
+ CUTLASS_DEVICE
+ void clear() {
+ stream_a.clear();
+ stream_b.clear();
+ }
+
/// The fragment A.
CUTLASS_DEVICE
typename StreamA::TransformedFragment const &fragment_a(int step) const {
diff --git a/cutlass/gemm/gemm_traits.h b/cutlass/gemm/gemm_traits.h
index b588de0a98..40194487c6 100644
--- a/cutlass/gemm/gemm_traits.h
+++ b/cutlass/gemm/gemm_traits.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -42,7 +42,7 @@
#include "cutlass/gemm/gemm_operand.h"
#include "cutlass/gemm/gemm_shared_stream.h"
#include "cutlass/gemm/threadblock_swizzle.h"
-#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/gemm_mainloop.h"
namespace cutlass {
namespace gemm {
@@ -359,7 +359,7 @@ struct GemmTraits {
ClearAccumulators_> This_;
/// The struct that consumes this Traits
- typedef typename cutlass::gemm::Gemm KernelClass;
+ typedef typename cutlass::gemm::GemmMainloop KernelClass;
/// The configuration.
typedef GemmConfig_ GemmConfig;
@@ -544,16 +544,26 @@ struct GemmTraits {
/// Helper to construct a partitionedK GEMM params
template
- CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc, Index partitionK_count_) {
+ CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc,
+ Index partitionK_count_,
+ Index partitionK_multiple_ = 1 // each partition will be mulitples of partitionK_multiple_
+ ) {
// partitionK GEMM is a specialized batched stried gemm with different K ranges per batch
// the problem_size of each batch is (lastK_size, n, m)
// add more comments here
// the k range for every batch excpet the last one
//assert(partitionK_count_ > 0);
partitionK_range = partitonK_desc.problem_size.k() / partitionK_count_;
+ partitionK_range = partitionK_range - (partitionK_range % partitionK_multiple_);
// the k range of the last batch
// int lastK_range = (partitonK_desc.problem_size.k() % partitionK_range) + partitionK_range;
int lastK_range = partitonK_desc.problem_size.k() - partitionK_range * (partitionK_count_ - 1);
+
+ assert((partitionK_range % partitionK_multiple_) == 0);
+ assert(partitionK_range > 0);
+ assert((lastK_range % partitionK_multiple_) == 0);
+ assert(lastK_range > 0);
+
int k_size = lastK_range;
int lda = partitonK_desc.A.stride(0);
int ldb = partitonK_desc.B.stride(0);
@@ -641,7 +651,8 @@ struct GemmTraits {
Index ldc,
ScalarD* d_d,
Index ldd,
- Index partitionK_count_) {
+ Index partitionK_count_,
+ Index partitionK_multiple_ = 1) {
GemmDesc desc(
GemmCoord(k, n, m, 1),
@@ -654,7 +665,7 @@ struct GemmTraits {
);
- return this->initialize(desc, partitionK_count_);
+ return this->initialize(desc, partitionK_count_, partitionK_multiple_);
}
};
diff --git a/cutlass/gemm/hgemm_global_tile.h b/cutlass/gemm/hgemm_global_tile.h
index 9d5ffe8508..a3b38151b6 100644
--- a/cutlass/gemm/hgemm_global_tile.h
+++ b/cutlass/gemm/hgemm_global_tile.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/hgemm_multiply_add.h b/cutlass/gemm/hgemm_multiply_add.h
index 7217d82c58..528e9c9e6f 100644
--- a/cutlass/gemm/hgemm_multiply_add.h
+++ b/cutlass/gemm/hgemm_multiply_add.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -29,7 +29,6 @@
#pragma once
#include "cutlass/fragment.h"
-
#include "cutlass/gemm/thread_multiply_add.h"
namespace cutlass {
@@ -66,6 +65,8 @@ struct ThreadMultiplyAdd {
/// Make sure there's an even number of elements in both dimensions.
static_assert(AccumulatorsPerThread::kH % 2 == 0, "Invalid size");
static_assert(AccumulatorsPerThread::kW % 2 == 0, "Invalid size");
+ static_assert(AccumulatorsPerThread::kH >= 2 && AccumulatorsPerThread::kW >= 2,
+ "HGEMM expects at least 2x2 accmulator tiles per thread.");
/// Ctor.
CUTLASS_DEVICE ThreadMultiplyAdd() {}
@@ -84,7 +85,10 @@ struct ThreadMultiplyAdd {
// The output.
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
+ CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < AccumulatorsPerThread::kH / 2; ++j) {
+
+ CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < AccumulatorsPerThread::kW / 2; ++i) {
// The offsets in the output fragment.
int const k0 = (2 * j + 0) * (AccumulatorsPerThread::kW / 2) + i;
diff --git a/cutlass/gemm/hgemm_swizzle.h b/cutlass/gemm/hgemm_swizzle.h
index 2ecd00881e..527d28b48c 100644
--- a/cutlass/gemm/hgemm_swizzle.h
+++ b/cutlass/gemm/hgemm_swizzle.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/hgemm_traits.h b/cutlass/gemm/hgemm_traits.h
index 2261bb4b3e..bf4b01d26c 100644
--- a/cutlass/gemm/hgemm_traits.h
+++ b/cutlass/gemm/hgemm_traits.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -38,7 +38,7 @@
#include "cutlass/gemm/gemm_traits.h"
#include "cutlass/gemm/hgemm_global_tile.h"
#include "cutlass/gemm/hgemm_multiply_add.h"
-#include "cutlass/gemm/hgemm_swizzle.h"
+#include "cutlass/layout/thread/transform.h"
namespace cutlass {
namespace gemm {
@@ -107,7 +107,8 @@ struct HgemmTransformerA {
template
struct HgemmTransformerA {
- typedef HgemmSwizzle Transformer;
+ typedef typename Iterator_::FragmentShape FragmentShape;
+ typedef cutlass::layout::thread::Transform Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -122,7 +123,8 @@ struct HgemmTransformerB {
template
struct HgemmTransformerB {
- typedef HgemmSwizzle Transformer;
+ typedef typename Iterator_::FragmentShape FragmentShape;
+ typedef cutlass::layout::thread::Transform Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/cutlass/gemm/igemm_epilogue.h b/cutlass/gemm/igemm_epilogue.h
index 2ad24f32cc..5bb1aa0888 100644
--- a/cutlass/gemm/igemm_epilogue.h
+++ b/cutlass/gemm/igemm_epilogue.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/igemm_global_tile.h b/cutlass/gemm/igemm_global_tile.h
index 845678a82a..e169933b01 100644
--- a/cutlass/gemm/igemm_global_tile.h
+++ b/cutlass/gemm/igemm_global_tile.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/igemm_multiply_add.h b/cutlass/gemm/igemm_multiply_add.h
index 2b09cba20e..7892850d0f 100644
--- a/cutlass/gemm/igemm_multiply_add.h
+++ b/cutlass/gemm/igemm_multiply_add.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -28,8 +28,9 @@
*/
#pragma once
-#include "cutlass/fragment.h"
+#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610))
+#include "cutlass/fragment.h"
#include "cutlass/gemm/thread_multiply_add.h"
namespace cutlass {
@@ -44,6 +45,11 @@ struct ThreadMultiplyAdd
typedef Shape<4, 1, 1> InstructionShape;
/// Shape of the thread-level GEMM (K-by-N-by-M)
typedef ThreadGemmShape_ ThreadGemmShape;
+
+ /// Thread-level GEMM (N-by-M) must be a multiple of 32.
+ static_assert((ThreadGemmShape::kH * ThreadGemmShape::kW) % 32 == 0,
+ "Thread-level GEMM (N-by-M) must be multiple of 32");
+
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
typedef ThreadGemmShape AccumulatorsPerThread;
/// The number of threads per warp.
@@ -72,19 +78,18 @@ struct ThreadMultiplyAdd
Accumulators const& c,
Accumulators& d) {
- #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)
// The inputs.
int const* a_int = reinterpret_cast(&a[0]);
int const* b_int = reinterpret_cast(&b[0]);
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
+
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
: "=r"(d[j * AccumulatorsPerThread::kW + i])
: "r"(a_int[i]), "r"(b_int[j]), "r"(c[j * AccumulatorsPerThread::kW + i]));
}
}
- #endif
}
};
@@ -92,3 +97,5 @@ struct ThreadMultiplyAdd
} // namespace gemm
} // namespace cutlass
+
+#endif // if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610))
diff --git a/cutlass/gemm/igemm_swizzle.h b/cutlass/gemm/igemm_swizzle.h
index fbb68d1434..0c6a2ffa51 100644
--- a/cutlass/gemm/igemm_swizzle.h
+++ b/cutlass/gemm/igemm_swizzle.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -60,6 +60,7 @@ struct IgemmSwizzle {
/// Transform a fragment.
CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
+
// Expose src/dst as int arrays.
int const* src_int = reinterpret_cast(&src[0]);
int* dst_int = reinterpret_cast(&dst[0]);
diff --git a/cutlass/gemm/igemm_traits.h b/cutlass/gemm/igemm_traits.h
index 5bceeda92e..3f2039afe5 100644
--- a/cutlass/gemm/igemm_traits.h
+++ b/cutlass/gemm/igemm_traits.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -39,7 +39,7 @@
#include "cutlass/gemm/igemm_epilogue.h"
#include "cutlass/gemm/igemm_global_tile.h"
#include "cutlass/gemm/igemm_multiply_add.h"
-#include "cutlass/gemm/igemm_swizzle.h"
+#include "cutlass/layout/thread/transform.h"
#include "cutlass/reshape_tile.h"
namespace cutlass {
@@ -90,9 +90,10 @@ struct IgemmConfig : public GemmConfig<
/// kResidueSeparate
false,
/// kResidueInPrologue
- false,
+ true,
/// kLaunchBounds
- false> {};
+ false>
+{};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -380,7 +381,8 @@ struct IgemmTransformerA {
template
struct IgemmTransformerA {
- typedef IgemmSwizzle Transformer;
+ typedef typename Iterator_::FragmentShape FragmentShape;
+ typedef cutlass::layout::thread::Transform Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -395,7 +397,8 @@ struct IgemmTransformerB {
template
struct IgemmTransformerB {
- typedef IgemmSwizzle Transformer;
+ typedef typename Iterator_::FragmentShape FragmentShape;
+ typedef cutlass::layout::thread::Transform Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/cutlass/gemm/linear_scaling.h b/cutlass/gemm/linear_scaling.h
index a12fc5f19f..e747b218c4 100644
--- a/cutlass/gemm/linear_scaling.h
+++ b/cutlass/gemm/linear_scaling.h
@@ -1,6 +1,5 @@
-
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/linear_scaling_device_ptr.h b/cutlass/gemm/linear_scaling_device_ptr.h
index 5dc845da4a..928b9700ec 100644
--- a/cutlass/gemm/linear_scaling_device_ptr.h
+++ b/cutlass/gemm/linear_scaling_device_ptr.h
@@ -1,5 +1,6 @@
+
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/mma_epilogue.h b/cutlass/gemm/mma_epilogue.h
new file mode 100644
index 0000000000..0e7004d73f
--- /dev/null
+++ b/cutlass/gemm/mma_epilogue.h
@@ -0,0 +1,284 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
+ with
+ the computed matrix product.
+*/
+
+#pragma once
+
+// clang-format off
+
+#include "cutlass/coord.h"
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct MMAEpilogue {
+ /// The traits class.
+ typedef EpilogueTraits_ Traits;
+
+ /// The params.
+ typedef typename Traits::Params Params;
+
+ /// The shared storage.
+ typedef typename Traits::SharedStorage SharedStorage;
+
+ /// Defines a tiling of the EpilogueTile over the entire threadblock GEMM tile
+ typedef typename Traits::Iterations Iterations;
+
+ /// The output tile.
+ typedef typename Traits::OutputTile OutputTile;
+
+ /// Accumulators to store in the epilogue
+ typedef typename Traits::Accumulators Accumulators;
+
+ /// A functor to copy a slice of accumulators for a given epilogue iteration
+ typedef typename Traits::SelectAccumulators SelectAccumulators;
+
+ /// The iterator to load source matrix from global memory.
+ typedef typename Traits::GlobalLoadStreamC GlobalLoadStreamC;
+
+ /// The iterator to store the final GEMM computation to global memory.
+ typedef typename Traits::GlobalStoreStreamD GlobalStoreStreamD;
+
+ /// The stream to store matrix product to shared memory
+ typedef typename Traits::SharedStoreStreamD SharedStoreStreamD;
+
+ /// The stream to load the matrix product from shared memory
+ typedef typename Traits::SharedLoadStreamD SharedLoadStreamD;
+
+ /// The functor in charge of the math.
+ typedef typename Traits::Functor Functor;
+
+ /// The scalar type used by the epilogue functor.
+ typedef typename Functor::Scalar Scalar;
+
+ /// The scalar type of the source accumulator matrix.
+ typedef typename Traits::ScalarC ScalarC;
+
+ /// The scalar type of the destination accumulator matrix.
+ typedef typename Traits::ScalarD ScalarD;
+
+ /// The index type.
+ typedef typename Traits::Index Index;
+
+ /// Functor computing the offset from the threadblock origin per iteration of
+ /// the epilogue.
+ typedef typename Traits::GlobalOffset GlobalOffset;
+
+ ///
+ typedef typename Traits::GlobalDataLayout GlobalDataLayout;
+
+ //
+ // Data members
+ //
+
+ /// The params.
+ Params const& params;
+
+ /// The shared storage.
+ SharedStorage& shared_storage;
+
+ /// The dimensions of the GEMM.
+ gemm::GemmCoord problem_size;
+
+ /// Epilogue functor
+ Functor functor;
+
+ // Functor to select a set of accumulators
+ SelectAccumulators select_accumulators;
+
+
+ // Functor to compute the global offset relative to the threadblock for each iteration
+ // of the epilogue.
+ GlobalOffset global_offset;
+
+ //
+ // Methods
+ //
+
+ /// Ctor.
+ CUTLASS_DEVICE MMAEpilogue(
+ Params const& params_,
+ SharedStorage& shared_storage_,
+ Coord<3> const& _problem_size,
+ SelectAccumulators _select_accumulators = SelectAccumulators(),
+ GlobalOffset _global_offset = GlobalOffset()
+ ):
+ params(params_),
+ shared_storage(shared_storage_),
+ problem_size(_problem_size),
+ functor(params_.functor),
+ select_accumulators(_select_accumulators),
+ global_offset(_global_offset) {}
+
+ /// Execute the epilogue.
+ CUTLASS_DEVICE void epilogue(
+ Accumulators& accumulators,
+ Coord<3> const& threadblock_offset = make_Coord(0, 0, 0),
+ int batch_id = 0) {
+
+ if (functor.source_required()) {
+ epilogue_with_or_without_beta(accumulators, threadblock_offset, batch_id);
+ }
+ else {
+ epilogue_with_or_without_beta(accumulators, threadblock_offset, batch_id);
+ }
+ }
+
+ ///
+
+ /// Execute the epilogue.
+ template
+ CUTLASS_DEVICE void epilogue_with_or_without_beta(
+ Accumulators& accumulators,
+ Coord<3> const& threadblock_offset = make_Coord(0, 0, 0),
+ int batch_id = 0) {
+
+ /// Global memory mapping function
+ GlobalDataLayout gmem_map_func;
+
+ // Construct shared memory streams
+ SharedStoreStreamD shared_store_stream(
+ params.shared_store_stream_d,
+ shared_storage.reference());
+
+ SharedLoadStreamD shared_load_stream(
+ params.shared_load_stream_d,
+ shared_storage.reference());
+
+ // Map the GEMM problem dimensions into the coordinate system of the output memory
+ Coord<2> gmem_bounds = gmem_map_func(make_Coord(
+ problem_size.m(), // GEMM M - rows
+ problem_size.n())); // GEMM N - columns
+
+ Coord<3> gmem_tile_bounds = make_Coord(
+ problem_size.k(), // GEMM K
+ gmem_bounds[0], // strided
+ gmem_bounds[1]); // contiguous
+
+ // Iterate over the entire Threadblock tile
+ CUTLASS_PRAGMA_UNROLL
+ for (int h = 0; h < Iterations::kH; ++h) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int w = 0; w < Iterations::kW; ++w) {
+ if (!(h == 0)) {
+ //continue;
+ }
+
+ // Offset in GEMM coordinates
+ gemm::GemmCoord offset_in_gemm = threadblock_offset + global_offset(make_Coord(h, w));
+
+ Coord<2> offset_in_memory = gmem_map_func(
+ make_Coord(
+ offset_in_gemm.m(), // GEMM M - rows
+ offset_in_gemm.n())); // GEMM N - columns
+
+ // Offset in
+ Coord<3> global_tile_offset = make_Coord(
+ offset_in_gemm.k(), // GEMM K
+ offset_in_memory[0], // strided
+ offset_in_memory[1]); // contiguous
+
+ GlobalLoadStreamC global_load_stream(
+ params.load_stream_c,
+ gmem_tile_bounds,
+ global_tile_offset);
+
+ GlobalStoreStreamD global_store_stream(
+ params.store_stream_d,
+ gmem_tile_bounds,
+ global_tile_offset);
+
+ // update C pointer offset based on batch_id and batch_stride_offset
+ global_load_stream.iterator.add_pointer_offset(batch_id * params.batch_stride_C);
+
+ // update D pointer offset based on batch_id and batch_stride_offset
+ global_store_stream.iterator.add_pointer_offset(batch_id * params.batch_stride_D);
+
+ // Load the C matrix into fragment.
+ if (kSourceRequired) {
+ global_load_stream.copy();
+ }
+
+ // Make sure we can write to shared memory.
+ shared_load_fence();
+
+ // Store accumulator tile to shared memory
+ shared_store_stream.copy(
+ select_accumulators(accumulators, make_Coord(h, w)));
+
+ shared_store_stream.commit();
+
+ // Make sure the data is in shared memory.
+ shared_store_fence();
+
+ // Load the accumulators back to registers from shared memory.
+ shared_load_stream.copy();
+ shared_load_stream.commit();
+ // Commit the C matrix fragment
+ if (kSourceRequired) {
+ global_load_stream.commit();
+ }
+
+ // Apply epilogue functor
+ if (kSourceRequired) {
+
+ functor.evaluate(shared_load_stream.fragment(),
+ global_load_stream.fragment(),
+ global_store_stream.fragment());
+ }
+ else {
+
+ functor.evaluate(
+ shared_load_stream.fragment(),
+ global_store_stream.fragment());
+ }
+
+ global_store_stream.copy();
+ global_store_stream.commit();
+ }
+ }
+ }
+
+ /// The memory fence for shared loads.
+ CUTLASS_DEVICE void shared_load_fence() { __syncthreads(); }
+
+ /// The memory fence for shared stores.
+ CUTLASS_DEVICE void shared_store_fence() { __syncthreads(); }
+
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // gemm
+} // namespace cutlass
+
+// clang-format on
diff --git a/cutlass/gemm/mma_global_stream.h b/cutlass/gemm/mma_global_stream.h
new file mode 100644
index 0000000000..c83c154a5b
--- /dev/null
+++ b/cutlass/gemm/mma_global_stream.h
@@ -0,0 +1,360 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Implements efficient loading of the thread block-level tile from global memory and
+ storing to shared memory.
+*/
+
+#pragma once
+
+// clang-format off
+
+#include "cutlass/convert.h"
+#include "cutlass/gemm/gemm_operand.h"
+#include "cutlass/predicate_vector.h"
+#include "cutlass/tile_allocation.h"
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+///! Stream adapter for loading threadblock-scoped GEMM tiles and storing to shared memory
+template <
+ /// Identifies multiplicand
+ GemmOperand::Kind Operand,
+ /// Layout of source matrix in global memory
+ MatrixLayout::Kind Layout,
+ /// Iterator for loading threadblock-scoped tiles
+ typename LoadIterator_,
+ /// Transformation functor for transforming fragments
+ typename Transformer_,
+ /// Iterator for storing threadblock-scoped tiles to shared memory
+ typename StoreIterator_,
+ /// Number of stores before iterator wraps - zero indicates no wrapping
+ int StageCount>
+struct MMAGlobalLoadStream {
+ //
+ // Type definitions
+ //
+
+ /// Identifies the operand
+ static GemmOperand::Kind const kOperand = Operand;
+ /// The layout.
+ static MatrixLayout::Kind const kLayout = Layout;
+ /// The load iterator.
+ typedef LoadIterator_ LoadIterator;
+ /// The transformer.
+ typedef Transformer_ Transformer;
+ /// The store iterator to write to shared memory.
+ typedef StoreIterator_ StoreIterator;
+ /// Number of stages
+ static int const kStageCount = StageCount;
+
+ /// Predicate vector
+ typedef typename LoadIterator::PredicateVector PredicateVector;
+ /// The fragment that is copied from shared memory.
+ typedef typename LoadIterator::Fragment FetchedFragment;
+ /// The fragment that is obtained after the transformation by the transformer.
+ typedef typename Transformer::OutputFragment TransformedFragment;
+ /// Make sure the fragments match.
+ static_assert((platform::is_same::value),
+ "");
+ /// The output fragment.
+ typedef TransformedFragment Fragment;
+ /// Make sure the transformed fragment is the same as the store fragment.
+ static_assert((platform::is_same::value),
+ "");
+
+ /// The scalar type of the iterator.
+ typedef typename LoadIterator::Scalar Scalar;
+ /// The pointer.
+ typedef typename LoadIterator::Pointer Pointer;
+ /// The index.
+ typedef typename LoadIterator::Index Index;
+ /// The index.
+ typedef typename LoadIterator::LongIndex LongIndex;
+ /// The tile.
+ typedef typename LoadIterator::Tile Tile;
+
+ /// The params.
+ struct Params {
+
+ /// Helper
+ static int const kElementsPerLdg = LoadIterator::Tile::kC;
+
+ //
+ // Data members
+ //
+
+ /// The load iterator.
+ typename LoadIterator::Params load_iterator;
+
+ /// Stride within a batch of matrix operands
+ LongIndex batch_stride;
+
+ // Offset to residue.
+ Index offset_to_residue;
+
+ // Offset to residue for the last partition
+ Index offset_to_residue_last_partition;
+
+ //
+ // Methods
+ //
+
+ CUTLASS_HOST_DEVICE
+ Params(): batch_stride(0), offset_to_residue(0), offset_to_residue_last_partition(0) {}
+
+ /// Constructor
+ CUTLASS_HOST_DEVICE
+ Params(
+ TensorRef const &ref,
+ Index _offset_to_residue
+ ):
+ batch_stride(0),
+ offset_to_residue(_offset_to_residue),
+ offset_to_residue_last_partition(0),
+ load_iterator(
+ TensorRef(
+ ref.data(),
+ make_Coord(ref.stride(0) * kElementsPerLdg, ref.stride(0), kElementsPerLdg, 1)
+ )
+ ) {}
+
+ /// Initializer
+ CUTLASS_HOST_DEVICE
+ int initialize(
+ TensorRef const &ref,
+ LongIndex batch_stride_,
+ Index offset_to_residue_,
+ Index offset_to_residue_last_partition_) {
+
+ batch_stride = batch_stride_;
+ offset_to_residue = offset_to_residue_;
+ offset_to_residue_last_partition = offset_to_residue_last_partition_;
+
+ return load_iterator.initialize(
+ TensorRef(
+ ref.data(),
+ make_Coord(static_cast(batch_stride), ref.stride(0), kElementsPerLdg, 1)
+ )
+ );
+ }
+
+ CUTLASS_HOST_DEVICE
+ int initialize(
+ TensorRef const &ref,
+ Index offset_to_residue_) {
+
+ offset_to_residue = offset_to_residue_;
+ return load_iterator.initialize(
+ TensorRef(
+ ref.data(),
+ make_Coord(ref.stride(0) * kElementsPerLdg, ref.stride(0), kElementsPerLdg, 1)
+ )
+ );
+ }
+
+ CUTLASS_HOST_DEVICE int initialize(Index offset_to_residue_) {
+ offset_to_residue = offset_to_residue_;
+ return 0;
+ }
+
+ CUTLASS_DEVICE Index get_offset_to_residue() {
+ if (blockIdx.z == gridDim.z - 1) { //last partition
+ return offset_to_residue_last_partition;
+ }
+ else {
+ return offset_to_residue;
+ }
+ }
+ };
+
+ /// Empty shared storage
+ struct SharedStorage {};
+
+ /// Shared memory allocation for the tile
+ typedef TileAllocation<
+ typename StoreIterator::Scalar,
+ typename ShapeMul<
+ typename StoreIterator::OperandShape,
+ Shape
+ >::Shape
+ > ThreadblockTileStorage;
+
+ /// ZipTensorRef to threadblock tiles
+ typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
+
+ //
+ // Data members
+ //
+
+ ///! The parameters
+ Params params;
+
+ ///! Dimensions of global memory tile
+ Coord<3> threadblock_offset;
+
+ ///! Dimensions of multiplicand bounds
+ Coord<3> multiplicand_bounds;
+
+ ///! Iterator to load threadblock tiles from global memory
+ LoadIterator load_iterator;
+
+ ///! Predicate vector
+ PredicateVector predicates;
+
+ ///! The fragment to fetch from shared memory.
+ FetchedFragment fetched_fragment;
+
+ ///! Functor to transform fragments after they have been loaded
+ Transformer transformer;
+
+ ///! The fragment to convert the data after it has been fetched from shared memory.
+ TransformedFragment transformed_fragment;
+
+ ///! Iterator to store threadblock tiles to shared memory
+ StoreIterator store_iterator;
+
+ ///! Counter
+ int stage_index;
+
+ //
+ // Static member functions
+ //
+
+ /// Maps a coordinate in the GEMM's (K, N, M) coordinate system to global memory
+ CUTLASS_HOST_DEVICE
+ static Coord<3> project_coordinate(Coord<3> const &coord, Index d_offset = 0) {
+ bool const kKstrided =
+ gemm::GemmMultiplicandTraits::kKstrided;
+
+ Coord<3> tile_coord = gemm::ProjectOperand::project(coord);
+
+ return make_Coord(
+ tile_coord[0] + d_offset, tile_coord[1], tile_coord[2] / LoadIterator::Tile::kC);
+ }
+
+ //
+ // Methods
+ //
+
+ /// Constructor
+ CUTLASS_DEVICE MMAGlobalLoadStream(Params const &_params,
+ SharedStorage &shared_storage,
+ ThreadblockTileRef const &threadblock_tile_ref,
+ Coord<3> const bounds,
+ Coord<3> const &block)
+ : params(_params),
+ threadblock_offset(project_coordinate(block)),
+ multiplicand_bounds(project_coordinate(bounds, 1)),
+ load_iterator(params.load_iterator, threadblock_offset),
+ transformer(),
+ store_iterator(threadblock_tile_ref.data()),
+ stage_index(0) {
+ load_iterator.initialize_predicates(
+ predicates.begin(), multiplicand_bounds, threadblock_offset);
+ }
+
+ /// Loads the data from global memory
+ CUTLASS_DEVICE void copy() {
+ load_iterator.load_post_increment(fetched_fragment, predicates.begin());
+ }
+
+ /// Transform and commit the data to shared memory
+ CUTLASS_DEVICE void commit() {
+ transformer.transform(fetched_fragment, transformed_fragment);
+ store_iterator.store_post_increment(transformed_fragment);
+
+ ++stage_index;
+ if (kStageCount && stage_index == kStageCount) {
+ store_iterator -= kStageCount;
+ stage_index = 0;
+ }
+ }
+
+ /// Computes a predicate mask for loads during final threadblock tile load iteration
+ CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
+ // That's the residue!
+ Coord<3> _block_offset = threadblock_offset;
+ if (kOperand == GemmOperand::kA ^ kLayout == MatrixLayout::kRowMajor) {
+ // K-strided
+ _block_offset =
+ make_Coord(threadblock_offset[0], multiplicand_bounds[1] - k, threadblock_offset[2]);
+ } else {
+ // K-contiguous
+ _block_offset = make_Coord(threadblock_offset[0],
+ threadblock_offset[1],
+ multiplicand_bounds[2] - k / LoadIterator::Tile::kC);
+ }
+
+ load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, _block_offset);
+ fetched_fragment.clear();
+ }
+
+ /// Move to the residue portion.
+ CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {
+ Index kResidue = k % kTileK;
+ if (kResidue) {
+ residue(kResidue);
+ Index this_offset_residue = params.get_offset_to_residue();
+ load_iterator.add_pointer_offset(this_offset_residue * load_iterator.stride_advance());
+ }
+ }
+
+ /// Rollback to the beginning of the first tile
+ CUTLASS_DEVICE void rollback(void) {
+ load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, threadblock_offset);
+
+ int const kBlock = kOperand == GemmOperand::kA
+ ? (kLayout == MatrixLayout::kColumnMajor ? Tile::kH : Tile::kW)
+ : (kLayout == MatrixLayout::kRowMajor ? Tile::kH : Tile::kW);
+ Index this_offset_residue = params.get_offset_to_residue();
+ load_iterator.add_pointer_offset(-(this_offset_residue + kBlock) *
+ load_iterator.stride_advance());
+ }
+
+ /// Adds a Coord<3> to the underlying global load iterator
+ CUTLASS_DEVICE MMAGlobalLoadStream &operator+=(Coord<3> const &offset) {
+ load_iterator += offset;
+ return *this;
+ }
+
+ /// Adds an offset based on batch stride
+ CUTLASS_DEVICE MMAGlobalLoadStream &add_batch_offset(int batch_id) {
+ load_iterator.add_pointer_offset(batch_id * params.batch_stride);
+ return *this;
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // gemm
+} // namespace cutlass
+
+// clang-format on
diff --git a/cutlass/gemm/mma_global_tile.h b/cutlass/gemm/mma_global_tile.h
new file mode 100644
index 0000000000..ae3a91250c
--- /dev/null
+++ b/cutlass/gemm/mma_global_tile.h
@@ -0,0 +1,201 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
+*/
+
+#pragma once
+
+#include "cutlass/coord.h"
+#include "cutlass/gemm/gemm_operand.h"
+#include "cutlass/reshape_tile.h"
+#include "cutlass/tile_iterator.h"
+#include "cutlass/util/platform.h"
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+//
+// Iterators used to load multiplicands from global memory specialized for Volta884 access patterns
+//
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Iterator for loading data for congruous access patterns
+template
+struct MMAThreadblockCongruousLoad {
+ /// Identifies multiplicand of GEMM (A or B)
+ static GemmOperand::Kind const kOperand = Operand;
+
+ /// Specifies layout of data in source memory
+ static MatrixLayout::Kind const kLayout =
+ (Operand == GemmOperand::kA ? MatrixLayout::kColumnMajor : MatrixLayout::kRowMajor);
+
+ /// Shape of thread-block multiplicand
+ typedef Tile_ Tile;
+
+ /// Number of participating warps
+ static int const kWarpCount = WarpCount;
+
+ /// Delta between warp accumulator tiles along the outer dimension
+ static int const kWarpDelta = WarpDelta;
+
+ /// This implementation is specialized for 128b loads
+ static int const kAccessSize = 8;
+
+ /// Projects the threadblock tile
+ typedef typename gemm::GemmMultiplicandTraits::Shape OperandShape;
+
+ /// Reshapes the threadblock tile by access size
+ typedef typename ReshapeTile::Tile VectorizedShape;
+
+ /// Shape of tile
+ typedef Shape<1, 4, 8> WarpStoreCoverage;
+
+ /// Shape of tile loaded by each warp per load operation
+ typedef Shape<1, 4, 8> WarpLoadShape;
+
+ //
+ // Load iterator
+ //
+
+ ///
+ typedef Shape<1, WarpLoadShape::kH * kWarpCount, WarpLoadShape::kW> Delta;
+
+ typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
+
+ /// Rakes warps along contiguous dimensions and strip-mines strided
+ /// dimension.
+ typedef Shape<1,
+ VectorizedShape::kH / WarpStoreCoverage::kH / WarpCount,
+ VectorizedShape::kW / WarpStoreCoverage::kW,
+ 1>
+ Iterations;
+
+ /// Functor computing starting offset for each thread
+ struct ThreadOffset {
+ __device__ Coord<4> operator()() const {
+ int warp_id = (threadIdx.x >> 5);
+ int lane_id = (threadIdx.x & 0x1f);
+
+ int lane_k = lane_id / WarpLoadShape::kW;
+ int lane_outer = lane_id % WarpLoadShape::kW;
+
+ Coord<4> offset = make_Coord(0, warp_id * WarpLoadShape::kH + lane_k, lane_outer, 0);
+
+ return offset;
+ }
+ };
+
+ /// Source tile traits
+ typedef TileTraits LoadTileTraits;
+
+ /// Load iterator
+ typedef TileLoadIterator Iterator;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Iterator for loading data for congruous access patterns
+template
+struct MMAThreadblockCrosswiseLoad {
+ /// Identifies multiplicand of GEMM (A or B)
+ static GemmOperand::Kind const kOperand = Operand;
+
+ /// Specifies layout of data in source memory
+ static MatrixLayout::Kind const kLayout =
+ (Operand == GemmOperand::kA ? MatrixLayout::kRowMajor : MatrixLayout::kColumnMajor);
+
+ /// Shape of thread-block multiplicand
+ typedef Tile_ Tile;
+
+ /// Number of participating warps
+ static int const kWarpCount = WarpCount;
+
+ /// Delta between warp accumulator tiles along the outer dimension
+ static int const kWarpDelta = WarpDelta;
+
+ /// This implementation is specialized for 128b loads
+ static int const kAccessSize = 8;
+
+ /// Projects the threadblock tile
+ typedef typename gemm::GemmMultiplicandTraits::Shape OperandShape;
+
+ /// Reshapes the threadblock tile by access size
+ typedef typename ReshapeTile::Tile VectorizedShape;
+
+ /// Shape of tile
+ typedef Shape<1, 8, 4> WarpStoreCoverage;
+
+ /// Shape of tile loaded by each warp per load operation
+ typedef Shape<1, 8, 4> WarpLoadShape;
+
+ //
+ // Load iterator
+ //
+
+ ///
+ typedef Shape<1, WarpLoadShape::kH, WarpLoadShape::kW> Delta;
+
+ typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
+
+ /// Rakes warps along contiguous dimensions and strip-mines strided
+ /// dimension.
+ typedef Shape<1,
+ VectorizedShape::kH / WarpStoreCoverage::kH / WarpCount,
+ VectorizedShape::kW / WarpStoreCoverage::kW,
+ 1>
+ Iterations;
+
+ /// Functor computing starting offset for each thread
+ struct ThreadOffset {
+ __device__ Coord<4> operator()() const {
+
+ int warp_id = (threadIdx.x >> 5);
+ int lane_id = (threadIdx.x & 0x1f);
+
+ int lane_k = lane_id % WarpLoadShape::kW;
+ int lane_outer = lane_id / WarpLoadShape::kW;
+
+ Coord<4> offset =
+ make_Coord(0, warp_id * Iterations::kH * WarpLoadShape::kH + lane_outer, lane_k, 0);
+
+ return offset;
+ }
+ };
+
+ /// Source tile traits
+ typedef TileTraits LoadTileTraits;
+
+ /// Load iterator
+ typedef TileLoadIterator Iterator;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // gemm
+} // namespace cutlass
diff --git a/cutlass/gemm/mma_shared_stream.h b/cutlass/gemm/mma_shared_stream.h
new file mode 100644
index 0000000000..af11ebab1e
--- /dev/null
+++ b/cutlass/gemm/mma_shared_stream.h
@@ -0,0 +1,155 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Implements efficient loading of the thread block-level tile from global memory and
+ storing to shared memory.
+*/
+
+#pragma once
+
+#include "cutlass/convert.h"
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Stream from shared memory to fragments for warp-level matrix multiply-accumulate
+template <
+ /// The load iterator.
+ typename Iterator_,
+ /// The transformer to be applied after the data has been copied from shared memory.
+ typename Transformer_ = Copy,
+ /// Number of increments before iterator wraps - zero indicates no wrapping
+ int StageCount = 1>
+struct MMASharedLoadStream {
+ /// The load iterator.
+ typedef Iterator_ Iterator;
+ /// The transformer.
+ typedef Transformer_ Transformer;
+
+ /// Number of increments before iterator wraps - zero indicates no wrapping
+ static int const kStageCount = StageCount;
+
+ /// The fragment that is copied from shared memory.
+ typedef typename Iterator::Fragment FetchedFragment;
+ /// The fragment that is obtained after the transformation by the transformer.
+ typedef typename Transformer::OutputFragment TransformedFragment;
+ /// Make sure the fragments match.
+ static_assert((platform::is_same::value),
+ "");
+ /// The output fragment.
+ typedef TransformedFragment Fragment;
+
+ /// Element type
+ typedef typename Iterator::Scalar Scalar;
+
+ /// Reference type to a tensor
+ typedef TensorRef TensorRef;
+
+ /// Parameters passed from host
+ struct Params {};
+
+ //
+ // Data members
+ //
+
+ /// Iterator for loading fragments for warp-level matrix multiply-accumulate
+ Iterator iterator;
+
+ /// Fetched fragment
+ FetchedFragment fetched[2];
+
+ /// The transformer.
+ Transformer transformer;
+
+ /// Transformed fragment
+ TransformedFragment transformed[2];
+
+ /// Counts the number of stages
+ int stage_index;
+
+ //
+ // Methods
+ //
+
+ /// Ctor.
+ CUTLASS_DEVICE MMASharedLoadStream() : stage_index(0) {}
+
+ /// Ctor.
+ CUTLASS_DEVICE MMASharedLoadStream(
+ Params const &_params,
+ TensorRef const &ref,
+ Coord<4> warp_offset = make_Coord(0, 0, 0, 0)
+ ):
+ iterator(ref.data(), warp_offset), stage_index(0) {
+
+ }
+
+ /// Load the data from shared memory to the fetch fragment.
+ CUTLASS_DEVICE void copy(int step) {
+ iterator.load(
+ fetched[step % 2],
+ make_Coord(step + stage_index * Iterator::VectorizedShape::kD, 0, 0, 0)
+ );
+ }
+
+ /// Commit the data.
+ CUTLASS_DEVICE void commit(int step) {
+ transformer.transform(fetched[step % 2], transformed[step % 2]);
+ }
+
+ ///
+ CUTLASS_DEVICE void clear() {
+ fetched[0].clear();
+ fetched[1].clear();
+ transformed[0].clear();
+ transformed[1].clear();
+ }
+
+ /// Gets the transformed fragment
+ CUTLASS_DEVICE
+ TransformedFragment &fragment(int step) { return transformed[step % 2]; }
+
+ /// Gets the transformed fragment
+ CUTLASS_DEVICE
+ TransformedFragment const &fragment(int step) const { return transformed[step % 2]; }
+
+ /// Increment the stage.
+ CUTLASS_DEVICE void inc_stage() {
+
+ ++stage_index;
+ if (kStageCount && stage_index == StageCount) {
+ stage_index = 0;
+ }
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // gemm
+} // namespace cutlass
diff --git a/cutlass/gemm/scalar_or_pointer.h b/cutlass/gemm/scalar_or_pointer.h
index 7c4b4b75d0..9e29295141 100644
--- a/cutlass/gemm/scalar_or_pointer.h
+++ b/cutlass/gemm/scalar_or_pointer.h
@@ -1,6 +1,5 @@
-
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/sgemm_traits.h b/cutlass/gemm/sgemm_traits.h
index 8ce7f58e26..6c54756e30 100644
--- a/cutlass/gemm/sgemm_traits.h
+++ b/cutlass/gemm/sgemm_traits.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/thread_multiply_add.h b/cutlass/gemm/thread_multiply_add.h
index b95dee58a0..784377ae5d 100644
--- a/cutlass/gemm/thread_multiply_add.h
+++ b/cutlass/gemm/thread_multiply_add.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -73,16 +73,27 @@ struct ThreadMultiplyAdd {
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
+
if(kLayout_ == MatrixLayout::kColumnMajor) {
+
+ CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
+
+ CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
+
d[j * AccumulatorsPerThread::kW + i] = a[i] * b[j] + c[j * AccumulatorsPerThread::kW + i];
}
}
}
else {
+
+ CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < AccumulatorsPerThread::kW; ++i) {
+
+ CUTLASS_PRAGMA_UNROLL
for(int j = 0; j < AccumulatorsPerThread::kH; ++j) {
+
d[i * AccumulatorsPerThread::kH + j] = a[i] * b[j] + c[i * AccumulatorsPerThread::kH + j];
}
}
diff --git a/cutlass/gemm/threadblock_swizzle.h b/cutlass/gemm/threadblock_swizzle.h
index eab8595a68..737b89a9cf 100644
--- a/cutlass/gemm/threadblock_swizzle.h
+++ b/cutlass/gemm/threadblock_swizzle.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cutlass/gemm/volta884_complex_gemm_epilogue_traits.h b/cutlass/gemm/volta884_complex_gemm_epilogue_traits.h
new file mode 100644
index 0000000000..059dcaeeeb
--- /dev/null
+++ b/cutlass/gemm/volta884_complex_gemm_epilogue_traits.h
@@ -0,0 +1,348 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
+ with the computed matrix product.
+*/
+
+#pragma once
+
+// clang-format off
+
+#include "cutlass/zip_fragment.h"
+#include "cutlass/zip_tile_iterator.h"
+#include "cutlass/util/complex.h"
+#include "cutlass/gemm/volta884_gemm_epilogue_traits.h"
+#include "cutlass/gemm/split_complex_linear_scaling.h"
+#include "cutlass/util/pair.h"
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Enables treating the accumulators selection as one object
+template
+struct ZipSelectAccumulators {
+
+ /// Underlying selection function
+ typedef First_ First;
+ typedef Second_ Second;
+
+ /// Accumulators
+ typedef ZipFragment<
+ typename First::Accumulators,
+ typename Second::Accumulators> Accumulators;
+
+ /// Fragment
+ typedef ZipFragment<
+ typename First::Fragment,
+ typename Second::Fragment> Fragment;
+
+ //
+ // Data members
+ //
+
+ /// Selects the accumulators for the first part
+ First first;
+
+ /// Selects the accumulators for the second
+ Second second;
+
+ //
+ // Methods
+ //
+
+ /// Default ctor
+ CUTLASS_DEVICE
+ ZipSelectAccumulators() { }
+
+ /// Basic constructor
+ CUTLASS_DEVICE
+ ZipSelectAccumulators(First const &_first, Second const &_second): first(_first), second(_second) { }
+
+ /// Selects accumulators for a given iteration of the epilogue
+ CUTLASS_DEVICE
+ Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const {
+ return make_ZipFragment(first(accum.first, idx), second(accum.second, idx));
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Defines epilogue traits for complex-valued mma.sync GEMM
+template <
+ typename GemmConfig_,
+ typename EpilogueFunctor_ = SplitComplexLinearScaling,
+ typename Index_ = int>
+struct Volta884ComplexGemmEpilogueTraits {
+
+ /// GEMM configuration
+ typedef GemmConfig_ GemmConfig;
+
+ /// Epilogue functor
+ typedef EpilogueFunctor_ Functor;
+
+ /// Global memory mapping function
+ typedef MatrixLayout::ColumnMajor GlobalDataLayout;
+
+ /// Index type
+ typedef Index_ Index;
+
+ /// Long index used for offsets
+ typedef long long LongIndex;
+
+ /// Defines epilogue traits for real-valued Volta884 GEMM epilogue
+ typedef typename Volta884GemmEpilogueTraitsHelper<
+ GemmConfig,
+ Functor,
+ typename GemmConfig::MultiplyAdd::RealMultiplyAdd,
+ Index>::EpilogueTraits RealEpilogueTraits;
+
+ /// The output tile.
+ typedef typename RealEpilogueTraits::OutputTile OutputTile;
+
+ /// The warp-level GEMM tile
+ typedef typename RealEpilogueTraits::WarpGemmTile WarpGemmTile;
+
+ /// Tiling of warp accumulator elements
+ typedef typename RealEpilogueTraits::WarpGemmTile WarpDelta;
+
+ /// Multiply-add operation
+ typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
+
+ /// The accumulators fragment type.
+ typedef typename MultiplyAdd::Accumulators Accumulators;
+
+ /// Selects a subset of accumulators for a given epilogue iteration
+ typedef ZipSelectAccumulators<
+ typename RealEpilogueTraits::SelectAccumulators,
+ typename RealEpilogueTraits::SelectAccumulators> SelectAccumulators;
+
+ /// The iterator to load source matrix from global memory.
+ typedef cutlass::PredicatedTileLoadStream<
+ ZipTileIterator<
+ typename RealEpilogueTraits::GlobalLoadStreamC::Iterator,
+ typename RealEpilogueTraits::GlobalLoadStreamC::Iterator
+ >,
+ typename RealEpilogueTraits::GlobalLoadStreamC::PredicateFunctor,
+ ZipConvert<
+ typename RealEpilogueTraits::GlobalLoadStreamC::Transformer,
+ typename RealEpilogueTraits::GlobalLoadStreamC::Transformer
+ >
+ > GlobalLoadStreamC;
+
+ /// The iterator to store the final GEMM computation to global memory.
+ typedef cutlass::PredicatedTileStoreStream<
+ ZipTileIterator<
+ typename RealEpilogueTraits::GlobalStoreStreamD::Iterator,
+ typename RealEpilogueTraits::GlobalStoreStreamD::Iterator
+ >,
+ typename RealEpilogueTraits::GlobalStoreStreamD::PredicateFunctor,
+ ZipConvert<
+ typename RealEpilogueTraits::GlobalStoreStreamD::Transformer,
+ typename RealEpilogueTraits::GlobalStoreStreamD::Transformer
+ >
+ > GlobalStoreStreamD;
+
+ /// The stream to store matrix product to shared memory
+ typedef cutlass::TileStoreStream<
+ ZipTileIterator<
+ typename RealEpilogueTraits::SharedStoreStreamD::Iterator,
+ typename RealEpilogueTraits::SharedStoreStreamD::Iterator
+ >,
+ ZipConvert<
+ typename RealEpilogueTraits::SharedStoreStreamD::Transformer,
+ typename RealEpilogueTraits::SharedStoreStreamD::Transformer
+ >
+ > SharedStoreStreamD;
+
+ /// The stream to load the matrix product from shared memory
+ typedef cutlass::TileLoadStream<
+ ZipTileIterator<
+ typename RealEpilogueTraits::SharedLoadStreamD::Iterator,
+ typename RealEpilogueTraits::SharedLoadStreamD::Iterator
+ >,
+ ZipConvert<
+ typename RealEpilogueTraits::SharedLoadStreamD::Transformer,
+ typename RealEpilogueTraits::SharedLoadStreamD::Transformer
+ >
+ > SharedLoadStreamD;
+
+ /// The scalar type of the source accumulator matrix.
+ typedef typename RealEpilogueTraits::ScalarC ScalarC;
+
+ /// The scalar type of the destination accumulator matrix.
+ typedef typename RealEpilogueTraits::ScalarD ScalarD;
+
+ //
+ // Dependent types
+ //
+
+ /// Cover an entire warp-level tile
+ typedef typename RealEpilogueTraits::Iterations Iterations;
+
+ /// Parameters structure initialized on the host
+ struct Params {
+ /// The params for the C iterator.
+ typename GlobalLoadStreamC::Params load_stream_c;
+
+ /// The params for the D global iterator.
+ typename GlobalStoreStreamD::Params store_stream_d;
+
+ /// Epilogue functor params
+ typename Functor::Params functor;
+
+ /// The params for the D shared store iterator.
+ typename SharedStoreStreamD::Params shared_store_stream_d;
+
+ /// The params for the D shared load stream.
+ typename SharedLoadStreamD::Params shared_load_stream_d;
+
+ /// Stride for C
+ platform::Pair batch_stride_C;
+
+ /// Stride for D
+ platform::Pair batch_stride_D;
+
+ //
+ // Methods
+ //
+
+ /// Default constructor
+ CUTLASS_HOST_DEVICE
+ Params() {
+ batch_stride_C.first = 0;
+ batch_stride_C.second = 0;
+
+ batch_stride_D.first = 0;
+ batch_stride_D.second = 0;
+ }
+
+ /// Setup the params.
+ CUTLASS_HOST_DEVICE int initialize(
+ platform::complex alpha,
+ platform::complex beta,
+ ScalarC const* real_C,
+ Index real_ldc,
+ ScalarC const* imag_C,
+ Index imag_ldc,
+ ScalarD* real_D,
+ Index real_ldd,
+ ScalarD* imag_D,
+ Index imag_ldd) {
+
+ int result = functor.initialize(alpha, beta);
+ if (result) {
+ return result;
+ }
+
+ // Setup the params for the global memory iterator for C.
+ result = load_stream_c.iterator.first.initialize(
+ real_C, real_ldc, real_ldc, 1);
+
+ if (result) {
+ return result;
+ }
+
+ result = load_stream_c.iterator.second.initialize(
+ imag_C, imag_ldc, imag_ldc, 1);
+
+ if (result) {
+ return result;
+ }
+
+ // Setup the params for the global memory iterator for D.
+ result = store_stream_d.iterator.first.initialize(
+ real_D, real_ldd, real_ldd, 1);
+
+ if (result) {
+ return result;
+ }
+
+ result = store_stream_d.iterator.second.initialize(
+ imag_D, imag_ldd, imag_ldd, 1);
+
+ if (result) {
+ return result;
+ }
+
+ return result;
+ }
+
+ /// Setup the params.
+ CUTLASS_HOST_DEVICE int initialize(
+ platform::complex alpha,
+ platform::complex beta,
+ ScalarC const* real_C,
+ Index real_ldc,
+ LongIndex stride_C_real,
+ ScalarC const* imag_C,
+ Index imag_ldc,
+ LongIndex stride_C_imag,
+ ScalarD* real_D,
+ Index real_ldd,
+ LongIndex stride_D_real,
+ ScalarD* imag_D,
+ Index imag_ldd,
+ LongIndex stride_D_imag) {
+
+ batch_stride_C.first = stride_C_real;
+ batch_stride_C.second = stride_C_imag;
+
+ batch_stride_D.first = stride_D_real;
+ batch_stride_D.second = stride_D_imag;
+
+ return initialize(alpha, beta, real_C, real_ldc, imag_C, imag_ldc, real_D, real_ldd, imag_D, imag_ldd);
+ }
+ };
+
+ /// Shared memory buffer used by epilogue
+ typedef ZipTileAllocation<
+ typename RealEpilogueTraits::SharedStorage,
+ typename RealEpilogueTraits::SharedStorage> SharedStorage;
+
+ /// Functor computing the offset from the threadblock origin per iteration of
+ /// the epilogue.
+ typedef typename RealEpilogueTraits::GlobalOffset GlobalOffset;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace gemm
+
+namespace platform {
+
+/// Here's a helpful arithmetic operator
+CUTLASS_HOST_DEVICE
+Pair operator*(int s, Pair _pair) {
+ return Pair(s * _pair.first, s * _pair.second);
+}
+
+}
+
+} // namespace cutlass
+
+// clang-format on
diff --git a/cutlass/gemm/volta884_complex_gemm_traits.h b/cutlass/gemm/volta884_complex_gemm_traits.h
new file mode 100644
index 0000000000..593b28dd46
--- /dev/null
+++ b/cutlass/gemm/volta884_complex_gemm_traits.h
@@ -0,0 +1,558 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Defines structural properties for complex-valued GEMM targeting Volta's mma.sync
+ instruction.
+
+ At present, it expects split complex representation in global memory in which the real part and
+ imaginary parts of a complex-valued matrices are disjoint (a structure of arrays). This is in
+ contrast with an interleaved complex representation which is an array of structures.
+*/
+
+#pragma once
+
+// clang-format off
+
+#include "cutlass/gemm/clear_accumulators.h"
+#include "cutlass/gemm/gemm_config.h"
+#include "cutlass/gemm/gemm_stream_pair.h"
+#include "cutlass/gemm/threadblock_swizzle.h"
+#include "cutlass/gemm/linear_scaling.h"
+#include "cutlass/kernel_launch.h"
+#include "cutlass/tensor_ref_collection.h"
+
+#include "cutlass/gemm/gemm_desc.h"
+
+#include "cutlass/gemm/volta884_multiplicand.h"
+#include "cutlass/gemm/mma_shared_stream.h"
+#include "cutlass/gemm/volta884_gemm_traits.h"
+
+#include "cutlass/gemm/volta884_complex_multiply_add.h"
+#include "cutlass/gemm/volta884_complex_global_stream.h"
+#include "cutlass/gemm/volta884_complex_shared_stream.h"
+#include "cutlass/gemm/volta884_complex_gemm_epilogue_traits.h"
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Defines configuration for Volta884 GEMM
+template <
+ /// The layout for A.
+ MatrixLayout::Kind LayoutA,
+ /// Indicates matrix transform on multiplicand A
+ MatrixTransform::Kind TransformA,
+ /// The layout for B.
+ MatrixLayout::Kind LayoutB,
+ /// Indicates matrix transform on multiplicand B
+ MatrixTransform::Kind TransformB,
+ /// The tile size for the GEMM KxNxM.
+ typename OutputTile_,
+ /// Tile size for warp-level GEMM (K-by-N-by-M)
+ typename WarpGemmShape_,
+ /// The accumulator type.
+ typename Accumulator_,
+ /// The source matrix type type.
+ typename ScalarC_,
+ /// The destination matrix type
+ typename ScalarD_,
+ /// Number of stages in shared memory
+ int StageCount,
+ /// Enables or disables launch bounds
+ bool LaunchBounds>
+struct Volta884ComplexGemmConfig : public GemmConfig<
+ /// The scalar type for A.
+ half,
+ /// The scalar type for B.
+ half,
+ /// The scalar type for C.
+ ScalarC_,
+ /// The scalar type for D.
+ ScalarD_,
+ /// The threadblock tile size
+ OutputTile_,
+ /// The functor to do the math in the main loop.
+ Volta884ComplexMultiplyAdd,
+ /// The number of scalars per LDG for A.
+ 8,
+ /// The number of scalars per STS for A.
+ 8,
+ /// The number of scalars per LDS for A.
+ 8,
+ /// The number of scalars per LDG for B.
+ 8,
+ /// The number of scalars per STS for B.
+ 8,
+ /// The number of scalars per LDS for B.
+ 8,
+ /// The number of scalars per LDG for C and STG for D.
+ 16 / int(sizeof(ScalarD_)),
+ /// The number of scalars per STS for D.
+ 16 / int(sizeof(ScalarD_)),
+ /// The number of scalars per LDS for D.
+ 16 / int(sizeof(ScalarD_)),
+ /// The number of stages in shared memory.
+ StageCount,
+ /// If true, separate mainloop is instantiated
+ true,
+ /// If true, compute residue in prolog
+ false,
+ /// Launch bounds not used
+ LaunchBounds> {};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Defines components of Volta884 GEMM
+template <
+ /// The layout for A.
+ MatrixLayout::Kind LayoutA,
+ /// Indicates matrix transform on multiplicand A
+ MatrixTransform::Kind TransformA,
+ /// The layout for B.
+ MatrixLayout::Kind LayoutB,
+ /// Indicates matrix transform on multiplicand B
+ MatrixTransform::Kind TransformB,
+ /// The tile size for the GEMM KxNxM.
+ typename OutputTile_,
+ /// Tile size for warp-level GEMM (K-by-N-by-M)
+ typename WarpGemmShape_,
+ /// The accumulator type.
+ typename Accumulator_,
+ /// The input matrix type type.
+ typename ScalarC_,
+ /// The output matrix type type.
+ typename ScalarD_,
+ /// Number of buffers in shared memory to use
+ int StageCount,
+ /// The functor to do the math in the epilogue.
+ typename EpilogueFunctor_ = SplitComplexLinearScaling,
+ /// Enables or disables launch bounds
+ bool LaunchBounds = false
+>
+struct Volta884ComplexGemmTraits {
+
+ /// This is insane.
+ typedef Volta884ComplexGemmTraits<
+ LayoutA,
+ TransformA,
+ LayoutB,
+ TransformB,
+ OutputTile_,
+ WarpGemmShape_,
+ Accumulator_,
+ ScalarC_,
+ ScalarD_,
+ StageCount,
+ EpilogueFunctor_,
+ LaunchBounds> This;
+
+ /// The actual device-side GEMM
+ typedef GemmMainloop KernelClass;
+
+ /// Layout of multiplicand A matrix
+ static MatrixLayout::Kind const kLayoutA = LayoutA;
+
+ /// If true, A operand is conjugated
+ static MatrixTransform::Kind const kTransformA = TransformA;
+
+ /// Layout of multiplicand B matrix
+ static MatrixLayout::Kind const kLayoutB = LayoutB;
+
+ /// If true, B operand is conjugated
+ static MatrixTransform::Kind const kTransformB = TransformB;
+
+ /// Dimensions of threadblock tile (concept Shape)
+ typedef OutputTile_ OutputTile;
+
+ /// Shape of warp-level accumulators
+ typedef WarpGemmShape_ WarpGemmShape;
+
+ /// Multiplicand A scalar type
+ typedef half ScalarA;
+
+ /// Multiplicand B scalar type
+ typedef half ScalarB;
+
+ /// Data type of internal accumulator
+ typedef Accumulator_ Accumulator;
+
+ /// Data type of input accumulator matrix operand
+ typedef ScalarC_ ScalarC;
+
+ /// Data type of output accumulator matrix operand
+ typedef ScalarD_ ScalarD;
+
+ /// Shape of individual mma.sync instruction
+ typedef Shape<4, 16, 16> InstructionShape;
+
+ /// Tile size for an individual warp-level multiply-add
+ typedef Shape WarpTile;
+
+ /// Defines properties about GEMM needed by host code
+ typedef Volta884ComplexGemmConfig<
+ kLayoutA,
+ kTransformA,
+ kLayoutB,
+ kTransformB,
+ OutputTile,
+ WarpGemmShape,
+ Accumulator,
+ ScalarC,
+ ScalarD,
+ StageCount,
+ LaunchBounds>
+ GemmConfig;
+
+ //
+ // Derived types
+ //
+
+ /// Index type
+ typedef int Index;
+
+ /// Long index type
+ typedef long long LongIndex;
+
+ /// Partitioning of threadblock into warps
+ typedef typename ShapeDiv::Shape WarpDelta;
+
+ /// Number of warps per threadblock
+ static int const kWarpCount = ShapeCount::kCount;
+
+ /// Defines iterators for A matrix
+ typedef Volta884Multiplicand
+ MultiplicandA;
+
+ /// Defines iterators for B matrix
+ typedef Volta884Multiplicand
+ MultiplicandB;
+
+ //
+ // GemmTraits mandatory type definitions
+ //
+
+ /// Maps hardware threadblocks to logical partitions of the GEMM
+ typedef IdentityBlockSwizzle BlockSwizzle;
+
+ /// Clears accumulators
+ typedef ClearAccumulators ClearAccumulators;
+
+ /// Loads multiplicands from global memory
+ typedef GlobalLoadStreamPair<
+ Volta884ComplexGlobalLoadStream,
+ typename MultiplicandA::StoreIterator,
+ StageCount>,
+ Volta884ComplexGlobalLoadStream,
+ typename MultiplicandB::StoreIterator,
+ StageCount>,
+ GemmConfig::kResidueInProlog >
+ GlobalLoadStream;
+
+ /// Memory needed to store the threadblock-scoped GEMM tile
+ typedef typename GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage;
+
+ /// Shared memory storage for mainloop phase
+ union MainLoopStorage {
+
+ /// Stores the threadblock tile
+ ThreadblockTileStorage threadblock_tile;
+
+ /// Storage for GEMM global stream
+ typename GlobalLoadStream::SharedStorage global_to_shared_stream;
+ };
+
+ /// Loads multiplicands from shared memory
+ typedef SharedStreamPair<
+ Volta884ComplexSharedLoadStream,
+ StageCount>,
+ Volta884ComplexSharedLoadStream,
+ StageCount> >
+ SharedStream;
+
+ // Multiply-add object specialized for Volta mma.sync
+ typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
+
+ #if 0
+ /// Naive epilogue for updating the output matrix
+ typedef Volta884ComplexNaiveEpilogue
+ Epilogue;
+
+ #else
+
+ /// Efficient epilogue
+ typedef MMAEpilogue<
+ Volta884ComplexGemmEpilogueTraits
+ > Epilogue;
+
+ #endif
+
+ /// Tensor reference to A multiplicand
+ typedef ZipTensorRef<
+ TensorRef,
+ TensorRef
+ > TensorRefA;
+
+ /// Tensor reference to B multiplicand
+ typedef ZipTensorRef<
+ TensorRef,
+ TensorRef
+ > TensorRefB;
+
+ /// Tensor reference to C multiplicand
+ typedef ZipTensorRef<
+ TensorRef,
+ TensorRef
+ > TensorRefC;
+
+ /// Tensor reference to D multiplicand
+ typedef ZipTensorRef<
+ TensorRef,
+ TensorRef
+ > TensorRefD;
+
+ /// gemm::ProblemDesc<>
+ typedef GemmDesc<
+ TensorRefA,
+ TensorRefB,
+ TensorRefC,
+ TensorRefD,
+ float
+ > GemmDesc;
+
+ /// Parameters structure
+ struct Params : public KernelLaunchConfiguration {
+ /// The dimensions of the GEMM.
+ GemmCoord problem_size;
+
+ /// PartitionK_range
+ int partitionK_range;
+
+ /// The params for the global load stream
+ typename GlobalLoadStream::Params global_to_shared_stream;
+
+ /// The params for the shared load stream
+ typename SharedStream::Params shared_stream;
+
+ /// The params for the epilogue.
+ typename Epilogue::Params epilogue;
+
+ //
+ // Methods
+ //
+
+ CUTLASS_HOST_DEVICE
+ Params() {}
+
+ /// Initialize the Params struct
+ CUTLASS_HOST_DEVICE int initialize(
+ Index m,
+ Index n,
+ Index k,
+ platform::complex alpha,
+ ScalarA const* real_A,
+ Index real_lda,
+ ScalarA const* imag_A,
+ Index imag_lda,
+ ScalarB const* real_B,
+ Index real_ldb,
+ ScalarB const* imag_B,
+ Index imag_ldb,
+ platform::complex beta,
+ ScalarC const* real_C,
+ Index real_ldc,
+ ScalarC const* imag_C,
+ Index imag_ldc,
+ ScalarD* real_D,
+ Index real_ldd,
+ ScalarD* imag_D,
+ Index imag_ldd) {
+
+ problem_size = make_Coord(k, n, m, 1);
+
+ partitionK_range = problem_size.k();
+
+ // Compute grid dimensions
+ BlockSwizzle block_swizzle;
+ this->block = dim3(GemmConfig::kThreads);
+ this->grid = block_swizzle.get_grid_layout(
+ problem_size,
+ make_Coord_from_shape());
+
+ // Initialize global load streams
+ global_to_shared_stream.stream_a.initialize(
+ make_ZipTensorRef(
+ TensorRefBatchStrided(TensorRef(real_A, real_lda), 0),
+ TensorRefBatchStrided(TensorRef(imag_A, imag_lda), 0)
+ ),
+ 0
+ );
+
+ global_to_shared_stream.stream_b.initialize(
+ make_ZipTensorRef(
+ TensorRefBatchStrided(TensorRef(real_B, real_ldb), 0),
+ TensorRefBatchStrided(TensorRef(imag_B, imag_ldb), 0)
+ ),
+ 0
+ );
+
+ return epilogue.initialize(
+ alpha,
+ beta,
+ real_C,
+ real_ldc,
+ imag_C,
+ imag_ldc,
+ real_D,
+ real_ldd,
+ imag_D,
+ imag_ldd
+ );
+ }
+
+ /// Initialize the Params struct
+ CUTLASS_HOST_DEVICE int initialize(
+ Index m,
+ Index n,
+ Index k,
+ platform::complex alpha,
+ ScalarA const* real_A,
+ Index real_lda,
+ LongIndex batch_stride_A_real,
+ ScalarA const* imag_A,
+ Index imag_lda,
+ LongIndex batch_stride_A_imag,
+ ScalarB const* real_B,
+ Index real_ldb,
+ LongIndex batch_stride_B_real,
+ ScalarB const* imag_B,
+ Index imag_ldb,
+ LongIndex batch_stride_B_imag,
+ platform::complex beta,
+ ScalarC const* real_C,
+ Index real_ldc,
+ LongIndex batch_stride_C_real,
+ ScalarC const* imag_C,
+ Index imag_ldc,
+ LongIndex batch_stride_C_imag,
+ ScalarD* real_D,
+ Index real_ldd,
+ LongIndex batch_stride_D_real,
+ ScalarD* imag_D,
+ Index imag_ldd,
+ LongIndex batch_stride_D_imag,
+ int batch_count) {
+
+ problem_size = make_Coord(k, n, m, batch_count);
+ partitionK_range = problem_size.k();
+
+ // Compute grid dimensions
+ BlockSwizzle block_swizzle;
+ this->block = dim3(GemmConfig::kThreads);
+ this->grid = block_swizzle.get_grid_layout(
+ problem_size,
+ make_Coord_from_shape());
+
+ // Initialize global load streams
+ global_to_shared_stream.stream_a.initialize(
+ make_ZipTensorRef(
+ TensorRefBatchStrided(TensorRef(real_A, real_lda), batch_stride_A_real),
+ TensorRefBatchStrided(TensorRef(imag_A, imag_lda), batch_stride_A_imag)
+ ),
+ 0
+ );
+
+ global_to_shared_stream.stream_b.initialize(
+ make_ZipTensorRef(
+ TensorRefBatchStrided(TensorRef(real_B, real_ldb), batch_stride_B_real),
+ TensorRefBatchStrided(TensorRef(imag_B, imag_ldb), batch_stride_B_imag)
+ ),
+ 0
+ );
+
+ return epilogue.initialize(
+ alpha,
+ beta,
+ real_C,
+ real_ldc,
+ batch_stride_C_real,
+ imag_C,
+ imag_ldc,
+ batch_stride_C_imag,
+ real_D,
+ real_ldd,
+ batch_stride_D_real,
+ imag_D,
+ imag_ldd,
+ batch_stride_D_imag
+ );
+ }
+ };
+
+ /// Shared memory storage
+ union SharedStorage {
+ /// Storage required during mainloop phase
+ MainLoopStorage main_loop;
+
+ /// Shared storage needed for epilogue
+ typename Epilogue::SharedStorage epilogue;
+ };
+
+ /// The memory fence for shared loads.
+ static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
+ if (StageCount < 2) {
+ __syncthreads();
+ }
+ }
+
+ /// The memory fence for shared stores.
+ static CUTLASS_DEVICE void shared_store_fence(bool in_loop) { __syncthreads(); }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace gemm
+} // namespace cutlass
+
+// clang-format on
diff --git a/cutlass/gemm/volta884_complex_global_stream.h b/cutlass/gemm/volta884_complex_global_stream.h
new file mode 100644
index 0000000000..7e3a92cb2f
--- /dev/null
+++ b/cutlass/gemm/volta884_complex_global_stream.h
@@ -0,0 +1,315 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Implements efficient loading of the thread block-level tile from global memory and
+ storing
+ to shared memory.
+*/
+
+#pragma once
+
+// clang-format off
+
+#include "cutlass/convert.h"
+#include "cutlass/zip_tile_iterator.h"
+#include "cutlass/zip_tensor_ref.h"
+#include "cutlass/gemm/gemm_operand.h"
+#include "cutlass/predicate_vector.h"
+#include "cutlass/util/pair.h"
+
+#include "cutlass/gemm/mma_global_stream.h"
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+///! Stream adapter for loading threadblock-scoped GEMM tiles and storing to shared memory
+template <
+ /// Identifies multiplicand
+ GemmOperand::Kind Operand,
+ /// Layout of source matrix in global memory
+ MatrixLayout::Kind Layout,
+ /// Iterator for loading threadblock-scoped tiles
+ typename LoadIterator_,
+ /// Transformation functor for transforming fragments
+ typename Transformer_,
+ /// Iterator for storing threadblock-scoped tiles to shared memory
+ typename StoreIterator_,
+ /// Number of stores before iterator wraps - zero indicates no wrapping
+ int StageCount>
+struct Volta884ComplexGlobalLoadStream {
+
+ //
+ // Type definitions
+ //
+
+ /// Identifies the operand
+ static GemmOperand::Kind const kOperand = Operand;
+
+ /// The layout.
+ static MatrixLayout::Kind const kLayout = Layout;
+
+ /// Load-store stream for real-valued matrices
+ typedef MMAGlobalLoadStream RealLoadStoreStream;
+
+ /// Loads a pair of real-valued fragments
+ typedef ZipTileIterator LoadIterator;
+
+ /// Zips a pair of transformers
+ typedef ZipConvert Transformer;
+
+ /// Stores a pair of real-valued ragments
+ typedef ZipTileIterator StoreIterator;
+
+ /// Number of stages
+ static int const kStageCount = StageCount;
+
+ /// Predicate vector
+ typedef typename RealLoadStoreStream::PredicateVector PredicateVector;
+
+ /// The fragment that is copied from shared memory.
+ typedef typename LoadIterator::Fragment FetchedFragment;
+ /// The fragment that is obtained after the transformation by the transformer.
+ typedef typename Transformer::OutputFragment TransformedFragment;
+ /// Make sure the fragments match.
+ static_assert((platform::is_same::value),
+ "");
+ /// The output fragment.
+ typedef TransformedFragment Fragment;
+ /// Make sure the transformed fragment is the same as the store fragment.
+ static_assert((platform::is_same::value),
+ "");
+
+ /// Index type
+ typedef typename RealLoadStoreStream::Index Index;
+
+ /// Long index type
+ typedef typename RealLoadStoreStream::LongIndex LongIndex;
+
+ /// The params.
+ struct Params {
+
+ //
+ // Type definitions
+ //
+
+ /// Matrix reference
+ typedef ZipTensorRef<
+ TensorRefBatchStrided,
+ TensorRefBatchStrided > SourceTensorRef;
+
+ /// Helper
+ static int const kElementsPerLdg = LoadIterator::First::Tile::kC;
+
+ //
+ // Data members
+ //
+
+ /// Source tensor reference
+ platform::Pair batch_stride;
+
+ // The load iterator.
+ typename LoadIterator::Params load_iterator;
+
+ // Offset to residue.
+ Index offset_to_residue;
+ //
+ // Methods
+ //
+
+ CUTLASS_HOST_DEVICE
+ Params() {}
+
+ ///
+ CUTLASS_HOST_DEVICE
+ Params(SourceTensorRef const &ref, Index _offset_to_residue) {
+ initialize(ref, _offset_to_residue);
+ }
+
+ CUTLASS_HOST_DEVICE
+ int initialize(SourceTensorRef const &ref, Index _offset_to_residue) {
+
+ batch_stride.first = ref.first.tensor_stride;
+ batch_stride.second = ref.second.tensor_stride;
+
+ offset_to_residue = _offset_to_residue;
+ load_iterator.first.initialize(
+ TensorRef(
+ ref.first.at().data(),
+ make_Coord(ref.first.at().stride(0) * kElementsPerLdg, ref.first.at().stride(0), kElementsPerLdg)
+ )
+ );
+ load_iterator.second.initialize(
+ TensorRef(
+ ref.second.at().data(),
+ make_Coord(ref.second.at().stride(0) * kElementsPerLdg, ref.second.at().stride(0), kElementsPerLdg)
+ )
+ );
+ return 0;
+ }
+ };
+
+ /// Empty shared storage
+ struct SharedStorage {};
+
+ /// Shared memory allocation for the tile
+ typedef TileAllocation<
+ typename RealLoadStoreStream::StoreIterator::Scalar,
+ typename ShapeMul<
+ typename RealLoadStoreStream::StoreIterator::OperandShape,
+ Shape
+ >::Shape
+ > RealThreadblockTileStorage;
+
+ /// Threadblock tile allocation
+ typedef ZipTileAllocation<
+ RealThreadblockTileStorage,
+ RealThreadblockTileStorage
+ > ThreadblockTileStorage;
+
+ /// Reference to ThreadblockTileStorage
+ typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
+
+ //
+ // Data members
+ //
+
+ ///! The parameters
+ Params params;
+
+ ///! Dimensions of global memory tile
+ Coord<3> threadblock_offset;
+
+ ///! Multiplicand bounds
+ Coord<3> multiplicand_bounds;
+
+ ///! Iterator to load threadblock tiles from global memory
+ LoadIterator load_iterator;
+
+ ///! Predicate vector
+ PredicateVector predicates;
+
+ ///! The fragment to fetch from shared memory.
+ FetchedFragment fetched_fragment;
+
+ ///! Functor to transform fragments after they have been loaded
+ Transformer transformer;
+
+ ///! The fragment to convert the data after it has been fetched from shared memory.
+ TransformedFragment transformed_fragment;
+
+ ///! Iterator to store threadblock tiles to shared memory
+ StoreIterator store_iterator;
+
+ ///! Counter
+ int stage_index;
+
+ //
+ // Methods
+ //
+
+ /// Constructor
+ CUTLASS_DEVICE Volta884ComplexGlobalLoadStream(Params const &_params,
+ SharedStorage &shared_storage,
+ ThreadblockTileRef const &threadblock_tile_ref,
+ Coord<3> const bounds,
+ Coord<3> const &block)
+ : params(_params),
+ threadblock_offset(RealLoadStoreStream::project_coordinate(block)),
+ multiplicand_bounds(RealLoadStoreStream::project_coordinate(bounds, 1)),
+ load_iterator(params.load_iterator, threadblock_offset),
+ transformer(),
+ store_iterator(threadblock_tile_ref),
+ stage_index(0) {
+
+ // initialize predicates used to guard loads
+ load_iterator.initialize_predicates(
+ predicates.begin(), multiplicand_bounds, threadblock_offset);
+ }
+
+ /// Loads the data from global memory
+ CUTLASS_DEVICE void copy() {
+ load_iterator.load_post_increment(fetched_fragment, predicates.begin());
+ }
+
+ /// Transform and commit the data to shared memory
+ CUTLASS_DEVICE void commit() {
+
+ transformer.transform(fetched_fragment, transformed_fragment);
+ store_iterator.store_post_increment(transformed_fragment);
+
+ ++stage_index;
+ if (kStageCount && stage_index == kStageCount) {
+ store_iterator -= kStageCount;
+ stage_index = 0;
+ }
+ }
+
+ /// Computes a predicate mask for loads during final threadblock tile load iteration
+ CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
+ // That's the residue!
+ Coord<3> _block_offset = threadblock_offset;
+ if (kOperand == GemmOperand::kA ^ kLayout == MatrixLayout::kRowMajor) {
+ // K-strided
+ _block_offset =
+ make_Coord(threadblock_offset[0], multiplicand_bounds[1] - k, threadblock_offset[2]);
+ } else {
+ // K-contiguous
+ _block_offset = make_Coord(threadblock_offset[0],
+ threadblock_offset[1],
+ multiplicand_bounds[2] - k / LoadIterator::First::Tile::kC);
+ }
+
+ load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, _block_offset);
+ fetched_fragment.clear();
+ }
+
+ CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {}
+
+ CUTLASS_DEVICE void rollback() {}
+
+ /// Adds a Coord<3> to the underlying global load iterator
+ CUTLASS_DEVICE Volta884ComplexGlobalLoadStream &operator+=(Coord<3> const &offset) {
+ load_iterator += offset;
+ return *this;
+ }
+
+ /// Adds an offset based on batch stride
+ CUTLASS_DEVICE Volta884ComplexGlobalLoadStream &add_batch_offset(int batch_id) {
+ load_iterator.first.add_pointer_offset(params.batch_stride.first * batch_id);
+ load_iterator.second.add_pointer_offset(params.batch_stride.second * batch_id);
+ return *this;
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace gemm
+} // namespace cutlass
+
+// clang-format on
diff --git a/cutlass/gemm/volta884_complex_multiply_add.h b/cutlass/gemm/volta884_complex_multiply_add.h
new file mode 100644
index 0000000000..5120c77b4a
--- /dev/null
+++ b/cutlass/gemm/volta884_complex_multiply_add.h
@@ -0,0 +1,319 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Implements warp-level multiply-accumulate operations using Volta's mma.sync instruction
+ for complex-valued data types.
+*/
+
+#pragma once
+
+#include "cutlass/util/complex.h"
+#include "cutlass/zip_fragment.h"
+#include "cutlass/gemm/volta884_multiply_add.h"
+#include "cutlass/zip_fragment.h"
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <
+ /// Shape of a warp-level GEMM (K-by-N-by-M)
+ typename WarpGemmShape_,
+ /// Layout of multiplicand A
+ MatrixLayout::Kind LayoutA,
+ /// Indicates matrix transform on multiplicand A
+ MatrixTransform::Kind TransformA,
+ /// Data type of multiplicand A
+ typename ScalarA_,
+ /// Layout of multiplicand B
+ MatrixLayout::Kind LayoutB,
+ /// Indicates matrix transform on multiplicand B
+ MatrixTransform::Kind TransformB,
+ /// Data type of multiplicand B
+ typename ScalarB_,
+ /// Data type of accumulators
+ typename ScalarC_,
+ /// If true, A operand is conjugated
+ bool ConjugateA = false,
+ /// If true, B operand is conjugated
+ bool ConjugateB = false,
+ /// If true, infinite results are saturated to +-MAX_FLOAT
+ bool SatFinite = false>
+struct Volta884ComplexMultiplyAdd {
+ //
+ // Constant and type definitions
+ //
+
+ /// Shape of a warp-level GEMM (K-by-N-by-M)
+ typedef WarpGemmShape_ WarpGemmShape;
+
+ /// Shape of a warp-level GEMM (K-by-N-by-M)
+ typedef WarpGemmShape_ AccumulatorsPerWarp;
+
+ /// Most of the Volta884 code assumes interleaved 32x32 tiles
+ typedef Shape<4, 32, 32> InterleavedTileShape;
+
+ /// Shape of an individual warp-wide mma.sync instruction
+ typedef Shape<4, 16, 16> InstructionShape;
+
+ /// Shape of a warp-level matrix multiply operation
+ typedef Shape WarpTile;
+
+ /// Verify WarpTile is a multiple of fundamental 32x32 interleaved tile
+ static_assert(!(WarpTile::kH % InterleavedTileShape::kH) &&
+ !(WarpTile::kW % InterleavedTileShape::kW) && WarpTile::kD == 4,
+ "WarpTile must be a multiple of InterleavedTileShape.");
+
+ /// Layout of A multiplicand
+ static MatrixLayout::Kind const kLayoutA = LayoutA;
+
+ /// Indicates matrix transform on multiplicand B
+ static MatrixTransform::Kind const kTransformA = TransformA;
+
+ /// Layout of B multiplicand
+ static MatrixLayout::Kind const kLayoutB = LayoutB;
+
+ /// Indicates matrix transform on multiplicand B
+ static MatrixTransform::Kind const kTransformB = TransformB;
+
+ /// The type for A.
+ typedef ScalarA_ ScalarA;
+ /// The type for B.
+ typedef ScalarB_ ScalarB;
+ /// The type for C and D.
+ typedef ScalarC_ ScalarC;
+
+ /// If true, infinite results are saturated to +-MAX_FLOAT
+ static bool const kSatFinite = SatFinite;
+
+ /// Hard-coded comptue type supported on Volta
+ static arch::ComputeType::Kind const kComputeType = arch::ComputeType::kDefault;
+
+ /// Underlying matrix multiply-add operator
+ typedef Volta884MultiplyAdd
+ RealMultiplyAdd;
+
+ /// Fragment definition for A multiplicand
+ typedef ZipFragment
+ FragmentA;
+
+ /// Fragment definition for B multiplicand
+ typedef ZipFragment
+ FragmentB;
+
+ /// Fragment definition for accumulators
+ typedef ZipFragment
+ Accumulators;
+
+ /// Number of mma.sync operations performed. See Volta884MultiplyAdd::Iterations for details.
+ typedef typename RealMultiplyAdd::Iterations Iterations;
+
+ //
+ // Methods
+ //
+
+ /// Ctor.
+ CUTLASS_DEVICE Volta884ComplexMultiplyAdd() {}
+
+ /// Multiply : d = a*b.
+ CUTLASS_DEVICE void multiply_add(FragmentA const& A,
+ FragmentB const& B,
+ Accumulators const& C,
+ Accumulators& D) {
+ RealMultiplyAdd op;
+
+ // complex-valued multiply-add
+ op.multiply_add(A.first, B.first, C.first, D.first);
+ op.multiply_add(A.first, B.second, C.second, D.second, kTransformB == MatrixTransform::kConjugate);
+ op.multiply_add(A.second, B.first, C.second, D.second, kTransformA == MatrixTransform::kConjugate);
+ op.multiply_add(A.second, B.second, C.first, D.first,
+ !((kTransformA == MatrixTransform::kConjugate) ^ (kTransformB == MatrixTransform::kConjugate)));
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Complex-valued epilogue
+template
+struct Volta884ComplexNaiveEpilogue {
+ /// Accumulator data type
+ typedef Accumulator ScalarC;
+
+ /// Output accumulator type
+ typedef Accumulator ScalarD;
+
+ /// BLAS Scalar type
+ typedef Accumulator Scalar;
+
+ /// Real-valued epilogue
+ typedef Volta884NaiveEpilogue RealEpilogue;
+
+ /// Params object
+ struct Params {
+ /// Parameters for the real-valued part
+ typename RealEpilogue::Params real;
+
+ /// Parameters for the imaginary-valued part
+ typename RealEpilogue::Params imag;
+
+ //
+ // Methods
+ //
+
+ /// Default constructor
+ CUTLASS_HOST_DEVICE Params() {}
+
+ /// Constructs from params object
+ CUTLASS_HOST_DEVICE Params(typename RealEpilogue::Params const& _real,
+ typename RealEpilogue::Params const& _imag)
+ : real(_real), imag(_imag) {}
+
+ /// Construct from pointers
+ CUTLASS_HOST_DEVICE Params(ScalarC* _real, int _ldr, ScalarC* _imag, int _ldi)
+ : real(_real, _ldr), imag(_imag, _ldi) {}
+
+ /// Construct from pointers
+ CUTLASS_HOST_DEVICE Params(
+ platform::complex const &alpha,
+ platform::complex const &beta,
+ ScalarC const *real_C,
+ int real_ldc,
+ ScalarC const *imag_C,
+ int imag_ldc,
+ ScalarD *real_D,
+ int real_ldd,
+ ScalarD *imag_D,
+ int imag_ldd
+ ):
+ real(real_D, real_ldd, alpha.real(), beta.real()),
+ imag(imag_D, imag_ldd, alpha.real(), beta.real()) { }
+
+ /// Initializer method
+ CUTLASS_HOST_DEVICE
+ int initialize(
+ platform::complex const &alpha,
+ platform::complex const &beta,
+ ScalarC const *real_C,
+ int real_ldc,
+ ScalarC const *imag_C,
+ int imag_ldc,
+ ScalarD *real_D,
+ int real_ldd,
+ ScalarD *imag_D,
+ int imag_ldd
+ ) {
+
+ real = typename RealEpilogue::Params(real_D, real_ldd, alpha.real(), beta.real());
+ imag = typename RealEpilogue::Params(imag_D, imag_ldd, alpha.real(), beta.real());
+
+ return 0;
+ }
+ };
+
+ /// Shared stoarge
+ struct SharedStorage {};
+
+ /// Accumulator fragment definition
+ typedef ZipFragment<
+ typename RealEpilogue::Accumulators,
+ typename RealEpilogue::Accumulators> Accumulators;
+
+ //
+ // Data members
+ //
+
+ /// Epilogue for real part
+ RealEpilogue real;
+
+ /// Epilogue for imaginary part
+ RealEpilogue imag;
+
+ //
+ // Methods
+ //
+
+ /// Constructs a complex-valued epilogue
+ CUTLASS_DEVICE Volta884ComplexNaiveEpilogue(
+ Params const& _params, Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
+ : real(_params.real, _problem_size), imag(_params.imag, _problem_size) {}
+
+ /// Constructs a complex-valued epilogue
+ CUTLASS_DEVICE Volta884ComplexNaiveEpilogue(ScalarC* _real,
+ int _ldr,
+ ScalarC* _imag,
+ int _ldi,
+ Coord<3> const& _problem_size = make_Coord(1024,
+ 1024,
+ 1024))
+ : real(_real, _ldr, _problem_size), imag(_imag, _ldi, _problem_size) {}
+
+ /// Constructs a complex-valued epilogue
+ CUTLASS_DEVICE Volta884ComplexNaiveEpilogue(Params const& _params,
+ SharedStorage& shared_storage,
+ Coord<3> const& _problem_size = make_Coord(1024,
+ 1024,
+ 1024))
+ : real(_params.real, _problem_size), imag(_params.imag, _problem_size) {}
+
+ /// Sets accumulators to zero
+ CUTLASS_DEVICE void clear(Accumulators& C) {
+ C.first.clear();
+ C.second.clear();
+ }
+
+ /// Naive load operation for debugging
+ CUTLASS_DEVICE void load(Accumulators& C,
+ Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
+ real.load(C.first, threadblock_offset);
+ imag.load(C.second, threadblock_offset);
+ }
+
+ /// Naive store operation for debugging
+ CUTLASS_DEVICE void store(Accumulators const& C,
+ Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
+ real.store(C.first, threadblock_offset);
+ imag.store(C.second, threadblock_offset);
+ }
+
+ /// CUTLASS Epilogue interface
+ CUTLASS_DEVICE void epilogue(Accumulators const& C,
+ Coord<3> const& threadblock_offset = make_Coord(0, 0, 0),
+ int batch_id = 0) {
+ real.store(C.first, threadblock_offset);
+ imag.store(C.second, threadblock_offset);
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace gemm
+} // namespace cutlass
diff --git a/cutlass/gemm/volta884_complex_shared_stream.h b/cutlass/gemm/volta884_complex_shared_stream.h
new file mode 100644
index 0000000000..17d2afc008
--- /dev/null
+++ b/cutlass/gemm/volta884_complex_shared_stream.h
@@ -0,0 +1,152 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Implements efficient loading of the thread block-level tile from global memory and
+ storing to shared memory.
+*/
+
+#pragma once
+
+#include "cutlass/convert.h"
+#include "cutlass/zip_fragment.h"
+#include "cutlass/zip_tensor_ref.h"
+#include "cutlass/zip_tile_iterator.h"
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Stream from shared memory to fragments for warp-level matrix multiply-accumulate
+template <
+ /// The load iterator.
+ typename Iterator_,
+ /// The transformer to be applied after the data has been copied from shared memory.
+ typename Transformer_ = Copy,
+ /// Number of increments before iterator wraps - zero indicates no wrapping
+ int StageCount = 1>
+struct Volta884ComplexSharedLoadStream {
+ /// The load iterator.
+ typedef Iterator_ RealIterator;
+
+ /// Zips two real-valued iterators together
+ typedef ZipTileIterator Iterator;
+
+ /// The transformer.
+ typedef Transformer_ RealTransformer;
+
+ /// Zips two transfoerms
+ typedef ZipConvert Transformer;
+
+ /// Number of increments before iterator wraps - zero indicates no wrapping
+ static int const kStageCount = StageCount;
+
+ /// The fragment that is copied from shared memory.
+ typedef typename Iterator::Fragment FetchedFragment;
+
+ /// The fragment that is obtained after the transformation by the transformer.
+ typedef typename Transformer::OutputFragment TransformedFragment;
+
+ /// Make sure the fragments match.
+ static_assert((platform::is_same::value),
+ "");
+
+ /// The output fragment.
+ typedef TransformedFragment Fragment;
+
+ /// Reference type
+ typedef ZipTensorRef<
+ TensorRef,
+ TensorRef
+ > TensorRef;
+
+ /// Parameters passed from host
+ struct Params { };
+
+ //
+ // Data members
+ //
+
+ /// Iterator for loading fragments for warp-level matrix multiply-accumulate
+ Iterator iterator;
+
+ /// Fetched fragment
+ FetchedFragment fetched[2];
+
+ /// The transformer.
+ Transformer transformer;
+
+ /// Transformed fragment
+ TransformedFragment transformed[2];
+
+ /// Counts the number of stages
+ int stage_index;
+
+ //
+ // Methods
+ //
+
+ /// Ctor.
+ CUTLASS_DEVICE Volta884ComplexSharedLoadStream() : stage_index(0) {}
+
+ /// Ctor.
+ CUTLASS_DEVICE Volta884ComplexSharedLoadStream(Params const &_params,
+ TensorRef const &ref)
+ : iterator(ref), stage_index(0) {}
+
+ /// Load the data from shared memory to the fetch fragment.
+ CUTLASS_DEVICE void copy(int step) {
+ iterator.load(fetched[step % 2],
+ make_Coord(step + stage_index * Iterator::First::VectorizedShape::kD, 0, 0, 0));
+ }
+
+ /// Commit the data.
+ CUTLASS_DEVICE void commit(int step) {
+ transformer.transform(fetched[step % 2], transformed[step % 2]);
+ }
+
+ /// Gets the transformed fragment
+ CUTLASS_DEVICE
+ TransformedFragment &fragment(int step) { return transformed[step % 2]; }
+
+ /// Gets the transformed fragment
+ CUTLASS_DEVICE
+ TransformedFragment const &fragment(int step) const { return transformed[step % 2]; }
+
+ /// Increment the stage.
+ CUTLASS_DEVICE void inc_stage() {
+ ++stage_index;
+ if (kStageCount && stage_index == StageCount) {
+ stage_index = 0;
+ }
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace gemm
+} // namespace cutlass
diff --git a/cutlass/gemm/volta884_gemm_epilogue_traits.h b/cutlass/gemm/volta884_gemm_epilogue_traits.h
new file mode 100644
index 0000000000..02b154204f
--- /dev/null
+++ b/cutlass/gemm/volta884_gemm_epilogue_traits.h
@@ -0,0 +1,771 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
+ with the computed matrix product.
+*/
+
+#pragma once
+
+// clang-format off
+
+#include "cutlass/tile_stream.h"
+#include "cutlass/tile_allocation.h"
+
+#include "cutlass/gemm/mma_shared_stream.h"
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Abstraction to select accumulators from an accumulator tile for each iteration fo the epilogue
+template
+struct Volta884SelectAccumulators;
+
+///////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Selects accumulators from Volta mma.sync.F32 layout
+template
+struct Volta884SelectAccumulators {
+ /// Shape of the warp-level matrix multiply operation
+ typedef WarpGemmShape_ WarpGemmShape;
+
+ /// Describes tiling of warp elements
+ typedef WarpDelta_ WarpDelta;
+
+ /// Data type of scalar
+ typedef float Scalar;
+
+ //
+ // Derived types and constants
+ //
+
+ /// (Actual) number of accumulators held by each individual thread
+ static int const kAccumulatorsPerThread = (WarpGemmShape::kH * WarpGemmShape::kW) / kWarpSize;
+
+ /// Accumulators fragment
+ typedef Fragment Accumulators;
+
+ /// Number of warps
+ static int const kWarpCount = ShapeCount::kCount;
+
+ /// Interleaved mma.sync shape
+ typedef Shape<4, 32, 32> MmaTileShape;
+
+ /// Hard-coded for FP32 layouts
+ typedef Shape<1, WarpGemmShape::kW / MmaTileShape::kW, 4> Elements;
+
+ /// Number of elements
+ static int const kElements = ShapeCount::kCount;
+
+ /// Slice of accumulators
+ typedef Fragment Fragment;
+
+ //
+ // Methods
+ //
+
+ /// Selects accumulators for a given iteration of the epilogue
+ CUTLASS_DEVICE
+ Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const {
+ Fragment frag;
+
+ static int const kAccumPerOp = 8;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < Elements::kH; ++j) {
+
+ // selects the 32x32 tile
+ Coord<2> tile_32x32 = make_Coord(idx[0] / 8, j);
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < Elements::kW; ++i) {
+ Coord<2> mma_op = make_Coord(((idx[0] >> 1) & 1), i / 2);
+
+ int element = ((i & 1) << 1) | (idx[0] & 1) | (idx[0] & 4);
+
+ int mma_op_idx = mma_op[1] + mma_op[0] * 2 + 4 * (tile_32x32[1] + 2 * tile_32x32[0]);
+
+ frag[i + j * Elements::kW] = accum[element + kAccumPerOp * mma_op_idx];
+ }
+ }
+
+ return frag;
+ }
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Selects accumulators from Volta mma.sync.F16 layout
+template
+struct Volta884SelectAccumulators {
+ /// Shape of the warp-level matrix multiply operation
+ typedef WarpGemmShape_ WarpGemmShape;
+
+ /// Describes tiling of warp elements
+ typedef WarpDelta_ WarpDelta;
+
+ /// Data type of accumulator elements
+ typedef half Scalar;
+
+ //
+ // Derived types and constants
+ //
+
+ /// (Actual) number of accumulators held by each individual thread
+ static int const kAccumulatorsPerThread = (WarpGemmShape::kH * WarpGemmShape::kW) / kWarpSize;
+
+ /// Accumulators fragment
+ typedef Fragment Accumulators;
+
+ /// Number of warps
+ static int const kWarpCount = ShapeCount::kCount;
+
+ /// Interleaved mma.sync shape
+ typedef Shape<4, 32, 32> MmaTileShape;
+
+ /// Hard-coded for FP16 layouts
+ typedef Shape<1, WarpGemmShape::kW / MmaTileShape::kW, 2> Elements;
+
+ /// Number of elements
+ static int const kElements = ShapeCount::kCount;
+
+ /// Slice of accumulators
+ typedef Fragment Fragment;
+
+ //
+ // Methods
+ //
+
+ /// Selects accumulators for a given iteration of the epilogue
+ CUTLASS_DEVICE
+ Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const {
+ Fragment frag;
+
+ static int const kAccumPerOp = 8;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < Elements::kH; ++j) {
+
+ // selects the 32x32 tile
+ Coord<2> tile_32x32 = make_Coord(idx[0] / 16, j);
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < Elements::kW; ++i) {
+
+ Coord<2> mma_op = make_Coord(((idx[0] >> 2) & 1), i & 1);
+
+ int element = (idx[0] & 3) | ((idx[0] >> 1) & 4);
+
+ int mma_op_idx = mma_op[1] + mma_op[0] * 2 + 4 * (tile_32x32[1] + 2 * tile_32x32[0]);
+
+ frag[i + j * Elements::kW] = accum[element + kAccumPerOp * mma_op_idx];
+ }
+ }
+
+ return frag;
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+//
+//
+//
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <
+ /// The warp-level GEMM tile
+ typename WarpGemmTile_,
+ /// Tiling of warp accumulator elements
+ typename WarpDelta_,
+ /// Size of vector to load or store
+ int AccessSize,
+ /// The accumulators fragment type - implies accumulator layout
+ typename Accumulators_>
+struct Volta884EpilogueGlobalTileTraits;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Global tile traits specialized for Volta mma.sync.F32 layout
+template <
+ /// The warp-level GEMM tile
+ typename WarpGemmTile_,
+ /// Tiling of warp accumulator elements
+ typename WarpDelta_,
+ /// Size of vector to load or store
+ int AccessSize>
+struct Volta884EpilogueGlobalTileTraits {
+ /// Shape of warp-scoped GEMM tile
+ typedef WarpGemmTile_ WarpGemmTile;
+
+ /// Structure of MMA
+ typedef WarpDelta_ WarpDelta;
+
+ /// Access size of input/output elements
+ static int const kAccessSize = AccessSize;
+
+ /// Scalar type of accumulators - used to imply accumulator layout, not the data
+ typedef float Accumulators;
+
+ /// Strides for immediate offset computation
+ typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
+
+ //typedef Shape<2, 2, 1, 1> Iterations;
+
+ /// Hard-coded pitch between Volta mma.sync Quad Pair tiles
+ static int const kMmaQuadPairWidth = 16;
+
+ /// Hard-coded pitch between warp tiles
+ static int const kInterleavedTileWidth = 32;
+
+ /// Number of actual threads
+ static int const kThreadCount = (WarpDelta::kH * WarpDelta::kW) * kWarpSize;
+
+ /// Shape of the tile
+ typedef Shape<2 * WarpDelta::kH, 2, WarpGemmTile::kW * WarpDelta::kW, 1> Tile;
+
+ /// Number of iterations
+ typedef Shape<2 * WarpDelta::kH,
+ (kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH),
+ (kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount),
+ 1> Iterations;
+
+ /// Delta between accesses
+ typedef Shape Delta;
+
+ /// Number of warps in threadblock
+ static int const kWarpCount = ShapeCount