diff --git a/CHANGELOG.md b/CHANGELOG.md index b956c07351..9bf6d239ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,7 @@ # NVIDIA CUTLASS Changelog -## [1.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.1) (2018-12-19) - * Resolved issue with sm50 and sm52 architectures +## [1.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.0) (2019-03-20) + * Efficient GEMM kernel targeting Volta Tensor Cores via `mma.sync` instruction added in CUDA 10.1. ## [1.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.0) (2018-10-26) * Parallelized reductions across threadblocks ("Split-K") diff --git a/CMakeLists.txt b/CMakeLists.txt index 2ec8cd7bba..25a967b881 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -20,7 +20,7 @@ # STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -cmake_minimum_required(VERSION 3.3.0) +cmake_minimum_required(VERSION 3.3.0 FATAL_ERROR) set(CUTLASS_LANGUAGES CXX) @@ -36,7 +36,8 @@ else() # FindCUDA fails to detect VS 2017 due to a changed directory format of the toolkits. # For this configuration we need CMake >= 3.9.0 to use the native CUDA support. if (WIN32 AND MSVC_VERSION GREATER 1800) - message(FATAL_ERROR "Please upgrade CMake to version >= 3.9.0 to support Visual Studio 2017 or higher") + message(SEND_ERROR "Please upgrade CMake to version >= 3.9.0 to support Visual Studio 2017 or higher") + cmake_minimum_required(VERSION 3.9.0 FATAL_ERROR) endif() # Fall back to the FindCUDA version to create an executable with CUDA files @@ -52,7 +53,11 @@ if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 ) message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!") endif() -find_package(CUDA) +find_package(CUDA REQUIRED) +include_directories(SYSTEM ${CUDA_INCLUDE_DIRS}) +# Some platforms (e.g. Visual Studio) don't add the CUDA include directories to the system include +# paths by default, so we add it explicitly here. + find_package(Doxygen QUIET) ################################################################################################### @@ -61,9 +66,18 @@ find_package(Doxygen QUIET) # ################################################################################################### -find_library(CUBLAS_LIBRARY cublas HINTS +# +# Conditionally enable cuBLAS +# +set(CUTLASS_ENABLE_CUBLAS ON CACHE BOOL "Enable CUTLASS Tests to build with cuBLAS library.") + +if(CUTLASS_ENABLE_CUBLAS) + + find_library(CUBLAS_LIBRARY cublas HINTS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64) +endif() + # By default we want to build in Release mode to ensure that we're getting best performance if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES)) @@ -78,26 +92,56 @@ if(WIN32) endif() if (WIN32) - # Enable more warnings and treat as errors - string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX") + # Enable more warnings and treat as errors + string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX") - # Disable warning on Unicode characters - string(APPEND NVCC_FLAGS " -Xcompiler /wd4819") + # Disable warning on Unicode characters + string(APPEND NVCC_FLAGS " -Xcompiler /wd4819") - # Disable excess x86 floating point precision that can lead to results being labeled incorrectly - string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict") + # Disable excess x86 floating point precision that can lead to results being labeled incorrectly + string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict") - # Verbose option - if (${CUTLASS_NVCC_VERBOSE}) - string(APPEND NVCC_FLAGS " -v") - endif() + # Verbose option + if (${CUTLASS_NVCC_VERBOSE}) + string(APPEND NVCC_FLAGS " -v") + endif() endif(WIN32) -set(CUTLASS_NVCC_ARCHS "50;60;61;70;75" CACHE STRING "The SM architectures to build code for.") +set(CUTLASS_NVCC_ARCHS_DEFAULT "") +if(NOT CUDA_VERSION VERSION_LESS 7.5) + list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 50) +endif() +if(NOT CUDA_VERSION VERSION_LESS 8.0) + list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 60 61) +endif() +if(NOT CUDA_VERSION VERSION_LESS 9.0) + list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 70) +endif() +if(NOT CUDA_VERSION VERSION_LESS 9.2) + list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 72) +endif() +if(NOT CUDA_VERSION VERSION_LESS 10.0) + list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 75) +endif() +set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_DEFAULT} CACHE STRING "The SM architectures to build code for.") + set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.") set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.") set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.") +# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations. +if (CUDA_VERSION VERSION_LESS 10.1) + set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT OFF) +else() + set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT ON) +endif() + +set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL + "Enable PTX mma instruction for collective matrix multiply operations.") + +set(CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST ${CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST} CACHE BOOL + "Enable more kernels instantiated in the perf suite. This might result in longer compiler time. ") + # # NOTE: running with asan and CUDA requires the following environment variable: # @@ -131,6 +175,18 @@ foreach(ARCH ${CUTLASS_NVCC_ARCHS}) endif() endforeach() +if (CUTLASS_ENABLE_TENSOR_CORE_MMA) + string(APPEND NVCC_FLAGS " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1") +endif() + +if (CUTLASS_ENABLE_CUBLAS) + string(APPEND NVCC_FLAGS " -DCUTLASS_ENABLE_CUBLAS=1") +endif() + +if (CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST) + add_definitions(-DEXHAUSTIVE_PROF) +endif() + if (CUTLASS_NVCC_KEEP) string(APPEND NVCC_FLAGS " -keep") endif() @@ -174,6 +230,7 @@ file(GLOB CUTLASS_UTIL RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/util/*.h) file(GLOB CUTLASS_DEVICE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/device/*.h) file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h) file(GLOB CUTLASS_REDUCTION RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/reduction/*.h ) +file(GLOB CUTLASS_LAYOUT_THREAD RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/layout/thread/*.h) ################################################################################################### # @@ -185,16 +242,24 @@ source_group("cutlass\\gemm" FILES ${CUTLASS_GEMM}) source_group("cutlass\\util" FILES ${CUTLASS_UTIL}) source_group("cutlass\\device" FILES ${CUTLASS_DEVICE}) source_group("cutlass\\reduction" FILES ${CUTLASS_REDUCTION}) +source_group("cutlass\\layout\\thread" FILES ${CUTLASS_LAYOUT_THREAD}) source_group("cutlass" FILES ${CUTLASS_CORE}) add_library(CUTLASS INTERFACE) include_directories("${CMAKE_CURRENT_SOURCE_DIR}") + +# Special policy introduced in CMake 3.13 +if (POLICY CMP0076) + cmake_policy(SET CMP0076 NEW) +endif() + target_sources(CUTLASS INTERFACE ${CUTLASS_GEMM} ${CUTLASS_UTIL} ${CUTLASS_DEVICE} ${CUTLASS_CORE} ${CUTLASS_REDUCTION} + ${CUTLASS_LAYOUT_THREAD} ) target_include_directories(CUTLASS INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) @@ -206,6 +271,7 @@ add_custom_target(cutlass_ide SOURCES ${CUTLASS_DEVICE} ${CUTLASS_CORE} ${CUTLASS_REDUCTION} + ${CUTLASS_LAYOUT_THREAD} ) # Doxygen is available. Generate documentation if (DOXYGEN_FOUND) diff --git a/CUTLASS.md b/CUTLASS.md index eb9e25e5f6..b6553db4e9 100644 --- a/CUTLASS.md +++ b/CUTLASS.md @@ -14,7 +14,7 @@ CUTLASS core components, and to identify their role in implementing GEMM computa # 1. Design Patterns CUTLASS strives to achieve the highest performance possible on NVIDIA GPUs while also offering a -flexible composition that an be easily applied to solve new problems related to Deep Learning and +flexible composition that can be easily applied to solve new problems related to Deep Learning and linear algebra. Though we intend to make CUTLASS as simple and straightforward as possible, given a tradeoff between simplicity and performance, CUTLASS chooses performance. Consequently, several design patterns are necessary to yield a composable structure while also satisfying these performance @@ -31,7 +31,7 @@ CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvla ## Tiles and Iterators -Efficient dense linear algebra computations emphasize data movement to match the execution of mathemtical operators to the flow of data. Consequently, CUTLASS defines a rich set of primitives for partitioning a tile of data among participating threads, warps, and threadblocks. CUTLASS applies the familiar iterator design pattern to provide an abstraction layer to (1.) access these tile objects and (2.) traverse a sequence of objects embedded in a higher level data structure. These subpartitions are typically defined by compile-time constants +Efficient dense linear algebra computations emphasize data movement to match the execution of mathematical operators to the flow of data. Consequently, CUTLASS defines a rich set of primitives for partitioning a tile of data among participating threads, warps, and threadblocks. CUTLASS applies the familiar iterator design pattern to provide an abstraction layer to (1.) access these tile objects and (2.) traverse a sequence of objects embedded in a higher level data structure. These subpartitions are typically defined by compile-time constants specifying element type, size, and data layout. CUTLASS refers to subpartitions as _tiles_. _Iterators_ are familiar design patterns in C++ that provide an abstraction for accessing individual @@ -353,7 +353,7 @@ An example of splitK usage can be found [here](examples/06_splitK_gemm/splitK_ge # Copyright -Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/README.md b/README.md index c612b0d2f4..231eafbab3 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 1.2 +# CUTLASS 1.3 -_CUTLASS 1.2 - October 2018_ +_CUTLASS 1.3.0 - March 2019_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA. @@ -20,13 +20,18 @@ multiply-accumulate abstractions for 8-bit integer, half-precision floating point (FP16), single-precision floating point (FP32), and double-precision floating point (FP64) types. Furthermore, CUTLASS demonstrates CUDA's WMMA API for targeting the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture -and beyond. +and beyond. Even faster performance on Volta is possible via direct access to +Volta Tenor Cores via `mma.sync` (added in CUDA 10.1). -CUTLASS 1.2 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying +CUTLASS 1.3 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying [Doxygen documentation](https://nvidia.github.io/cutlass). We describe the structure of an efficient GEMM in our talk at the [GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf). +# What's New in CUTLASS 1.3 +_March 2019_ +* CUTLASS 1.3 includes an efficient GEMM implementation with the `mma.sync` instruction added in CUDA 10.1. + # What's New in CUTLASS 1.2 _October 2018_ * [Parallelized Reductions](CUTLASS.md#parallel-reductions-across-gemm-k) @@ -63,8 +68,8 @@ when compiled with CUDA 10.0. # Compatibility -CUTLASS performs best when compiled with the [CUDA 10.0 Toolkit](ttps://developer.nvidia.com/cuda-toolkit). -It is compatible with CUDA 9.0, 9.1, and 9.2, but these versions of the CUDA Toolkit do not support new Turing WMMA features. +CUTLASS performs best when compiled with the [CUDA 10.1 Toolkit](ttps://developer.nvidia.com/cuda-toolkit). +It is also compatible with CUDA 9.0, 9.1, 9.2, and 10.0. We have tested the following environments. @@ -77,7 +82,7 @@ We have tested the following environments. | Ubuntu 18.04 | GCC 7.3.0 | CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on -any Maxwell-, Pascal-, or Volta-architecture NVIDIA GPU. +any Maxwell-, Pascal-, Volta-, and Turing-architecture NVIDIA GPUs. |**GPU**| |---| @@ -220,6 +225,9 @@ Program usage: # Varies GEMM K dimension for SGEMM and IGEMM with column-major multiplicands $ ./tools/test/perf/cutlass_perf_test --m=10240 --n=4096 --k=1024:8192:128 --kernels=sgemm_nn,igemm_nn + + # Executes GEMM kernel on Volta Tensor Cores + $ ./tools/test/perf/cutlass_perf_test --kernels=s884gemm_nt ``` # About @@ -230,7 +238,7 @@ CUTLASS is released by NVIDIA Corporation as Open Source software under the # Copyright -Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted @@ -253,4 +261,3 @@ Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ``` - diff --git a/cutlass/arch/mma.h b/cutlass/arch/mma.h new file mode 100644 index 0000000000..1b01df4663 --- /dev/null +++ b/cutlass/arch/mma.h @@ -0,0 +1,380 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates wrapping direct issue of MMA instructions to Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/shape.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specifies internal data type for computation +struct ComputeType { + enum Kind { + kBegin, + kDefault, /// Compute type implied by operand and accumulator types + kEnd + }; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Direct wrapper for native MMA instruction +template < + /// Warp-level matrix multiply-accumulate operation + typename WmmaTile, + /// Layout of A multiplicand + MatrixLayout::Kind LayoutA, + /// Data type of A multiplicand + typename ScalarA, + /// Layout of B multiplicand + MatrixLayout::Kind LayoutB, + /// Data type of A multiplicand + typename ScalarB, + /// Data type of accumulators + typename ScalarC, + /// Specifies particular compute type, overriding data types of operands + ComputeType::Kind ComputeTy> +inline __device__ void mma(ScalarA const A[], ScalarB const B[], ScalarC const C[], ScalarC D[]); + +///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// 16x16x4 +// + +// +// FP16 accumulation +// + +/// Volta mma.sync instruction +template <> +inline __device__ void mma, + MatrixLayout::kRowMajor, + half, + MatrixLayout::kColumnMajor, + half, + half, + ComputeType::kDefault>(half const a[], + half const b[], + half const c[], + half d[]) { +#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA) + + unsigned const *A = reinterpret_cast(a); + unsigned const *B = reinterpret_cast(b); + unsigned const *C = reinterpret_cast(c); + unsigned *D = reinterpret_cast(d); + + asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1 +#endif +} + +/// Volta mma.sync instruction +template <> +inline __device__ void mma, + MatrixLayout::kColumnMajor, + half, + MatrixLayout::kColumnMajor, + half, + half, + ComputeType::kDefault>(half const a[], + half const b[], + half const c[], + half d[]) { +#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA) + + unsigned const *A = reinterpret_cast(a); + unsigned const *B = reinterpret_cast(b); + unsigned const *C = reinterpret_cast(c); + unsigned *D = reinterpret_cast(d); + + asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1 +#endif +} + +/// Volta mma.sync instruction +template <> +inline __device__ void mma, + MatrixLayout::kRowMajor, + half, + MatrixLayout::kRowMajor, + half, + half, + ComputeType::kDefault>(half const a[], + half const b[], + half const c[], + half d[]) { +#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA) + + unsigned const *A = reinterpret_cast(a); + unsigned const *B = reinterpret_cast(b); + unsigned const *C = reinterpret_cast(c); + unsigned *D = reinterpret_cast(d); + + asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1 +#endif +} + +/// Volta mma.sync instruction +template <> +inline __device__ void mma, + MatrixLayout::kColumnMajor, + half, + MatrixLayout::kRowMajor, + half, + half, + ComputeType::kDefault>(half const a[], + half const b[], + half const c[], + half d[]) { +#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA) + + unsigned const *A = reinterpret_cast(a); + unsigned const *B = reinterpret_cast(b); + unsigned const *C = reinterpret_cast(c); + unsigned *D = reinterpret_cast(d); + + asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1 +#endif +} + +// +// FP32 accumulation +// + +/// Volta mma.sync instruction +template <> +inline __device__ void mma, + MatrixLayout::kRowMajor, + half, + MatrixLayout::kColumnMajor, + half, + float, + ComputeType::kDefault>(half const a[], + half const b[], + float const C[], + float D[]) { +#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA) + + unsigned const *A = reinterpret_cast(a); + unsigned const *B = reinterpret_cast(b); + + asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};" + : "=f"(D[0]), + "=f"(D[1]), + "=f"(D[2]), + "=f"(D[3]), + "=f"(D[4]), + "=f"(D[5]), + "=f"(D[6]), + "=f"(D[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7])); + +#else + CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1 +#endif +} + +/// Volta mma.sync instruction +template <> +inline __device__ void mma, + MatrixLayout::kColumnMajor, + half, + MatrixLayout::kColumnMajor, + half, + float, + ComputeType::kDefault>(half const a[], + half const b[], + float const C[], + float D[]) { + +#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA) + + unsigned const *A = reinterpret_cast(a); + unsigned const *B = reinterpret_cast(b); + + asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};" + : "=f"(D[0]), + "=f"(D[1]), + "=f"(D[2]), + "=f"(D[3]), + "=f"(D[4]), + "=f"(D[5]), + "=f"(D[6]), + "=f"(D[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7])); + +#else + CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1 +#endif +} + +/// Volta mma.sync instruction +template <> +inline __device__ void mma, + MatrixLayout::kRowMajor, + half, + MatrixLayout::kRowMajor, + half, + float, + ComputeType::kDefault>(half const a[], + half const b[], + float const C[], + float D[]) { +#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA) + + unsigned const *A = reinterpret_cast(a); + unsigned const *B = reinterpret_cast(b); + + asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};" + : "=f"(D[0]), + "=f"(D[1]), + "=f"(D[2]), + "=f"(D[3]), + "=f"(D[4]), + "=f"(D[5]), + "=f"(D[6]), + "=f"(D[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7])); + +#else + CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1 +#endif +} + +/// Volta mma.sync instruction +template <> +inline __device__ void mma, + MatrixLayout::kColumnMajor, + half, + MatrixLayout::kRowMajor, + half, + float, + ComputeType::kDefault>(half const a[], + half const b[], + float const C[], + float D[]) { +#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA) + + unsigned const *A = reinterpret_cast(a); + unsigned const *B = reinterpret_cast(b); + + asm volatile ("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};" + : "=f"(D[0]), + "=f"(D[1]), + "=f"(D[2]), + "=f"(D[3]), + "=f"(D[4]), + "=f"(D[5]), + "=f"(D[6]), + "=f"(D[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7])); + +#else + CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1 +#endif +} + +} // namespace arch +} // namespace cutlass diff --git a/cutlass/convert.h b/cutlass/convert.h index b4d0f8eddb..23fa5d560f 100644 --- a/cutlass/convert.h +++ b/cutlass/convert.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/coord.h b/cutlass/coord.h index e90af8a1b8..7e91d6e99d 100644 --- a/cutlass/coord.h +++ b/cutlass/coord.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/core_io.h b/cutlass/core_io.h index 849a7613f4..edfc1ec803 100644 --- a/cutlass/core_io.h +++ b/cutlass/core_io.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/cutlass.h b/cutlass/cutlass.h index ac5420d724..26de6c0278 100644 --- a/cutlass/cutlass.h +++ b/cutlass/cutlass.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -29,11 +29,12 @@ #pragma once + //////////////////////////////////////////////////////////////////////////////////////////////////// #define CUTLASS_MAJOR 1 -#define CUTLASS_MINOR 2 -#define CUTLASS_PATCH 1 +#define CUTLASS_MINOR 3 +#define CUTLASS_PATCH 0 #define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH) #ifdef __NVCC__ @@ -47,9 +48,31 @@ // CUTLASS_DEVICE is an error if not compiling device code #endif +// CUDA 10.1 introduces the mma instruction +#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) +#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0 +#endif + +// CUTLASS assert #define CUTLASS_ASSERT(x) assert(x) -#include "cutlass/util/performance_tuning.h" +// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler. +#if defined(__CUDA_ARCH__) + #define CUTLASS_PRAGMA_UNROLL #pragma unroll + #define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1 + + #define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL + + #define CUTLASS_GEMM_LOOP_HEADER \ + asm volatile (".pragma \"nounroll\";\n"); +#else + + #define CUTLASS_PRAGMA_UNROLL + #define CUTLASS_PRAGMA_NO_UNROLL + #define CUTLASS_GEMM_LOOP_HEADER + #define CUTLASS_GEMM_LOOP + +#endif // A small helper class to dump a type at compile time // Usage:: DumpType::Class diff --git a/cutlass/fragment.h b/cutlass/fragment.h index 6a93d779c4..e048c525b1 100644 --- a/cutlass/fragment.h +++ b/cutlass/fragment.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -98,9 +98,9 @@ struct StorageType<1> { template struct Fragment : public AlignedStruct { /// Make sure the alignment makes sense wrt the size of elements. - static_assert(kAlignment_ == 16 || kAlignment_ >= sizeof(Element_), "Alignment is too small"); + static_assert(int(kAlignment_) == 16 || int(kAlignment_) >= sizeof(Element_), "Alignment is too small"); /// Alignment must be a power of two - static_assert(is_pow2::value, "Alignment must be a power of two"); + static_assert(is_pow2::value, "Alignment must be a power of two"); /// This class. typedef Fragment This_; @@ -109,27 +109,31 @@ struct Fragment : public AlignedStruct { /// The number of elements. static int const kElements = kElements_; /// Alignment - static int const kAlignment = kAlignment_; + static int const kAlignment = int(kAlignment_); /// Clear a fragment. CUTLASS_HOST_DEVICE void clear() { // Avoid element-wise access for sub 32b element type if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) { uint64_t* ptr = reinterpret_cast(storage); + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < (kElements * sizeof(Element)) / 8; ++i) { ptr[i] = uint64_t(0); } } else if (kAlignment_ >= 4 && (kElements * sizeof(Element)) % 4 == 0) { uint32_t* ptr = reinterpret_cast(storage); + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < (kElements * sizeof(Element)) / 4; ++i) { ptr[i] = uint32_t(0); } } else if (kAlignment_ >= 2 && (kElements * sizeof(Element)) % 2 == 0) { uint16_t* ptr = reinterpret_cast(storage); + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < (kElements * sizeof(Element)) / 2; ++i) { ptr[i] = uint16_t(0); } } else { + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kElements; ++i) { storage[i] = 0; } @@ -146,7 +150,7 @@ struct Fragment : public AlignedStruct { private: /// Storage type to use for Elements - typedef typename StorageType::Type StorageType; + typedef typename StorageType::Type StorageType; /// Number of elements in the storage static int const kStorageCount = diff --git a/cutlass/fragment_multiply_add.h b/cutlass/fragment_multiply_add.h index 8bcf81209a..f29b08e7ca 100644 --- a/cutlass/fragment_multiply_add.h +++ b/cutlass/fragment_multiply_add.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -68,7 +68,6 @@ struct FragmentMultiplyAdd { FragmentB_ const& b, FragmentCd_ const& c, FragmentCd_& d) { - int const kReduction = FragmentB_::kElements / FragmentCd_::kElements; for (int j = 0; j < FragmentCd_::kElements; ++j) { d[j] = b[j * kReduction + 0]; diff --git a/cutlass/gemm/clear_accumulators.h b/cutlass/gemm/clear_accumulators.h index 3a2f337525..9e336cb6f3 100644 --- a/cutlass/gemm/clear_accumulators.h +++ b/cutlass/gemm/clear_accumulators.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/device_gemm.h b/cutlass/gemm/device_gemm.h index aaf4bfe783..1380f90efc 100644 --- a/cutlass/gemm/device_gemm.h +++ b/cutlass/gemm/device_gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -33,6 +33,8 @@ #include "cutlass/coord.h" #include "cutlass/util/platform.h" +#include "cutlass/gemm/gemm.h" + namespace cutlass { namespace gemm { @@ -47,7 +49,8 @@ struct DeviceGemm { #if !defined(__CUDACC_RTC__) /// Launch the kernels in order static __host__ cudaError_t launch(Params const& params) { - Traits::GemmTraits::KernelClass::launch(params.GemmParams); + //Traits::GemmTraits::KernelClass::launch(params.GemmParams); + Gemm::launch(params.GemmParams); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) return err; diff --git a/cutlass/gemm/device_gemm_traits.h b/cutlass/gemm/device_gemm_traits.h index fbcfef3e1f..1ff2a11cc9 100644 --- a/cutlass/gemm/device_gemm_traits.h +++ b/cutlass/gemm/device_gemm_traits.h @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -73,7 +73,7 @@ struct SplitkPIGemmTraits { /// The pointer to workspace memory ScalarAccum *workspace_ptr; /// - int workspace_size; + size_t workspace_size; /// The Params for the first kernel typename GemmTraits::Params GemmParams; /// The Params for the second kernel @@ -112,7 +112,8 @@ struct SplitkPIGemmTraits { Index ldc_, ScalarD* d_d_, Index ldd_, - ScalarAccum *workspace_ptr_) { + ScalarAccum *workspace_ptr_, + Index partitionK_multiple = 1) { workspace_ptr = workspace_ptr_; @@ -133,7 +134,7 @@ struct SplitkPIGemmTraits { TensorRef(workspace_ptr, problem_size.m()), /*m = ldc, workspace is not transposed and is packed*/ TensorRef(workspace_ptr, problem_size.m()) /*m = ldd, workspace is not transposed and is packed*/ ); - GemmParams.initialize(desc, ReductionTraits::ReductionSize); + GemmParams.initialize(desc, ReductionTraits::ReductionSize, partitionK_multiple); //call batched reduction (second kernel) param @@ -155,9 +156,12 @@ struct SplitkPIGemmTraits { // workspace will be used to store D (output) from the first gemm kernel (not D of the entire gemm) // note typedef typename GemmTraits::ScalarD ScalarAccum; // workspace of size of M * N * Reduction - int required_workspace_memory_in_byte(){ + size_t required_workspace_memory_in_byte(){ assert(problem_size_initialized == true); - workspace_size = problem_size.n() * problem_size.m() * ReductionTraits::ReductionSize * static_cast(sizeof(ScalarAccum)); + workspace_size = static_cast(problem_size.n()) * + static_cast(problem_size.m()) * + static_cast(ReductionTraits::ReductionSize) * + sizeof(ScalarAccum); return workspace_size; } diff --git a/cutlass/gemm/dgemm_traits.h b/cutlass/gemm/dgemm_traits.h index 5c05590207..4005bf9647 100644 --- a/cutlass/gemm/dgemm_traits.h +++ b/cutlass/gemm/dgemm_traits.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/fp16_sgemm_multiply_add.h b/cutlass/gemm/fp16_sgemm_multiply_add.h index 534b8c8998..b45ae8cfbb 100644 --- a/cutlass/gemm/fp16_sgemm_multiply_add.h +++ b/cutlass/gemm/fp16_sgemm_multiply_add.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -29,6 +29,7 @@ #include "cutlass/fragment.h" #include "cutlass/gemm/thread_multiply_add.h" + namespace cutlass { namespace gemm { @@ -69,8 +70,10 @@ struct ThreadMultiplyAdd { FragmentB const& b, Accumulators const& c, Accumulators& d) { + for (int j = 0; j < AccumulatorsPerThread::kH; ++j) { for (int i = 0; i < AccumulatorsPerThread::kW; ++i) { + d[j * AccumulatorsPerThread::kW + i] = static_cast(a[i]) * static_cast(b[j]) + c[j * AccumulatorsPerThread::kW + i]; } } diff --git a/cutlass/gemm/fp16_sgemm_traits.h b/cutlass/gemm/fp16_sgemm_traits.h index 361186455b..6ad6d96385 100644 --- a/cutlass/gemm/fp16_sgemm_traits.h +++ b/cutlass/gemm/fp16_sgemm_traits.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/gemm.h b/cutlass/gemm/gemm.h index 3aec792866..0d91919931 100644 --- a/cutlass/gemm/gemm.h +++ b/cutlass/gemm/gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -33,20 +33,30 @@ #include "cutlass/coord.h" #include "cutlass/util/platform.h" +#include namespace cutlass { namespace gemm { +//////////////////////////////////////////////////////////////////////////////////////////////////// + + //////////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM kernel with launch bounds specified template __global__ __launch_bounds__(Gemm_::kThreads) void gemm_kernel(typename Gemm_::Params params) { - // Declare shared memory. - __shared__ typename Gemm_::SharedStorage shared_storage; + + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Gemm_::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); // Construct the GEMM object. - Gemm_ gemm(params, shared_storage); + Gemm_ gemm(params, *shared_storage); + // Run GEMM. gemm.multiply_add(); } @@ -57,11 +67,17 @@ void gemm_kernel(typename Gemm_::Params params) { template __global__ /* __launch_bounds__(Gemm_::kThreads) */ void gemm_kernel_nolb(typename Gemm_::Params params) { - // Declare shared memory. - __shared__ typename Gemm_::SharedStorage shared_storage; + + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Gemm_::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); // Construct the GEMM object. - Gemm_ gemm(params, shared_storage); + Gemm_ gemm(params, *shared_storage); + // Run GEMM. gemm.multiply_add(); } @@ -72,7 +88,31 @@ void gemm_kernel_nolb(typename Gemm_::Params params) { template struct Launch { Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) { - gemm_kernel<<< grid, block, 0, stream >>>(params); + + int smem_size = int(sizeof(typename Gemm::SharedStorage)); + if (smem_size >= (48 << 10)) { + + cudaError_t result = cudaFuncSetAttribute( + gemm_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size + ); + + if (result != cudaSuccess) { + return; + } + + result = cudaFuncSetAttribute( + gemm_kernel_nolb, + cudaFuncAttributePreferredSharedMemoryCarveout, + 100); + + if (result != cudaSuccess) { + return; + } + } + + gemm_kernel<<< grid, block, sizeof(typename Gemm::SharedStorage), stream >>>(params); } }; @@ -82,50 +122,51 @@ struct Launch { template struct Launch { Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) { - gemm_kernel_nolb<<< grid, block, 0, stream >>>(params); + int smem_size = int(sizeof(typename Gemm::SharedStorage)); + if (smem_size >= (48 << 10)) { + + cudaError_t result = cudaFuncSetAttribute( + gemm_kernel_nolb, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size + ); + + if (result != cudaSuccess) { + return; + } + + result = cudaFuncSetAttribute( + gemm_kernel_nolb, + cudaFuncAttributePreferredSharedMemoryCarveout, + 100); + + if (result != cudaSuccess) { + // throw exception? + return; + } + } + + gemm_kernel_nolb<<< + grid, + block, + smem_size, + stream >>>(params); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Gemm { - /// This class. - typedef Gemm This_; + /// The traits. - typedef GemmTraits_ Traits; - /// The shared storage. - typedef typename Traits::SharedStorage SharedStorage; - - /// The scalar for A. - typedef typename Traits::ScalarA ScalarA; - /// The scalar for B. - typedef typename Traits::ScalarB ScalarB; - /// The scalar in the epilogue. - typedef typename Traits::Epilogue::Scalar ScalarEpilogue; - /// The scalar for C. - typedef typename Traits::Epilogue::ScalarC ScalarC; - /// The scalar for D. - typedef typename Traits::Epilogue::ScalarD ScalarD; - /// The index. - typedef typename Traits::Index Index; - - /// Define the mainloop iteration size - typedef typename Traits::MultiplyAdd MultiplyAdd; - - /// The number of threads. - static int const kThreads = Traits::GemmConfig::kThreads; - - // Number of warp-level multiply-accumulate steps executed by each warp. - static Index const kWarpGemmSteps = - Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD; - - // Make sure we have at least 2 unrolling steps or our pipeling is not going to work. - static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps"); + typedef Traits_ Traits; /// Use the params object defined in traits typedef typename Traits::Params Params; + typedef typename Traits::KernelClass KernelClass; + // // Static function members // @@ -137,7 +178,7 @@ struct Gemm { cudaStream_t stream = cudaStreamDefault) { // Launch the kernel. - Launch( + Launch( params, params.grid, params.block, stream); return cudaGetLastError(); @@ -164,189 +205,6 @@ struct Gemm { } #endif - - // - // Methods - // - - /// Ctor. - CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_) - : params(params_), shared_storage(shared_storage_) {} - - /// Computes a warp-level GEMM on data held in shared memory - template - CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream, - typename Traits::SharedStream& shared_load_stream, - typename MultiplyAdd::Accumulators& accumulators, - Index outer_k) { - // If residue portion and not calculating residue in prolog, update residue predicates now. - if (Residue && outer_k <= Traits::OutputTile::kD) { - global_to_shared_stream.residue(outer_k); - } - - // Load data for the next iteration of the main loop (unless it's the last iteration). - if (!LastIteration) { - global_to_shared_stream.copy(); - } - - CUTLASS_PRAGMA_UNROLL - for (int step = 0; step < kWarpGemmSteps - 1; ++step) { - // Trigger the copy from shared memory for the next A/B values. - shared_load_stream.copy(step + 1); - - // Make sure the values are available for the current iteration to do the multiply-add. - shared_load_stream.commit(step); - - MultiplyAdd multiply_add; - - // Do the math on the fragments of the current iteration. - multiply_add.multiply_add(shared_load_stream.fragment_a(step), - shared_load_stream.fragment_b(step), - accumulators, - accumulators); - } - - // Make sure the data from shared memory has been entirely consumed. - Traits::shared_load_fence(true); - - // Commit the data in shared memory for A/B. - if (!LastIteration) { - global_to_shared_stream.commit(); - } - // Make sure the data is in shared memory. - Traits::shared_store_fence(true); - - if (!LastIteration) { - // Move to the next stage for the load (if it makes sense). - shared_load_stream.inc_stage(); - // Trigger the copy from shared memory for the next loop iteration. - shared_load_stream.copy(0); - } - // Make sure the values are available for the current iteration to do the multiply-add. - shared_load_stream.commit(kWarpGemmSteps - 1); - - // Do the math on the fragments of the current iteration. - MultiplyAdd multiply_add; - multiply_add.multiply_add(shared_load_stream.fragment_a(kWarpGemmSteps - 1), - shared_load_stream.fragment_b(kWarpGemmSteps - 1), - accumulators, - accumulators); - } - - /// Do the GEMM. - CUTLASS_DEVICE void multiply_add() { - // Swizzle the IDs of the block (to enable better cache behavior). - typename Traits::BlockSwizzle block_swizzle; - Coord<3> threadblock_offset = - block_swizzle.get_threadblock_offset(make_Coord_from_shape()); - - // We may want to use shared memory to clear the registers. - typedef typename Traits::ClearAccumulators ClearAccumulators; - - // Get the bounds for each thread, it maybe different than problem_size - Coord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size, - params.partitionK_range); - - // The streams to read A/B from global memory to shared memory. - typename Traits::GlobalLoadStream global_to_shared_stream( - params.global_to_shared_stream, - shared_storage.main_loop.global_to_shared_stream, - shared_storage.main_loop.threadblock_tile.reference(), - bounds, - threadblock_offset); - - // update A and B pointer offset based on batch_id and batch_stride_offset - global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id()); - - // Create the accumulator clear. - ClearAccumulators clear; - - // Deal with residue in prolog. - // global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD); - global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD); - - // Fetch the fragments for A and B from global memory. - global_to_shared_stream.copy(); - - // Copy the elements to shared memory (after transformation if needed). - global_to_shared_stream.commit(); - - // Make sure the data is in shared memory. - Traits::shared_store_fence(false); - - // Rollback to the beginning of the first tile (if residue exists). - // global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD); - global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD); - - // The stream of data from shared memory to fragments. - typename Traits::SharedStream shared_load_stream( - params.shared_stream, - shared_storage.main_loop.threadblock_tile.reference()); - - // Trigger the copy from shared memory for the 1st stream. - shared_load_stream.copy(0); - - // Allocate the accumulators. - typename MultiplyAdd::Accumulators accumulators; - - // Clear the accumulators. - clear.clear(accumulators); - - // Initial index - // Index outer_k = params.problem_size[0] - Traits::OutputTile::kD; - // problem_size[0] might be bigger than bounds[0] - Index outer_k = bounds[0] - Traits::OutputTile::kD; - // Check if we are computing residue in prolog or not. - if (Traits::GemmConfig::kResidueInProlog) { - // Execute all mainloop iterations but the last one. - - CUTLASS_GEMM_LOOP - for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) { - consume_tile( - global_to_shared_stream, shared_load_stream, accumulators, outer_k); - } - - // Don't load data for the last "residue" portion since we've already computed the residue. - CUTLASS_GEMM_LOOP - for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) { - consume_tile( - global_to_shared_stream, shared_load_stream, accumulators, outer_k); - } - } else { - // When kResidueSeparate = true, execute all mainloop iterations but the last two without any - // consideration for K-residue or predicate updates. This improves the steady state of some - // kernels. - if (Traits::GemmConfig::kResidueSeparate) { - - CUTLASS_GEMM_LOOP - for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) { - consume_tile( - global_to_shared_stream, shared_load_stream, accumulators, outer_k); - } - } - - // Execute remaining tiles with K-residue predicate updates enabled. - CUTLASS_GEMM_LOOP - for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) { - consume_tile( - global_to_shared_stream, shared_load_stream, accumulators, outer_k); - } - } - - // Epilogue. - typedef typename Traits::Epilogue Epilogue; - Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm()); - epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id()); - } - - // - // Data members - // - - /// The params. - Params const& params; - /// The shared storage. - SharedStorage& shared_storage; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass/gemm/gemm_config.h b/cutlass/gemm/gemm_config.h index 76df0add62..1153cc470e 100644 --- a/cutlass/gemm/gemm_config.h +++ b/cutlass/gemm/gemm_config.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/gemm_coord.h b/cutlass/gemm/gemm_coord.h index e029af3522..4d64eb96c8 100644 --- a/cutlass/gemm/gemm_coord.h +++ b/cutlass/gemm/gemm_coord.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/gemm_desc.h b/cutlass/gemm/gemm_desc.h index 80f4b36557..597e46d5c7 100644 --- a/cutlass/gemm/gemm_desc.h +++ b/cutlass/gemm/gemm_desc.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/gemm_epilogue.h b/cutlass/gemm/gemm_epilogue.h index 0e0cfc5374..086b50901a 100644 --- a/cutlass/gemm/gemm_epilogue.h +++ b/cutlass/gemm/gemm_epilogue.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/gemm_epilogue_traits.h b/cutlass/gemm/gemm_epilogue_traits.h index bffd5e516c..7f0f788e66 100644 --- a/cutlass/gemm/gemm_epilogue_traits.h +++ b/cutlass/gemm/gemm_epilogue_traits.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/gemm_global_stream.h b/cutlass/gemm/gemm_global_stream.h index 1ae2963c00..6a6d75b36d 100644 --- a/cutlass/gemm/gemm_global_stream.h +++ b/cutlass/gemm/gemm_global_stream.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/gemm_global_tile.h b/cutlass/gemm/gemm_global_tile.h index 5174ce67fe..154bc24dc0 100644 --- a/cutlass/gemm/gemm_global_tile.h +++ b/cutlass/gemm/gemm_global_tile.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/gemm_mainloop.h b/cutlass/gemm/gemm_mainloop.h new file mode 100644 index 0000000000..a65cb3ae21 --- /dev/null +++ b/cutlass/gemm/gemm_mainloop.h @@ -0,0 +1,274 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements a software-pipelined efficient GEMM. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" + +namespace cutlass { +namespace gemm { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmMainloop { + + // + // Type definitions + // + + /// The traits. + typedef Traits_ Traits; + + /// The GEMM mainloop + typedef typename Traits::KernelClass KernelClass; + + /// The shared storage. + typedef typename Traits::SharedStorage SharedStorage; + + /// The scalar for A. + typedef typename Traits::ScalarA ScalarA; + /// The scalar for B. + typedef typename Traits::ScalarB ScalarB; + /// The scalar in the epilogue. + typedef typename Traits::Epilogue::Scalar ScalarEpilogue; + /// The scalar for C. + typedef typename Traits::Epilogue::ScalarC ScalarC; + /// The scalar for D. + typedef typename Traits::Epilogue::ScalarD ScalarD; + /// The index. + typedef typename Traits::Index Index; + + /// Define the mainloop iteration size + typedef typename Traits::MultiplyAdd MultiplyAdd; + + /// The number of threads. + static int const kThreads = Traits::GemmConfig::kThreads; + + // Number of warp-level multiply-accumulate steps executed by each warp. + static Index const kWarpGemmSteps = + Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD; + + /* + // Make sure we have at least 2 unrolling steps or our pipeling is not going to work. + static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps"); + */ + + /// Use the params object defined in traits + typedef typename Traits::Params Params; + + // + // Data members + // + + /// The params. + Params const& params; + + /// SharedStorage object + SharedStorage& shared_storage; + + // + // Methods + // + + /// Ctor. + CUTLASS_DEVICE GemmMainloop(Params const& params_, SharedStorage& shared_storage_) + : params(params_), shared_storage(shared_storage_) {} + + /// Fetches global stream pair + template + CUTLASS_DEVICE void fetch_global(typename Traits::GlobalLoadStream& global_to_shared_stream, + Index outer_k) { + // If residue portion and not calculating residue in prolog, update residue predicates now. + if (Residue) { + global_to_shared_stream.residue(outer_k); + } + global_to_shared_stream.copy(); + } + + /// Computes a warp-level GEMM on data held in shared memory + template + CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream, + typename Traits::SharedStream& shared_load_stream, + typename MultiplyAdd::Accumulators& accumulators, + Index outer_k) { + + // Whether to load global stream before loading shared stream + const bool kGlobalStreamFirst = (kWarpGemmSteps <= 4); + + // Load data for the next iteration of the main loop (unless it's the last iteration). + if (kGlobalStreamFirst && !LastIteration) { + fetch_global(global_to_shared_stream, outer_k); + } + + CUTLASS_PRAGMA_UNROLL + for (int step = 0; step < kWarpGemmSteps; ++step) { + + // Trigger the copy from shared memory for the next A/B values. + shared_load_stream.copy((step + 1) % kWarpGemmSteps); + + // Load data for the next iteration of the main loop (unless it's the last iteration). + if (!kGlobalStreamFirst && (step == 0) && !LastIteration) { + fetch_global(global_to_shared_stream, outer_k); + } + + if (step == kWarpGemmSteps - 2) { + // Make sure the data from shared memory has been entirely consumed. + Traits::shared_load_fence(true); + + global_to_shared_stream.commit(); + + // Make sure the data is in shared memory. + Traits::shared_store_fence(true); + + // Move to the next stage for the load (if it makes sense). + shared_load_stream.inc_stage(); + } + + // Make sure the values are available for the current iteration to do the multiply-add. + shared_load_stream.commit(step); + + // Do the math on the fragments of the current iteration. + MultiplyAdd multiply_add; + multiply_add.multiply_add(shared_load_stream.fragment_a(step), + shared_load_stream.fragment_b(step), + accumulators, + accumulators); + } + } + + /// Do the GEMM. + CUTLASS_DEVICE void multiply_add() { + // Swizzle the IDs of the block (to enable better cache behavior). + typename Traits::BlockSwizzle block_swizzle; + Coord<3> threadblock_offset = + block_swizzle.get_threadblock_offset(make_Coord_from_shape()); + + // We may want to use shared memory to clear the registers. + typedef typename Traits::ClearAccumulators ClearAccumulators; + + // Get the bounds for each thread, it maybe different than problem_size + Coord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size, + params.partitionK_range); + + // The streams to read A/B from global memory to shared memory. + typename Traits::GlobalLoadStream global_to_shared_stream( + params.global_to_shared_stream, + shared_storage.main_loop.global_to_shared_stream, + shared_storage.main_loop.threadblock_tile.reference(), + bounds, + threadblock_offset); + + // update A and B pointer offset based on batch_id and batch_stride_offset + global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id()); + + // Create the accumulator clear. + ClearAccumulators clear; + + // Deal with residue in prolog. + // global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD); + global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD); + + // Fetch the fragments for A and B from global memory. + global_to_shared_stream.copy(); + + // Copy the elements to shared memory (after transformation if needed). + global_to_shared_stream.commit(); + + // Make sure the data is in shared memory. + Traits::shared_store_fence(false); + + // Rollback to the beginning of the first tile (if residue exists). + // global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD); + global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD); + + // The stream of data from shared memory to fragments. + typename Traits::SharedStream shared_load_stream( + params.shared_stream, + shared_storage.main_loop.threadblock_tile.reference()); + + // Trigger the copy from shared memory for the 1st stream. + shared_load_stream.copy(0); + + // Allocate the accumulators. + typename MultiplyAdd::Accumulators accumulators; + + // Clear the accumulators. + clear.clear(accumulators); + + // Initial index + // Index outer_k = params.problem_size[0] - Traits::OutputTile::kD; + // problem_size[0] might be bigger than bounds[0] + Index outer_k = bounds[0] - Traits::OutputTile::kD; + // Check if we are computing residue in prolog or not. + if (Traits::GemmConfig::kResidueInProlog) { + // Execute all mainloop iterations but the last one. + + CUTLASS_GEMM_LOOP + for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) { + CUTLASS_GEMM_LOOP_HEADER + consume_tile( + global_to_shared_stream, shared_load_stream, accumulators, outer_k); + } + + consume_tile( + global_to_shared_stream, shared_load_stream, accumulators, outer_k); + + } else { + // When kResidueSeparate = true, execute all mainloop iterations but the last two without any + // consideration for K-residue or predicate updates. This improves the steady state of some + // kernels. + if (Traits::GemmConfig::kResidueSeparate) { + + CUTLASS_GEMM_LOOP + for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) { + CUTLASS_GEMM_LOOP_HEADER + consume_tile( + global_to_shared_stream, shared_load_stream, accumulators, outer_k); + } + } + + // Execute remaining tiles with K-residue predicate updates enabled. + CUTLASS_GEMM_LOOP + for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) { + CUTLASS_GEMM_LOOP_HEADER + consume_tile( + global_to_shared_stream, shared_load_stream, accumulators, outer_k); + } + } + + typedef typename Traits::Epilogue Epilogue; + Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm()); + epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id()); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/gemm_operand.h b/cutlass/gemm/gemm_operand.h index 2b4dcdc916..eef0e50234 100644 --- a/cutlass/gemm/gemm_operand.h +++ b/cutlass/gemm/gemm_operand.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/gemm_shared_stream.h b/cutlass/gemm/gemm_shared_stream.h index ed158d6b23..beb214ab7c 100644 --- a/cutlass/gemm/gemm_shared_stream.h +++ b/cutlass/gemm/gemm_shared_stream.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -91,8 +91,16 @@ struct SharedLoadStream { transformer = Transformer(); } + /// Clears the fragment + CUTLASS_DEVICE void clear() { + fetched[0].clear(); + fetched[1].clear(); + transformed[0].clear(); + transformed[1].clear(); + } + /// Load the data from shared memory to the fetch fragment. - CUTLASS_DEVICE void copy() { + CUTLASS_DEVICE void copy() { iterator.load_post_increment(fetched[0]); } diff --git a/cutlass/gemm/gemm_shared_tile.h b/cutlass/gemm/gemm_shared_tile.h index 78fb1f2054..6882ca9b88 100644 --- a/cutlass/gemm/gemm_shared_tile.h +++ b/cutlass/gemm/gemm_shared_tile.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/gemm_stream_pair.h b/cutlass/gemm/gemm_stream_pair.h index f1c22edfc0..5690afa289 100644 --- a/cutlass/gemm/gemm_stream_pair.h +++ b/cutlass/gemm/gemm_stream_pair.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -140,14 +140,19 @@ struct GlobalLoadStreamPair { /// Trigger the copies from shared memory to registers. CUTLASS_DEVICE void copy() { + stream_a.copy(); + stream_b.copy(); + } /// Commit the data. CUTLASS_DEVICE void commit() { stream_a.commit(); + stream_b.commit(); + } /// Execute the residue code. @@ -233,6 +238,13 @@ struct SharedStreamPair { stream_b.commit(step); } + /// Clears all fragments + CUTLASS_DEVICE + void clear() { + stream_a.clear(); + stream_b.clear(); + } + /// The fragment A. CUTLASS_DEVICE typename StreamA::TransformedFragment const &fragment_a(int step) const { diff --git a/cutlass/gemm/gemm_traits.h b/cutlass/gemm/gemm_traits.h index b588de0a98..40194487c6 100644 --- a/cutlass/gemm/gemm_traits.h +++ b/cutlass/gemm/gemm_traits.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -42,7 +42,7 @@ #include "cutlass/gemm/gemm_operand.h" #include "cutlass/gemm/gemm_shared_stream.h" #include "cutlass/gemm/threadblock_swizzle.h" -#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/gemm_mainloop.h" namespace cutlass { namespace gemm { @@ -359,7 +359,7 @@ struct GemmTraits { ClearAccumulators_> This_; /// The struct that consumes this Traits - typedef typename cutlass::gemm::Gemm KernelClass; + typedef typename cutlass::gemm::GemmMainloop KernelClass; /// The configuration. typedef GemmConfig_ GemmConfig; @@ -544,16 +544,26 @@ struct GemmTraits { /// Helper to construct a partitionedK GEMM params template - CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc, Index partitionK_count_) { + CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc, + Index partitionK_count_, + Index partitionK_multiple_ = 1 // each partition will be mulitples of partitionK_multiple_ + ) { // partitionK GEMM is a specialized batched stried gemm with different K ranges per batch // the problem_size of each batch is (lastK_size, n, m) // add more comments here // the k range for every batch excpet the last one //assert(partitionK_count_ > 0); partitionK_range = partitonK_desc.problem_size.k() / partitionK_count_; + partitionK_range = partitionK_range - (partitionK_range % partitionK_multiple_); // the k range of the last batch // int lastK_range = (partitonK_desc.problem_size.k() % partitionK_range) + partitionK_range; int lastK_range = partitonK_desc.problem_size.k() - partitionK_range * (partitionK_count_ - 1); + + assert((partitionK_range % partitionK_multiple_) == 0); + assert(partitionK_range > 0); + assert((lastK_range % partitionK_multiple_) == 0); + assert(lastK_range > 0); + int k_size = lastK_range; int lda = partitonK_desc.A.stride(0); int ldb = partitonK_desc.B.stride(0); @@ -641,7 +651,8 @@ struct GemmTraits { Index ldc, ScalarD* d_d, Index ldd, - Index partitionK_count_) { + Index partitionK_count_, + Index partitionK_multiple_ = 1) { GemmDesc desc( GemmCoord(k, n, m, 1), @@ -654,7 +665,7 @@ struct GemmTraits { ); - return this->initialize(desc, partitionK_count_); + return this->initialize(desc, partitionK_count_, partitionK_multiple_); } }; diff --git a/cutlass/gemm/hgemm_global_tile.h b/cutlass/gemm/hgemm_global_tile.h index 9d5ffe8508..a3b38151b6 100644 --- a/cutlass/gemm/hgemm_global_tile.h +++ b/cutlass/gemm/hgemm_global_tile.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/hgemm_multiply_add.h b/cutlass/gemm/hgemm_multiply_add.h index 7217d82c58..528e9c9e6f 100644 --- a/cutlass/gemm/hgemm_multiply_add.h +++ b/cutlass/gemm/hgemm_multiply_add.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -29,7 +29,6 @@ #pragma once #include "cutlass/fragment.h" - #include "cutlass/gemm/thread_multiply_add.h" namespace cutlass { @@ -66,6 +65,8 @@ struct ThreadMultiplyAdd { /// Make sure there's an even number of elements in both dimensions. static_assert(AccumulatorsPerThread::kH % 2 == 0, "Invalid size"); static_assert(AccumulatorsPerThread::kW % 2 == 0, "Invalid size"); + static_assert(AccumulatorsPerThread::kH >= 2 && AccumulatorsPerThread::kW >= 2, + "HGEMM expects at least 2x2 accmulator tiles per thread."); /// Ctor. CUTLASS_DEVICE ThreadMultiplyAdd() {} @@ -84,7 +85,10 @@ struct ThreadMultiplyAdd { // The output. __half2* d_half2 = reinterpret_cast<__half2*>(&d[0]); + CUTLASS_PRAGMA_UNROLL for (int j = 0; j < AccumulatorsPerThread::kH / 2; ++j) { + + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < AccumulatorsPerThread::kW / 2; ++i) { // The offsets in the output fragment. int const k0 = (2 * j + 0) * (AccumulatorsPerThread::kW / 2) + i; diff --git a/cutlass/gemm/hgemm_swizzle.h b/cutlass/gemm/hgemm_swizzle.h index 2ecd00881e..527d28b48c 100644 --- a/cutlass/gemm/hgemm_swizzle.h +++ b/cutlass/gemm/hgemm_swizzle.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/hgemm_traits.h b/cutlass/gemm/hgemm_traits.h index 2261bb4b3e..bf4b01d26c 100644 --- a/cutlass/gemm/hgemm_traits.h +++ b/cutlass/gemm/hgemm_traits.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -38,7 +38,7 @@ #include "cutlass/gemm/gemm_traits.h" #include "cutlass/gemm/hgemm_global_tile.h" #include "cutlass/gemm/hgemm_multiply_add.h" -#include "cutlass/gemm/hgemm_swizzle.h" +#include "cutlass/layout/thread/transform.h" namespace cutlass { namespace gemm { @@ -107,7 +107,8 @@ struct HgemmTransformerA { template struct HgemmTransformerA { - typedef HgemmSwizzle Transformer; + typedef typename Iterator_::FragmentShape FragmentShape; + typedef cutlass::layout::thread::Transform Transformer; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -122,7 +123,8 @@ struct HgemmTransformerB { template struct HgemmTransformerB { - typedef HgemmSwizzle Transformer; + typedef typename Iterator_::FragmentShape FragmentShape; + typedef cutlass::layout::thread::Transform Transformer; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass/gemm/igemm_epilogue.h b/cutlass/gemm/igemm_epilogue.h index 2ad24f32cc..5bb1aa0888 100644 --- a/cutlass/gemm/igemm_epilogue.h +++ b/cutlass/gemm/igemm_epilogue.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/igemm_global_tile.h b/cutlass/gemm/igemm_global_tile.h index 845678a82a..e169933b01 100644 --- a/cutlass/gemm/igemm_global_tile.h +++ b/cutlass/gemm/igemm_global_tile.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/igemm_multiply_add.h b/cutlass/gemm/igemm_multiply_add.h index 2b09cba20e..7892850d0f 100644 --- a/cutlass/gemm/igemm_multiply_add.h +++ b/cutlass/gemm/igemm_multiply_add.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -28,8 +28,9 @@ */ #pragma once -#include "cutlass/fragment.h" +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) +#include "cutlass/fragment.h" #include "cutlass/gemm/thread_multiply_add.h" namespace cutlass { @@ -44,6 +45,11 @@ struct ThreadMultiplyAdd typedef Shape<4, 1, 1> InstructionShape; /// Shape of the thread-level GEMM (K-by-N-by-M) typedef ThreadGemmShape_ ThreadGemmShape; + + /// Thread-level GEMM (N-by-M) must be a multiple of 32. + static_assert((ThreadGemmShape::kH * ThreadGemmShape::kW) % 32 == 0, + "Thread-level GEMM (N-by-M) must be multiple of 32"); + /// Aliased for compatibility. Will be removed in CUTLASS v2.0 typedef ThreadGemmShape AccumulatorsPerThread; /// The number of threads per warp. @@ -72,19 +78,18 @@ struct ThreadMultiplyAdd Accumulators const& c, Accumulators& d) { - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610) // The inputs. int const* a_int = reinterpret_cast(&a[0]); int const* b_int = reinterpret_cast(&b[0]); for (int j = 0; j < AccumulatorsPerThread::kH; ++j) { for (int i = 0; i < AccumulatorsPerThread::kW; ++i) { + asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" : "=r"(d[j * AccumulatorsPerThread::kW + i]) : "r"(a_int[i]), "r"(b_int[j]), "r"(c[j * AccumulatorsPerThread::kW + i])); } } - #endif } }; @@ -92,3 +97,5 @@ struct ThreadMultiplyAdd } // namespace gemm } // namespace cutlass + +#endif // if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) diff --git a/cutlass/gemm/igemm_swizzle.h b/cutlass/gemm/igemm_swizzle.h index fbb68d1434..0c6a2ffa51 100644 --- a/cutlass/gemm/igemm_swizzle.h +++ b/cutlass/gemm/igemm_swizzle.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -60,6 +60,7 @@ struct IgemmSwizzle { /// Transform a fragment. CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) { + // Expose src/dst as int arrays. int const* src_int = reinterpret_cast(&src[0]); int* dst_int = reinterpret_cast(&dst[0]); diff --git a/cutlass/gemm/igemm_traits.h b/cutlass/gemm/igemm_traits.h index 5bceeda92e..3f2039afe5 100644 --- a/cutlass/gemm/igemm_traits.h +++ b/cutlass/gemm/igemm_traits.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -39,7 +39,7 @@ #include "cutlass/gemm/igemm_epilogue.h" #include "cutlass/gemm/igemm_global_tile.h" #include "cutlass/gemm/igemm_multiply_add.h" -#include "cutlass/gemm/igemm_swizzle.h" +#include "cutlass/layout/thread/transform.h" #include "cutlass/reshape_tile.h" namespace cutlass { @@ -90,9 +90,10 @@ struct IgemmConfig : public GemmConfig< /// kResidueSeparate false, /// kResidueInPrologue - false, + true, /// kLaunchBounds - false> {}; + false> +{}; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -380,7 +381,8 @@ struct IgemmTransformerA { template struct IgemmTransformerA { - typedef IgemmSwizzle Transformer; + typedef typename Iterator_::FragmentShape FragmentShape; + typedef cutlass::layout::thread::Transform Transformer; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -395,7 +397,8 @@ struct IgemmTransformerB { template struct IgemmTransformerB { - typedef IgemmSwizzle Transformer; + typedef typename Iterator_::FragmentShape FragmentShape; + typedef cutlass::layout::thread::Transform Transformer; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass/gemm/linear_scaling.h b/cutlass/gemm/linear_scaling.h index a12fc5f19f..e747b218c4 100644 --- a/cutlass/gemm/linear_scaling.h +++ b/cutlass/gemm/linear_scaling.h @@ -1,6 +1,5 @@ - /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/linear_scaling_device_ptr.h b/cutlass/gemm/linear_scaling_device_ptr.h index 5dc845da4a..928b9700ec 100644 --- a/cutlass/gemm/linear_scaling_device_ptr.h +++ b/cutlass/gemm/linear_scaling_device_ptr.h @@ -1,5 +1,6 @@ + /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/mma_epilogue.h b/cutlass/gemm/mma_epilogue.h new file mode 100644 index 0000000000..0e7004d73f --- /dev/null +++ b/cutlass/gemm/mma_epilogue.h @@ -0,0 +1,284 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory + with + the computed matrix product. +*/ + +#pragma once + +// clang-format off + +#include "cutlass/coord.h" + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMAEpilogue { + /// The traits class. + typedef EpilogueTraits_ Traits; + + /// The params. + typedef typename Traits::Params Params; + + /// The shared storage. + typedef typename Traits::SharedStorage SharedStorage; + + /// Defines a tiling of the EpilogueTile over the entire threadblock GEMM tile + typedef typename Traits::Iterations Iterations; + + /// The output tile. + typedef typename Traits::OutputTile OutputTile; + + /// Accumulators to store in the epilogue + typedef typename Traits::Accumulators Accumulators; + + /// A functor to copy a slice of accumulators for a given epilogue iteration + typedef typename Traits::SelectAccumulators SelectAccumulators; + + /// The iterator to load source matrix from global memory. + typedef typename Traits::GlobalLoadStreamC GlobalLoadStreamC; + + /// The iterator to store the final GEMM computation to global memory. + typedef typename Traits::GlobalStoreStreamD GlobalStoreStreamD; + + /// The stream to store matrix product to shared memory + typedef typename Traits::SharedStoreStreamD SharedStoreStreamD; + + /// The stream to load the matrix product from shared memory + typedef typename Traits::SharedLoadStreamD SharedLoadStreamD; + + /// The functor in charge of the math. + typedef typename Traits::Functor Functor; + + /// The scalar type used by the epilogue functor. + typedef typename Functor::Scalar Scalar; + + /// The scalar type of the source accumulator matrix. + typedef typename Traits::ScalarC ScalarC; + + /// The scalar type of the destination accumulator matrix. + typedef typename Traits::ScalarD ScalarD; + + /// The index type. + typedef typename Traits::Index Index; + + /// Functor computing the offset from the threadblock origin per iteration of + /// the epilogue. + typedef typename Traits::GlobalOffset GlobalOffset; + + /// + typedef typename Traits::GlobalDataLayout GlobalDataLayout; + + // + // Data members + // + + /// The params. + Params const& params; + + /// The shared storage. + SharedStorage& shared_storage; + + /// The dimensions of the GEMM. + gemm::GemmCoord problem_size; + + /// Epilogue functor + Functor functor; + + // Functor to select a set of accumulators + SelectAccumulators select_accumulators; + + + // Functor to compute the global offset relative to the threadblock for each iteration + // of the epilogue. + GlobalOffset global_offset; + + // + // Methods + // + + /// Ctor. + CUTLASS_DEVICE MMAEpilogue( + Params const& params_, + SharedStorage& shared_storage_, + Coord<3> const& _problem_size, + SelectAccumulators _select_accumulators = SelectAccumulators(), + GlobalOffset _global_offset = GlobalOffset() + ): + params(params_), + shared_storage(shared_storage_), + problem_size(_problem_size), + functor(params_.functor), + select_accumulators(_select_accumulators), + global_offset(_global_offset) {} + + /// Execute the epilogue. + CUTLASS_DEVICE void epilogue( + Accumulators& accumulators, + Coord<3> const& threadblock_offset = make_Coord(0, 0, 0), + int batch_id = 0) { + + if (functor.source_required()) { + epilogue_with_or_without_beta(accumulators, threadblock_offset, batch_id); + } + else { + epilogue_with_or_without_beta(accumulators, threadblock_offset, batch_id); + } + } + + /// + + /// Execute the epilogue. + template + CUTLASS_DEVICE void epilogue_with_or_without_beta( + Accumulators& accumulators, + Coord<3> const& threadblock_offset = make_Coord(0, 0, 0), + int batch_id = 0) { + + /// Global memory mapping function + GlobalDataLayout gmem_map_func; + + // Construct shared memory streams + SharedStoreStreamD shared_store_stream( + params.shared_store_stream_d, + shared_storage.reference()); + + SharedLoadStreamD shared_load_stream( + params.shared_load_stream_d, + shared_storage.reference()); + + // Map the GEMM problem dimensions into the coordinate system of the output memory + Coord<2> gmem_bounds = gmem_map_func(make_Coord( + problem_size.m(), // GEMM M - rows + problem_size.n())); // GEMM N - columns + + Coord<3> gmem_tile_bounds = make_Coord( + problem_size.k(), // GEMM K + gmem_bounds[0], // strided + gmem_bounds[1]); // contiguous + + // Iterate over the entire Threadblock tile + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + if (!(h == 0)) { + //continue; + } + + // Offset in GEMM coordinates + gemm::GemmCoord offset_in_gemm = threadblock_offset + global_offset(make_Coord(h, w)); + + Coord<2> offset_in_memory = gmem_map_func( + make_Coord( + offset_in_gemm.m(), // GEMM M - rows + offset_in_gemm.n())); // GEMM N - columns + + // Offset in + Coord<3> global_tile_offset = make_Coord( + offset_in_gemm.k(), // GEMM K + offset_in_memory[0], // strided + offset_in_memory[1]); // contiguous + + GlobalLoadStreamC global_load_stream( + params.load_stream_c, + gmem_tile_bounds, + global_tile_offset); + + GlobalStoreStreamD global_store_stream( + params.store_stream_d, + gmem_tile_bounds, + global_tile_offset); + + // update C pointer offset based on batch_id and batch_stride_offset + global_load_stream.iterator.add_pointer_offset(batch_id * params.batch_stride_C); + + // update D pointer offset based on batch_id and batch_stride_offset + global_store_stream.iterator.add_pointer_offset(batch_id * params.batch_stride_D); + + // Load the C matrix into fragment. + if (kSourceRequired) { + global_load_stream.copy(); + } + + // Make sure we can write to shared memory. + shared_load_fence(); + + // Store accumulator tile to shared memory + shared_store_stream.copy( + select_accumulators(accumulators, make_Coord(h, w))); + + shared_store_stream.commit(); + + // Make sure the data is in shared memory. + shared_store_fence(); + + // Load the accumulators back to registers from shared memory. + shared_load_stream.copy(); + shared_load_stream.commit(); + // Commit the C matrix fragment + if (kSourceRequired) { + global_load_stream.commit(); + } + + // Apply epilogue functor + if (kSourceRequired) { + + functor.evaluate(shared_load_stream.fragment(), + global_load_stream.fragment(), + global_store_stream.fragment()); + } + else { + + functor.evaluate( + shared_load_stream.fragment(), + global_store_stream.fragment()); + } + + global_store_stream.copy(); + global_store_stream.commit(); + } + } + } + + /// The memory fence for shared loads. + CUTLASS_DEVICE void shared_load_fence() { __syncthreads(); } + + /// The memory fence for shared stores. + CUTLASS_DEVICE void shared_store_fence() { __syncthreads(); } + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // gemm +} // namespace cutlass + +// clang-format on diff --git a/cutlass/gemm/mma_global_stream.h b/cutlass/gemm/mma_global_stream.h new file mode 100644 index 0000000000..c83c154a5b --- /dev/null +++ b/cutlass/gemm/mma_global_stream.h @@ -0,0 +1,360 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements efficient loading of the thread block-level tile from global memory and + storing to shared memory. +*/ + +#pragma once + +// clang-format off + +#include "cutlass/convert.h" +#include "cutlass/gemm/gemm_operand.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tile_allocation.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +///! Stream adapter for loading threadblock-scoped GEMM tiles and storing to shared memory +template < + /// Identifies multiplicand + GemmOperand::Kind Operand, + /// Layout of source matrix in global memory + MatrixLayout::Kind Layout, + /// Iterator for loading threadblock-scoped tiles + typename LoadIterator_, + /// Transformation functor for transforming fragments + typename Transformer_, + /// Iterator for storing threadblock-scoped tiles to shared memory + typename StoreIterator_, + /// Number of stores before iterator wraps - zero indicates no wrapping + int StageCount> +struct MMAGlobalLoadStream { + // + // Type definitions + // + + /// Identifies the operand + static GemmOperand::Kind const kOperand = Operand; + /// The layout. + static MatrixLayout::Kind const kLayout = Layout; + /// The load iterator. + typedef LoadIterator_ LoadIterator; + /// The transformer. + typedef Transformer_ Transformer; + /// The store iterator to write to shared memory. + typedef StoreIterator_ StoreIterator; + /// Number of stages + static int const kStageCount = StageCount; + + /// Predicate vector + typedef typename LoadIterator::PredicateVector PredicateVector; + /// The fragment that is copied from shared memory. + typedef typename LoadIterator::Fragment FetchedFragment; + /// The fragment that is obtained after the transformation by the transformer. + typedef typename Transformer::OutputFragment TransformedFragment; + /// Make sure the fragments match. + static_assert((platform::is_same::value), + ""); + /// The output fragment. + typedef TransformedFragment Fragment; + /// Make sure the transformed fragment is the same as the store fragment. + static_assert((platform::is_same::value), + ""); + + /// The scalar type of the iterator. + typedef typename LoadIterator::Scalar Scalar; + /// The pointer. + typedef typename LoadIterator::Pointer Pointer; + /// The index. + typedef typename LoadIterator::Index Index; + /// The index. + typedef typename LoadIterator::LongIndex LongIndex; + /// The tile. + typedef typename LoadIterator::Tile Tile; + + /// The params. + struct Params { + + /// Helper + static int const kElementsPerLdg = LoadIterator::Tile::kC; + + // + // Data members + // + + /// The load iterator. + typename LoadIterator::Params load_iterator; + + /// Stride within a batch of matrix operands + LongIndex batch_stride; + + // Offset to residue. + Index offset_to_residue; + + // Offset to residue for the last partition + Index offset_to_residue_last_partition; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): batch_stride(0), offset_to_residue(0), offset_to_residue_last_partition(0) {} + + /// Constructor + CUTLASS_HOST_DEVICE + Params( + TensorRef const &ref, + Index _offset_to_residue + ): + batch_stride(0), + offset_to_residue(_offset_to_residue), + offset_to_residue_last_partition(0), + load_iterator( + TensorRef( + ref.data(), + make_Coord(ref.stride(0) * kElementsPerLdg, ref.stride(0), kElementsPerLdg, 1) + ) + ) {} + + /// Initializer + CUTLASS_HOST_DEVICE + int initialize( + TensorRef const &ref, + LongIndex batch_stride_, + Index offset_to_residue_, + Index offset_to_residue_last_partition_) { + + batch_stride = batch_stride_; + offset_to_residue = offset_to_residue_; + offset_to_residue_last_partition = offset_to_residue_last_partition_; + + return load_iterator.initialize( + TensorRef( + ref.data(), + make_Coord(static_cast(batch_stride), ref.stride(0), kElementsPerLdg, 1) + ) + ); + } + + CUTLASS_HOST_DEVICE + int initialize( + TensorRef const &ref, + Index offset_to_residue_) { + + offset_to_residue = offset_to_residue_; + return load_iterator.initialize( + TensorRef( + ref.data(), + make_Coord(ref.stride(0) * kElementsPerLdg, ref.stride(0), kElementsPerLdg, 1) + ) + ); + } + + CUTLASS_HOST_DEVICE int initialize(Index offset_to_residue_) { + offset_to_residue = offset_to_residue_; + return 0; + } + + CUTLASS_DEVICE Index get_offset_to_residue() { + if (blockIdx.z == gridDim.z - 1) { //last partition + return offset_to_residue_last_partition; + } + else { + return offset_to_residue; + } + } + }; + + /// Empty shared storage + struct SharedStorage {}; + + /// Shared memory allocation for the tile + typedef TileAllocation< + typename StoreIterator::Scalar, + typename ShapeMul< + typename StoreIterator::OperandShape, + Shape + >::Shape + > ThreadblockTileStorage; + + /// ZipTensorRef to threadblock tiles + typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef; + + // + // Data members + // + + ///! The parameters + Params params; + + ///! Dimensions of global memory tile + Coord<3> threadblock_offset; + + ///! Dimensions of multiplicand bounds + Coord<3> multiplicand_bounds; + + ///! Iterator to load threadblock tiles from global memory + LoadIterator load_iterator; + + ///! Predicate vector + PredicateVector predicates; + + ///! The fragment to fetch from shared memory. + FetchedFragment fetched_fragment; + + ///! Functor to transform fragments after they have been loaded + Transformer transformer; + + ///! The fragment to convert the data after it has been fetched from shared memory. + TransformedFragment transformed_fragment; + + ///! Iterator to store threadblock tiles to shared memory + StoreIterator store_iterator; + + ///! Counter + int stage_index; + + // + // Static member functions + // + + /// Maps a coordinate in the GEMM's (K, N, M) coordinate system to global memory + CUTLASS_HOST_DEVICE + static Coord<3> project_coordinate(Coord<3> const &coord, Index d_offset = 0) { + bool const kKstrided = + gemm::GemmMultiplicandTraits::kKstrided; + + Coord<3> tile_coord = gemm::ProjectOperand::project(coord); + + return make_Coord( + tile_coord[0] + d_offset, tile_coord[1], tile_coord[2] / LoadIterator::Tile::kC); + } + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE MMAGlobalLoadStream(Params const &_params, + SharedStorage &shared_storage, + ThreadblockTileRef const &threadblock_tile_ref, + Coord<3> const bounds, + Coord<3> const &block) + : params(_params), + threadblock_offset(project_coordinate(block)), + multiplicand_bounds(project_coordinate(bounds, 1)), + load_iterator(params.load_iterator, threadblock_offset), + transformer(), + store_iterator(threadblock_tile_ref.data()), + stage_index(0) { + load_iterator.initialize_predicates( + predicates.begin(), multiplicand_bounds, threadblock_offset); + } + + /// Loads the data from global memory + CUTLASS_DEVICE void copy() { + load_iterator.load_post_increment(fetched_fragment, predicates.begin()); + } + + /// Transform and commit the data to shared memory + CUTLASS_DEVICE void commit() { + transformer.transform(fetched_fragment, transformed_fragment); + store_iterator.store_post_increment(transformed_fragment); + + ++stage_index; + if (kStageCount && stage_index == kStageCount) { + store_iterator -= kStageCount; + stage_index = 0; + } + } + + /// Computes a predicate mask for loads during final threadblock tile load iteration + CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) { + // That's the residue! + Coord<3> _block_offset = threadblock_offset; + if (kOperand == GemmOperand::kA ^ kLayout == MatrixLayout::kRowMajor) { + // K-strided + _block_offset = + make_Coord(threadblock_offset[0], multiplicand_bounds[1] - k, threadblock_offset[2]); + } else { + // K-contiguous + _block_offset = make_Coord(threadblock_offset[0], + threadblock_offset[1], + multiplicand_bounds[2] - k / LoadIterator::Tile::kC); + } + + load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, _block_offset); + fetched_fragment.clear(); + } + + /// Move to the residue portion. + CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) { + Index kResidue = k % kTileK; + if (kResidue) { + residue(kResidue); + Index this_offset_residue = params.get_offset_to_residue(); + load_iterator.add_pointer_offset(this_offset_residue * load_iterator.stride_advance()); + } + } + + /// Rollback to the beginning of the first tile + CUTLASS_DEVICE void rollback(void) { + load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, threadblock_offset); + + int const kBlock = kOperand == GemmOperand::kA + ? (kLayout == MatrixLayout::kColumnMajor ? Tile::kH : Tile::kW) + : (kLayout == MatrixLayout::kRowMajor ? Tile::kH : Tile::kW); + Index this_offset_residue = params.get_offset_to_residue(); + load_iterator.add_pointer_offset(-(this_offset_residue + kBlock) * + load_iterator.stride_advance()); + } + + /// Adds a Coord<3> to the underlying global load iterator + CUTLASS_DEVICE MMAGlobalLoadStream &operator+=(Coord<3> const &offset) { + load_iterator += offset; + return *this; + } + + /// Adds an offset based on batch stride + CUTLASS_DEVICE MMAGlobalLoadStream &add_batch_offset(int batch_id) { + load_iterator.add_pointer_offset(batch_id * params.batch_stride); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // gemm +} // namespace cutlass + +// clang-format on diff --git a/cutlass/gemm/mma_global_tile.h b/cutlass/gemm/mma_global_tile.h new file mode 100644 index 0000000000..ae3a91250c --- /dev/null +++ b/cutlass/gemm/mma_global_tile.h @@ -0,0 +1,201 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines structural properties for GEMM targeting Volta's mma.sync instruction +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/gemm/gemm_operand.h" +#include "cutlass/reshape_tile.h" +#include "cutlass/tile_iterator.h" +#include "cutlass/util/platform.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Iterators used to load multiplicands from global memory specialized for Volta884 access patterns +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterator for loading data for congruous access patterns +template +struct MMAThreadblockCongruousLoad { + /// Identifies multiplicand of GEMM (A or B) + static GemmOperand::Kind const kOperand = Operand; + + /// Specifies layout of data in source memory + static MatrixLayout::Kind const kLayout = + (Operand == GemmOperand::kA ? MatrixLayout::kColumnMajor : MatrixLayout::kRowMajor); + + /// Shape of thread-block multiplicand + typedef Tile_ Tile; + + /// Number of participating warps + static int const kWarpCount = WarpCount; + + /// Delta between warp accumulator tiles along the outer dimension + static int const kWarpDelta = WarpDelta; + + /// This implementation is specialized for 128b loads + static int const kAccessSize = 8; + + /// Projects the threadblock tile + typedef typename gemm::GemmMultiplicandTraits::Shape OperandShape; + + /// Reshapes the threadblock tile by access size + typedef typename ReshapeTile::Tile VectorizedShape; + + /// Shape of tile + typedef Shape<1, 4, 8> WarpStoreCoverage; + + /// Shape of tile loaded by each warp per load operation + typedef Shape<1, 4, 8> WarpLoadShape; + + // + // Load iterator + // + + /// + typedef Shape<1, WarpLoadShape::kH * kWarpCount, WarpLoadShape::kW> Delta; + + typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides; + + /// Rakes warps along contiguous dimensions and strip-mines strided + /// dimension. + typedef Shape<1, + VectorizedShape::kH / WarpStoreCoverage::kH / WarpCount, + VectorizedShape::kW / WarpStoreCoverage::kW, + 1> + Iterations; + + /// Functor computing starting offset for each thread + struct ThreadOffset { + __device__ Coord<4> operator()() const { + int warp_id = (threadIdx.x >> 5); + int lane_id = (threadIdx.x & 0x1f); + + int lane_k = lane_id / WarpLoadShape::kW; + int lane_outer = lane_id % WarpLoadShape::kW; + + Coord<4> offset = make_Coord(0, warp_id * WarpLoadShape::kH + lane_k, lane_outer, 0); + + return offset; + } + }; + + /// Source tile traits + typedef TileTraits LoadTileTraits; + + /// Load iterator + typedef TileLoadIterator Iterator; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterator for loading data for congruous access patterns +template +struct MMAThreadblockCrosswiseLoad { + /// Identifies multiplicand of GEMM (A or B) + static GemmOperand::Kind const kOperand = Operand; + + /// Specifies layout of data in source memory + static MatrixLayout::Kind const kLayout = + (Operand == GemmOperand::kA ? MatrixLayout::kRowMajor : MatrixLayout::kColumnMajor); + + /// Shape of thread-block multiplicand + typedef Tile_ Tile; + + /// Number of participating warps + static int const kWarpCount = WarpCount; + + /// Delta between warp accumulator tiles along the outer dimension + static int const kWarpDelta = WarpDelta; + + /// This implementation is specialized for 128b loads + static int const kAccessSize = 8; + + /// Projects the threadblock tile + typedef typename gemm::GemmMultiplicandTraits::Shape OperandShape; + + /// Reshapes the threadblock tile by access size + typedef typename ReshapeTile::Tile VectorizedShape; + + /// Shape of tile + typedef Shape<1, 8, 4> WarpStoreCoverage; + + /// Shape of tile loaded by each warp per load operation + typedef Shape<1, 8, 4> WarpLoadShape; + + // + // Load iterator + // + + /// + typedef Shape<1, WarpLoadShape::kH, WarpLoadShape::kW> Delta; + + typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides; + + /// Rakes warps along contiguous dimensions and strip-mines strided + /// dimension. + typedef Shape<1, + VectorizedShape::kH / WarpStoreCoverage::kH / WarpCount, + VectorizedShape::kW / WarpStoreCoverage::kW, + 1> + Iterations; + + /// Functor computing starting offset for each thread + struct ThreadOffset { + __device__ Coord<4> operator()() const { + + int warp_id = (threadIdx.x >> 5); + int lane_id = (threadIdx.x & 0x1f); + + int lane_k = lane_id % WarpLoadShape::kW; + int lane_outer = lane_id / WarpLoadShape::kW; + + Coord<4> offset = + make_Coord(0, warp_id * Iterations::kH * WarpLoadShape::kH + lane_outer, lane_k, 0); + + return offset; + } + }; + + /// Source tile traits + typedef TileTraits LoadTileTraits; + + /// Load iterator + typedef TileLoadIterator Iterator; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // gemm +} // namespace cutlass diff --git a/cutlass/gemm/mma_shared_stream.h b/cutlass/gemm/mma_shared_stream.h new file mode 100644 index 0000000000..af11ebab1e --- /dev/null +++ b/cutlass/gemm/mma_shared_stream.h @@ -0,0 +1,155 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements efficient loading of the thread block-level tile from global memory and + storing to shared memory. +*/ + +#pragma once + +#include "cutlass/convert.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Stream from shared memory to fragments for warp-level matrix multiply-accumulate +template < + /// The load iterator. + typename Iterator_, + /// The transformer to be applied after the data has been copied from shared memory. + typename Transformer_ = Copy, + /// Number of increments before iterator wraps - zero indicates no wrapping + int StageCount = 1> +struct MMASharedLoadStream { + /// The load iterator. + typedef Iterator_ Iterator; + /// The transformer. + typedef Transformer_ Transformer; + + /// Number of increments before iterator wraps - zero indicates no wrapping + static int const kStageCount = StageCount; + + /// The fragment that is copied from shared memory. + typedef typename Iterator::Fragment FetchedFragment; + /// The fragment that is obtained after the transformation by the transformer. + typedef typename Transformer::OutputFragment TransformedFragment; + /// Make sure the fragments match. + static_assert((platform::is_same::value), + ""); + /// The output fragment. + typedef TransformedFragment Fragment; + + /// Element type + typedef typename Iterator::Scalar Scalar; + + /// Reference type to a tensor + typedef TensorRef TensorRef; + + /// Parameters passed from host + struct Params {}; + + // + // Data members + // + + /// Iterator for loading fragments for warp-level matrix multiply-accumulate + Iterator iterator; + + /// Fetched fragment + FetchedFragment fetched[2]; + + /// The transformer. + Transformer transformer; + + /// Transformed fragment + TransformedFragment transformed[2]; + + /// Counts the number of stages + int stage_index; + + // + // Methods + // + + /// Ctor. + CUTLASS_DEVICE MMASharedLoadStream() : stage_index(0) {} + + /// Ctor. + CUTLASS_DEVICE MMASharedLoadStream( + Params const &_params, + TensorRef const &ref, + Coord<4> warp_offset = make_Coord(0, 0, 0, 0) + ): + iterator(ref.data(), warp_offset), stage_index(0) { + + } + + /// Load the data from shared memory to the fetch fragment. + CUTLASS_DEVICE void copy(int step) { + iterator.load( + fetched[step % 2], + make_Coord(step + stage_index * Iterator::VectorizedShape::kD, 0, 0, 0) + ); + } + + /// Commit the data. + CUTLASS_DEVICE void commit(int step) { + transformer.transform(fetched[step % 2], transformed[step % 2]); + } + + /// + CUTLASS_DEVICE void clear() { + fetched[0].clear(); + fetched[1].clear(); + transformed[0].clear(); + transformed[1].clear(); + } + + /// Gets the transformed fragment + CUTLASS_DEVICE + TransformedFragment &fragment(int step) { return transformed[step % 2]; } + + /// Gets the transformed fragment + CUTLASS_DEVICE + TransformedFragment const &fragment(int step) const { return transformed[step % 2]; } + + /// Increment the stage. + CUTLASS_DEVICE void inc_stage() { + + ++stage_index; + if (kStageCount && stage_index == StageCount) { + stage_index = 0; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // gemm +} // namespace cutlass diff --git a/cutlass/gemm/scalar_or_pointer.h b/cutlass/gemm/scalar_or_pointer.h index 7c4b4b75d0..9e29295141 100644 --- a/cutlass/gemm/scalar_or_pointer.h +++ b/cutlass/gemm/scalar_or_pointer.h @@ -1,6 +1,5 @@ - /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/sgemm_traits.h b/cutlass/gemm/sgemm_traits.h index 8ce7f58e26..6c54756e30 100644 --- a/cutlass/gemm/sgemm_traits.h +++ b/cutlass/gemm/sgemm_traits.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/thread_multiply_add.h b/cutlass/gemm/thread_multiply_add.h index b95dee58a0..784377ae5d 100644 --- a/cutlass/gemm/thread_multiply_add.h +++ b/cutlass/gemm/thread_multiply_add.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -73,16 +73,27 @@ struct ThreadMultiplyAdd { FragmentB const& b, Accumulators const& c, Accumulators& d) { + if(kLayout_ == MatrixLayout::kColumnMajor) { + + CUTLASS_PRAGMA_UNROLL for (int j = 0; j < AccumulatorsPerThread::kH; ++j) { + + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < AccumulatorsPerThread::kW; ++i) { + d[j * AccumulatorsPerThread::kW + i] = a[i] * b[j] + c[j * AccumulatorsPerThread::kW + i]; } } } else { + + CUTLASS_PRAGMA_UNROLL for(int i = 0; i < AccumulatorsPerThread::kW; ++i) { + + CUTLASS_PRAGMA_UNROLL for(int j = 0; j < AccumulatorsPerThread::kH; ++j) { + d[i * AccumulatorsPerThread::kH + j] = a[i] * b[j] + c[i * AccumulatorsPerThread::kH + j]; } } diff --git a/cutlass/gemm/threadblock_swizzle.h b/cutlass/gemm/threadblock_swizzle.h index eab8595a68..737b89a9cf 100644 --- a/cutlass/gemm/threadblock_swizzle.h +++ b/cutlass/gemm/threadblock_swizzle.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/gemm/volta884_complex_gemm_epilogue_traits.h b/cutlass/gemm/volta884_complex_gemm_epilogue_traits.h new file mode 100644 index 0000000000..059dcaeeeb --- /dev/null +++ b/cutlass/gemm/volta884_complex_gemm_epilogue_traits.h @@ -0,0 +1,348 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory + with the computed matrix product. +*/ + +#pragma once + +// clang-format off + +#include "cutlass/zip_fragment.h" +#include "cutlass/zip_tile_iterator.h" +#include "cutlass/util/complex.h" +#include "cutlass/gemm/volta884_gemm_epilogue_traits.h" +#include "cutlass/gemm/split_complex_linear_scaling.h" +#include "cutlass/util/pair.h" + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enables treating the accumulators selection as one object +template +struct ZipSelectAccumulators { + + /// Underlying selection function + typedef First_ First; + typedef Second_ Second; + + /// Accumulators + typedef ZipFragment< + typename First::Accumulators, + typename Second::Accumulators> Accumulators; + + /// Fragment + typedef ZipFragment< + typename First::Fragment, + typename Second::Fragment> Fragment; + + // + // Data members + // + + /// Selects the accumulators for the first part + First first; + + /// Selects the accumulators for the second + Second second; + + // + // Methods + // + + /// Default ctor + CUTLASS_DEVICE + ZipSelectAccumulators() { } + + /// Basic constructor + CUTLASS_DEVICE + ZipSelectAccumulators(First const &_first, Second const &_second): first(_first), second(_second) { } + + /// Selects accumulators for a given iteration of the epilogue + CUTLASS_DEVICE + Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const { + return make_ZipFragment(first(accum.first, idx), second(accum.second, idx)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines epilogue traits for complex-valued mma.sync GEMM +template < + typename GemmConfig_, + typename EpilogueFunctor_ = SplitComplexLinearScaling, + typename Index_ = int> +struct Volta884ComplexGemmEpilogueTraits { + + /// GEMM configuration + typedef GemmConfig_ GemmConfig; + + /// Epilogue functor + typedef EpilogueFunctor_ Functor; + + /// Global memory mapping function + typedef MatrixLayout::ColumnMajor GlobalDataLayout; + + /// Index type + typedef Index_ Index; + + /// Long index used for offsets + typedef long long LongIndex; + + /// Defines epilogue traits for real-valued Volta884 GEMM epilogue + typedef typename Volta884GemmEpilogueTraitsHelper< + GemmConfig, + Functor, + typename GemmConfig::MultiplyAdd::RealMultiplyAdd, + Index>::EpilogueTraits RealEpilogueTraits; + + /// The output tile. + typedef typename RealEpilogueTraits::OutputTile OutputTile; + + /// The warp-level GEMM tile + typedef typename RealEpilogueTraits::WarpGemmTile WarpGemmTile; + + /// Tiling of warp accumulator elements + typedef typename RealEpilogueTraits::WarpGemmTile WarpDelta; + + /// Multiply-add operation + typedef typename GemmConfig::MultiplyAdd MultiplyAdd; + + /// The accumulators fragment type. + typedef typename MultiplyAdd::Accumulators Accumulators; + + /// Selects a subset of accumulators for a given epilogue iteration + typedef ZipSelectAccumulators< + typename RealEpilogueTraits::SelectAccumulators, + typename RealEpilogueTraits::SelectAccumulators> SelectAccumulators; + + /// The iterator to load source matrix from global memory. + typedef cutlass::PredicatedTileLoadStream< + ZipTileIterator< + typename RealEpilogueTraits::GlobalLoadStreamC::Iterator, + typename RealEpilogueTraits::GlobalLoadStreamC::Iterator + >, + typename RealEpilogueTraits::GlobalLoadStreamC::PredicateFunctor, + ZipConvert< + typename RealEpilogueTraits::GlobalLoadStreamC::Transformer, + typename RealEpilogueTraits::GlobalLoadStreamC::Transformer + > + > GlobalLoadStreamC; + + /// The iterator to store the final GEMM computation to global memory. + typedef cutlass::PredicatedTileStoreStream< + ZipTileIterator< + typename RealEpilogueTraits::GlobalStoreStreamD::Iterator, + typename RealEpilogueTraits::GlobalStoreStreamD::Iterator + >, + typename RealEpilogueTraits::GlobalStoreStreamD::PredicateFunctor, + ZipConvert< + typename RealEpilogueTraits::GlobalStoreStreamD::Transformer, + typename RealEpilogueTraits::GlobalStoreStreamD::Transformer + > + > GlobalStoreStreamD; + + /// The stream to store matrix product to shared memory + typedef cutlass::TileStoreStream< + ZipTileIterator< + typename RealEpilogueTraits::SharedStoreStreamD::Iterator, + typename RealEpilogueTraits::SharedStoreStreamD::Iterator + >, + ZipConvert< + typename RealEpilogueTraits::SharedStoreStreamD::Transformer, + typename RealEpilogueTraits::SharedStoreStreamD::Transformer + > + > SharedStoreStreamD; + + /// The stream to load the matrix product from shared memory + typedef cutlass::TileLoadStream< + ZipTileIterator< + typename RealEpilogueTraits::SharedLoadStreamD::Iterator, + typename RealEpilogueTraits::SharedLoadStreamD::Iterator + >, + ZipConvert< + typename RealEpilogueTraits::SharedLoadStreamD::Transformer, + typename RealEpilogueTraits::SharedLoadStreamD::Transformer + > + > SharedLoadStreamD; + + /// The scalar type of the source accumulator matrix. + typedef typename RealEpilogueTraits::ScalarC ScalarC; + + /// The scalar type of the destination accumulator matrix. + typedef typename RealEpilogueTraits::ScalarD ScalarD; + + // + // Dependent types + // + + /// Cover an entire warp-level tile + typedef typename RealEpilogueTraits::Iterations Iterations; + + /// Parameters structure initialized on the host + struct Params { + /// The params for the C iterator. + typename GlobalLoadStreamC::Params load_stream_c; + + /// The params for the D global iterator. + typename GlobalStoreStreamD::Params store_stream_d; + + /// Epilogue functor params + typename Functor::Params functor; + + /// The params for the D shared store iterator. + typename SharedStoreStreamD::Params shared_store_stream_d; + + /// The params for the D shared load stream. + typename SharedLoadStreamD::Params shared_load_stream_d; + + /// Stride for C + platform::Pair batch_stride_C; + + /// Stride for D + platform::Pair batch_stride_D; + + // + // Methods + // + + /// Default constructor + CUTLASS_HOST_DEVICE + Params() { + batch_stride_C.first = 0; + batch_stride_C.second = 0; + + batch_stride_D.first = 0; + batch_stride_D.second = 0; + } + + /// Setup the params. + CUTLASS_HOST_DEVICE int initialize( + platform::complex alpha, + platform::complex beta, + ScalarC const* real_C, + Index real_ldc, + ScalarC const* imag_C, + Index imag_ldc, + ScalarD* real_D, + Index real_ldd, + ScalarD* imag_D, + Index imag_ldd) { + + int result = functor.initialize(alpha, beta); + if (result) { + return result; + } + + // Setup the params for the global memory iterator for C. + result = load_stream_c.iterator.first.initialize( + real_C, real_ldc, real_ldc, 1); + + if (result) { + return result; + } + + result = load_stream_c.iterator.second.initialize( + imag_C, imag_ldc, imag_ldc, 1); + + if (result) { + return result; + } + + // Setup the params for the global memory iterator for D. + result = store_stream_d.iterator.first.initialize( + real_D, real_ldd, real_ldd, 1); + + if (result) { + return result; + } + + result = store_stream_d.iterator.second.initialize( + imag_D, imag_ldd, imag_ldd, 1); + + if (result) { + return result; + } + + return result; + } + + /// Setup the params. + CUTLASS_HOST_DEVICE int initialize( + platform::complex alpha, + platform::complex beta, + ScalarC const* real_C, + Index real_ldc, + LongIndex stride_C_real, + ScalarC const* imag_C, + Index imag_ldc, + LongIndex stride_C_imag, + ScalarD* real_D, + Index real_ldd, + LongIndex stride_D_real, + ScalarD* imag_D, + Index imag_ldd, + LongIndex stride_D_imag) { + + batch_stride_C.first = stride_C_real; + batch_stride_C.second = stride_C_imag; + + batch_stride_D.first = stride_D_real; + batch_stride_D.second = stride_D_imag; + + return initialize(alpha, beta, real_C, real_ldc, imag_C, imag_ldc, real_D, real_ldd, imag_D, imag_ldd); + } + }; + + /// Shared memory buffer used by epilogue + typedef ZipTileAllocation< + typename RealEpilogueTraits::SharedStorage, + typename RealEpilogueTraits::SharedStorage> SharedStorage; + + /// Functor computing the offset from the threadblock origin per iteration of + /// the epilogue. + typedef typename RealEpilogueTraits::GlobalOffset GlobalOffset; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm + +namespace platform { + +/// Here's a helpful arithmetic operator +CUTLASS_HOST_DEVICE +Pair operator*(int s, Pair _pair) { + return Pair(s * _pair.first, s * _pair.second); +} + +} + +} // namespace cutlass + +// clang-format on diff --git a/cutlass/gemm/volta884_complex_gemm_traits.h b/cutlass/gemm/volta884_complex_gemm_traits.h new file mode 100644 index 0000000000..593b28dd46 --- /dev/null +++ b/cutlass/gemm/volta884_complex_gemm_traits.h @@ -0,0 +1,558 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines structural properties for complex-valued GEMM targeting Volta's mma.sync + instruction. + + At present, it expects split complex representation in global memory in which the real part and + imaginary parts of a complex-valued matrices are disjoint (a structure of arrays). This is in + contrast with an interleaved complex representation which is an array of structures. +*/ + +#pragma once + +// clang-format off + +#include "cutlass/gemm/clear_accumulators.h" +#include "cutlass/gemm/gemm_config.h" +#include "cutlass/gemm/gemm_stream_pair.h" +#include "cutlass/gemm/threadblock_swizzle.h" +#include "cutlass/gemm/linear_scaling.h" +#include "cutlass/kernel_launch.h" +#include "cutlass/tensor_ref_collection.h" + +#include "cutlass/gemm/gemm_desc.h" + +#include "cutlass/gemm/volta884_multiplicand.h" +#include "cutlass/gemm/mma_shared_stream.h" +#include "cutlass/gemm/volta884_gemm_traits.h" + +#include "cutlass/gemm/volta884_complex_multiply_add.h" +#include "cutlass/gemm/volta884_complex_global_stream.h" +#include "cutlass/gemm/volta884_complex_shared_stream.h" +#include "cutlass/gemm/volta884_complex_gemm_epilogue_traits.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines configuration for Volta884 GEMM +template < + /// The layout for A. + MatrixLayout::Kind LayoutA, + /// Indicates matrix transform on multiplicand A + MatrixTransform::Kind TransformA, + /// The layout for B. + MatrixLayout::Kind LayoutB, + /// Indicates matrix transform on multiplicand B + MatrixTransform::Kind TransformB, + /// The tile size for the GEMM KxNxM. + typename OutputTile_, + /// Tile size for warp-level GEMM (K-by-N-by-M) + typename WarpGemmShape_, + /// The accumulator type. + typename Accumulator_, + /// The source matrix type type. + typename ScalarC_, + /// The destination matrix type + typename ScalarD_, + /// Number of stages in shared memory + int StageCount, + /// Enables or disables launch bounds + bool LaunchBounds> +struct Volta884ComplexGemmConfig : public GemmConfig< + /// The scalar type for A. + half, + /// The scalar type for B. + half, + /// The scalar type for C. + ScalarC_, + /// The scalar type for D. + ScalarD_, + /// The threadblock tile size + OutputTile_, + /// The functor to do the math in the main loop. + Volta884ComplexMultiplyAdd, + /// The number of scalars per LDG for A. + 8, + /// The number of scalars per STS for A. + 8, + /// The number of scalars per LDS for A. + 8, + /// The number of scalars per LDG for B. + 8, + /// The number of scalars per STS for B. + 8, + /// The number of scalars per LDS for B. + 8, + /// The number of scalars per LDG for C and STG for D. + 16 / int(sizeof(ScalarD_)), + /// The number of scalars per STS for D. + 16 / int(sizeof(ScalarD_)), + /// The number of scalars per LDS for D. + 16 / int(sizeof(ScalarD_)), + /// The number of stages in shared memory. + StageCount, + /// If true, separate mainloop is instantiated + true, + /// If true, compute residue in prolog + false, + /// Launch bounds not used + LaunchBounds> {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines components of Volta884 GEMM +template < + /// The layout for A. + MatrixLayout::Kind LayoutA, + /// Indicates matrix transform on multiplicand A + MatrixTransform::Kind TransformA, + /// The layout for B. + MatrixLayout::Kind LayoutB, + /// Indicates matrix transform on multiplicand B + MatrixTransform::Kind TransformB, + /// The tile size for the GEMM KxNxM. + typename OutputTile_, + /// Tile size for warp-level GEMM (K-by-N-by-M) + typename WarpGemmShape_, + /// The accumulator type. + typename Accumulator_, + /// The input matrix type type. + typename ScalarC_, + /// The output matrix type type. + typename ScalarD_, + /// Number of buffers in shared memory to use + int StageCount, + /// The functor to do the math in the epilogue. + typename EpilogueFunctor_ = SplitComplexLinearScaling, + /// Enables or disables launch bounds + bool LaunchBounds = false +> +struct Volta884ComplexGemmTraits { + + /// This is insane. + typedef Volta884ComplexGemmTraits< + LayoutA, + TransformA, + LayoutB, + TransformB, + OutputTile_, + WarpGemmShape_, + Accumulator_, + ScalarC_, + ScalarD_, + StageCount, + EpilogueFunctor_, + LaunchBounds> This; + + /// The actual device-side GEMM + typedef GemmMainloop KernelClass; + + /// Layout of multiplicand A matrix + static MatrixLayout::Kind const kLayoutA = LayoutA; + + /// If true, A operand is conjugated + static MatrixTransform::Kind const kTransformA = TransformA; + + /// Layout of multiplicand B matrix + static MatrixLayout::Kind const kLayoutB = LayoutB; + + /// If true, B operand is conjugated + static MatrixTransform::Kind const kTransformB = TransformB; + + /// Dimensions of threadblock tile (concept Shape) + typedef OutputTile_ OutputTile; + + /// Shape of warp-level accumulators + typedef WarpGemmShape_ WarpGemmShape; + + /// Multiplicand A scalar type + typedef half ScalarA; + + /// Multiplicand B scalar type + typedef half ScalarB; + + /// Data type of internal accumulator + typedef Accumulator_ Accumulator; + + /// Data type of input accumulator matrix operand + typedef ScalarC_ ScalarC; + + /// Data type of output accumulator matrix operand + typedef ScalarD_ ScalarD; + + /// Shape of individual mma.sync instruction + typedef Shape<4, 16, 16> InstructionShape; + + /// Tile size for an individual warp-level multiply-add + typedef Shape WarpTile; + + /// Defines properties about GEMM needed by host code + typedef Volta884ComplexGemmConfig< + kLayoutA, + kTransformA, + kLayoutB, + kTransformB, + OutputTile, + WarpGemmShape, + Accumulator, + ScalarC, + ScalarD, + StageCount, + LaunchBounds> + GemmConfig; + + // + // Derived types + // + + /// Index type + typedef int Index; + + /// Long index type + typedef long long LongIndex; + + /// Partitioning of threadblock into warps + typedef typename ShapeDiv::Shape WarpDelta; + + /// Number of warps per threadblock + static int const kWarpCount = ShapeCount::kCount; + + /// Defines iterators for A matrix + typedef Volta884Multiplicand + MultiplicandA; + + /// Defines iterators for B matrix + typedef Volta884Multiplicand + MultiplicandB; + + // + // GemmTraits mandatory type definitions + // + + /// Maps hardware threadblocks to logical partitions of the GEMM + typedef IdentityBlockSwizzle BlockSwizzle; + + /// Clears accumulators + typedef ClearAccumulators ClearAccumulators; + + /// Loads multiplicands from global memory + typedef GlobalLoadStreamPair< + Volta884ComplexGlobalLoadStream, + typename MultiplicandA::StoreIterator, + StageCount>, + Volta884ComplexGlobalLoadStream, + typename MultiplicandB::StoreIterator, + StageCount>, + GemmConfig::kResidueInProlog > + GlobalLoadStream; + + /// Memory needed to store the threadblock-scoped GEMM tile + typedef typename GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage; + + /// Shared memory storage for mainloop phase + union MainLoopStorage { + + /// Stores the threadblock tile + ThreadblockTileStorage threadblock_tile; + + /// Storage for GEMM global stream + typename GlobalLoadStream::SharedStorage global_to_shared_stream; + }; + + /// Loads multiplicands from shared memory + typedef SharedStreamPair< + Volta884ComplexSharedLoadStream, + StageCount>, + Volta884ComplexSharedLoadStream, + StageCount> > + SharedStream; + + // Multiply-add object specialized for Volta mma.sync + typedef typename GemmConfig::MultiplyAdd MultiplyAdd; + + #if 0 + /// Naive epilogue for updating the output matrix + typedef Volta884ComplexNaiveEpilogue + Epilogue; + + #else + + /// Efficient epilogue + typedef MMAEpilogue< + Volta884ComplexGemmEpilogueTraits + > Epilogue; + + #endif + + /// Tensor reference to A multiplicand + typedef ZipTensorRef< + TensorRef, + TensorRef + > TensorRefA; + + /// Tensor reference to B multiplicand + typedef ZipTensorRef< + TensorRef, + TensorRef + > TensorRefB; + + /// Tensor reference to C multiplicand + typedef ZipTensorRef< + TensorRef, + TensorRef + > TensorRefC; + + /// Tensor reference to D multiplicand + typedef ZipTensorRef< + TensorRef, + TensorRef + > TensorRefD; + + /// gemm::ProblemDesc<> + typedef GemmDesc< + TensorRefA, + TensorRefB, + TensorRefC, + TensorRefD, + float + > GemmDesc; + + /// Parameters structure + struct Params : public KernelLaunchConfiguration { + /// The dimensions of the GEMM. + GemmCoord problem_size; + + /// PartitionK_range + int partitionK_range; + + /// The params for the global load stream + typename GlobalLoadStream::Params global_to_shared_stream; + + /// The params for the shared load stream + typename SharedStream::Params shared_stream; + + /// The params for the epilogue. + typename Epilogue::Params epilogue; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() {} + + /// Initialize the Params struct + CUTLASS_HOST_DEVICE int initialize( + Index m, + Index n, + Index k, + platform::complex alpha, + ScalarA const* real_A, + Index real_lda, + ScalarA const* imag_A, + Index imag_lda, + ScalarB const* real_B, + Index real_ldb, + ScalarB const* imag_B, + Index imag_ldb, + platform::complex beta, + ScalarC const* real_C, + Index real_ldc, + ScalarC const* imag_C, + Index imag_ldc, + ScalarD* real_D, + Index real_ldd, + ScalarD* imag_D, + Index imag_ldd) { + + problem_size = make_Coord(k, n, m, 1); + + partitionK_range = problem_size.k(); + + // Compute grid dimensions + BlockSwizzle block_swizzle; + this->block = dim3(GemmConfig::kThreads); + this->grid = block_swizzle.get_grid_layout( + problem_size, + make_Coord_from_shape()); + + // Initialize global load streams + global_to_shared_stream.stream_a.initialize( + make_ZipTensorRef( + TensorRefBatchStrided(TensorRef(real_A, real_lda), 0), + TensorRefBatchStrided(TensorRef(imag_A, imag_lda), 0) + ), + 0 + ); + + global_to_shared_stream.stream_b.initialize( + make_ZipTensorRef( + TensorRefBatchStrided(TensorRef(real_B, real_ldb), 0), + TensorRefBatchStrided(TensorRef(imag_B, imag_ldb), 0) + ), + 0 + ); + + return epilogue.initialize( + alpha, + beta, + real_C, + real_ldc, + imag_C, + imag_ldc, + real_D, + real_ldd, + imag_D, + imag_ldd + ); + } + + /// Initialize the Params struct + CUTLASS_HOST_DEVICE int initialize( + Index m, + Index n, + Index k, + platform::complex alpha, + ScalarA const* real_A, + Index real_lda, + LongIndex batch_stride_A_real, + ScalarA const* imag_A, + Index imag_lda, + LongIndex batch_stride_A_imag, + ScalarB const* real_B, + Index real_ldb, + LongIndex batch_stride_B_real, + ScalarB const* imag_B, + Index imag_ldb, + LongIndex batch_stride_B_imag, + platform::complex beta, + ScalarC const* real_C, + Index real_ldc, + LongIndex batch_stride_C_real, + ScalarC const* imag_C, + Index imag_ldc, + LongIndex batch_stride_C_imag, + ScalarD* real_D, + Index real_ldd, + LongIndex batch_stride_D_real, + ScalarD* imag_D, + Index imag_ldd, + LongIndex batch_stride_D_imag, + int batch_count) { + + problem_size = make_Coord(k, n, m, batch_count); + partitionK_range = problem_size.k(); + + // Compute grid dimensions + BlockSwizzle block_swizzle; + this->block = dim3(GemmConfig::kThreads); + this->grid = block_swizzle.get_grid_layout( + problem_size, + make_Coord_from_shape()); + + // Initialize global load streams + global_to_shared_stream.stream_a.initialize( + make_ZipTensorRef( + TensorRefBatchStrided(TensorRef(real_A, real_lda), batch_stride_A_real), + TensorRefBatchStrided(TensorRef(imag_A, imag_lda), batch_stride_A_imag) + ), + 0 + ); + + global_to_shared_stream.stream_b.initialize( + make_ZipTensorRef( + TensorRefBatchStrided(TensorRef(real_B, real_ldb), batch_stride_B_real), + TensorRefBatchStrided(TensorRef(imag_B, imag_ldb), batch_stride_B_imag) + ), + 0 + ); + + return epilogue.initialize( + alpha, + beta, + real_C, + real_ldc, + batch_stride_C_real, + imag_C, + imag_ldc, + batch_stride_C_imag, + real_D, + real_ldd, + batch_stride_D_real, + imag_D, + imag_ldd, + batch_stride_D_imag + ); + } + }; + + /// Shared memory storage + union SharedStorage { + /// Storage required during mainloop phase + MainLoopStorage main_loop; + + /// Shared storage needed for epilogue + typename Epilogue::SharedStorage epilogue; + }; + + /// The memory fence for shared loads. + static CUTLASS_DEVICE void shared_load_fence(bool in_loop) { + if (StageCount < 2) { + __syncthreads(); + } + } + + /// The memory fence for shared stores. + static CUTLASS_DEVICE void shared_store_fence(bool in_loop) { __syncthreads(); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass + +// clang-format on diff --git a/cutlass/gemm/volta884_complex_global_stream.h b/cutlass/gemm/volta884_complex_global_stream.h new file mode 100644 index 0000000000..7e3a92cb2f --- /dev/null +++ b/cutlass/gemm/volta884_complex_global_stream.h @@ -0,0 +1,315 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements efficient loading of the thread block-level tile from global memory and + storing + to shared memory. +*/ + +#pragma once + +// clang-format off + +#include "cutlass/convert.h" +#include "cutlass/zip_tile_iterator.h" +#include "cutlass/zip_tensor_ref.h" +#include "cutlass/gemm/gemm_operand.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/util/pair.h" + +#include "cutlass/gemm/mma_global_stream.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +///! Stream adapter for loading threadblock-scoped GEMM tiles and storing to shared memory +template < + /// Identifies multiplicand + GemmOperand::Kind Operand, + /// Layout of source matrix in global memory + MatrixLayout::Kind Layout, + /// Iterator for loading threadblock-scoped tiles + typename LoadIterator_, + /// Transformation functor for transforming fragments + typename Transformer_, + /// Iterator for storing threadblock-scoped tiles to shared memory + typename StoreIterator_, + /// Number of stores before iterator wraps - zero indicates no wrapping + int StageCount> +struct Volta884ComplexGlobalLoadStream { + + // + // Type definitions + // + + /// Identifies the operand + static GemmOperand::Kind const kOperand = Operand; + + /// The layout. + static MatrixLayout::Kind const kLayout = Layout; + + /// Load-store stream for real-valued matrices + typedef MMAGlobalLoadStream RealLoadStoreStream; + + /// Loads a pair of real-valued fragments + typedef ZipTileIterator LoadIterator; + + /// Zips a pair of transformers + typedef ZipConvert Transformer; + + /// Stores a pair of real-valued ragments + typedef ZipTileIterator StoreIterator; + + /// Number of stages + static int const kStageCount = StageCount; + + /// Predicate vector + typedef typename RealLoadStoreStream::PredicateVector PredicateVector; + + /// The fragment that is copied from shared memory. + typedef typename LoadIterator::Fragment FetchedFragment; + /// The fragment that is obtained after the transformation by the transformer. + typedef typename Transformer::OutputFragment TransformedFragment; + /// Make sure the fragments match. + static_assert((platform::is_same::value), + ""); + /// The output fragment. + typedef TransformedFragment Fragment; + /// Make sure the transformed fragment is the same as the store fragment. + static_assert((platform::is_same::value), + ""); + + /// Index type + typedef typename RealLoadStoreStream::Index Index; + + /// Long index type + typedef typename RealLoadStoreStream::LongIndex LongIndex; + + /// The params. + struct Params { + + // + // Type definitions + // + + /// Matrix reference + typedef ZipTensorRef< + TensorRefBatchStrided, + TensorRefBatchStrided > SourceTensorRef; + + /// Helper + static int const kElementsPerLdg = LoadIterator::First::Tile::kC; + + // + // Data members + // + + /// Source tensor reference + platform::Pair batch_stride; + + // The load iterator. + typename LoadIterator::Params load_iterator; + + // Offset to residue. + Index offset_to_residue; + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() {} + + /// + CUTLASS_HOST_DEVICE + Params(SourceTensorRef const &ref, Index _offset_to_residue) { + initialize(ref, _offset_to_residue); + } + + CUTLASS_HOST_DEVICE + int initialize(SourceTensorRef const &ref, Index _offset_to_residue) { + + batch_stride.first = ref.first.tensor_stride; + batch_stride.second = ref.second.tensor_stride; + + offset_to_residue = _offset_to_residue; + load_iterator.first.initialize( + TensorRef( + ref.first.at().data(), + make_Coord(ref.first.at().stride(0) * kElementsPerLdg, ref.first.at().stride(0), kElementsPerLdg) + ) + ); + load_iterator.second.initialize( + TensorRef( + ref.second.at().data(), + make_Coord(ref.second.at().stride(0) * kElementsPerLdg, ref.second.at().stride(0), kElementsPerLdg) + ) + ); + return 0; + } + }; + + /// Empty shared storage + struct SharedStorage {}; + + /// Shared memory allocation for the tile + typedef TileAllocation< + typename RealLoadStoreStream::StoreIterator::Scalar, + typename ShapeMul< + typename RealLoadStoreStream::StoreIterator::OperandShape, + Shape + >::Shape + > RealThreadblockTileStorage; + + /// Threadblock tile allocation + typedef ZipTileAllocation< + RealThreadblockTileStorage, + RealThreadblockTileStorage + > ThreadblockTileStorage; + + /// Reference to ThreadblockTileStorage + typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef; + + // + // Data members + // + + ///! The parameters + Params params; + + ///! Dimensions of global memory tile + Coord<3> threadblock_offset; + + ///! Multiplicand bounds + Coord<3> multiplicand_bounds; + + ///! Iterator to load threadblock tiles from global memory + LoadIterator load_iterator; + + ///! Predicate vector + PredicateVector predicates; + + ///! The fragment to fetch from shared memory. + FetchedFragment fetched_fragment; + + ///! Functor to transform fragments after they have been loaded + Transformer transformer; + + ///! The fragment to convert the data after it has been fetched from shared memory. + TransformedFragment transformed_fragment; + + ///! Iterator to store threadblock tiles to shared memory + StoreIterator store_iterator; + + ///! Counter + int stage_index; + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE Volta884ComplexGlobalLoadStream(Params const &_params, + SharedStorage &shared_storage, + ThreadblockTileRef const &threadblock_tile_ref, + Coord<3> const bounds, + Coord<3> const &block) + : params(_params), + threadblock_offset(RealLoadStoreStream::project_coordinate(block)), + multiplicand_bounds(RealLoadStoreStream::project_coordinate(bounds, 1)), + load_iterator(params.load_iterator, threadblock_offset), + transformer(), + store_iterator(threadblock_tile_ref), + stage_index(0) { + + // initialize predicates used to guard loads + load_iterator.initialize_predicates( + predicates.begin(), multiplicand_bounds, threadblock_offset); + } + + /// Loads the data from global memory + CUTLASS_DEVICE void copy() { + load_iterator.load_post_increment(fetched_fragment, predicates.begin()); + } + + /// Transform and commit the data to shared memory + CUTLASS_DEVICE void commit() { + + transformer.transform(fetched_fragment, transformed_fragment); + store_iterator.store_post_increment(transformed_fragment); + + ++stage_index; + if (kStageCount && stage_index == kStageCount) { + store_iterator -= kStageCount; + stage_index = 0; + } + } + + /// Computes a predicate mask for loads during final threadblock tile load iteration + CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) { + // That's the residue! + Coord<3> _block_offset = threadblock_offset; + if (kOperand == GemmOperand::kA ^ kLayout == MatrixLayout::kRowMajor) { + // K-strided + _block_offset = + make_Coord(threadblock_offset[0], multiplicand_bounds[1] - k, threadblock_offset[2]); + } else { + // K-contiguous + _block_offset = make_Coord(threadblock_offset[0], + threadblock_offset[1], + multiplicand_bounds[2] - k / LoadIterator::First::Tile::kC); + } + + load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, _block_offset); + fetched_fragment.clear(); + } + + CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {} + + CUTLASS_DEVICE void rollback() {} + + /// Adds a Coord<3> to the underlying global load iterator + CUTLASS_DEVICE Volta884ComplexGlobalLoadStream &operator+=(Coord<3> const &offset) { + load_iterator += offset; + return *this; + } + + /// Adds an offset based on batch stride + CUTLASS_DEVICE Volta884ComplexGlobalLoadStream &add_batch_offset(int batch_id) { + load_iterator.first.add_pointer_offset(params.batch_stride.first * batch_id); + load_iterator.second.add_pointer_offset(params.batch_stride.second * batch_id); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass + +// clang-format on diff --git a/cutlass/gemm/volta884_complex_multiply_add.h b/cutlass/gemm/volta884_complex_multiply_add.h new file mode 100644 index 0000000000..5120c77b4a --- /dev/null +++ b/cutlass/gemm/volta884_complex_multiply_add.h @@ -0,0 +1,319 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements warp-level multiply-accumulate operations using Volta's mma.sync instruction + for complex-valued data types. +*/ + +#pragma once + +#include "cutlass/util/complex.h" +#include "cutlass/zip_fragment.h" +#include "cutlass/gemm/volta884_multiply_add.h" +#include "cutlass/zip_fragment.h" + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Shape of a warp-level GEMM (K-by-N-by-M) + typename WarpGemmShape_, + /// Layout of multiplicand A + MatrixLayout::Kind LayoutA, + /// Indicates matrix transform on multiplicand A + MatrixTransform::Kind TransformA, + /// Data type of multiplicand A + typename ScalarA_, + /// Layout of multiplicand B + MatrixLayout::Kind LayoutB, + /// Indicates matrix transform on multiplicand B + MatrixTransform::Kind TransformB, + /// Data type of multiplicand B + typename ScalarB_, + /// Data type of accumulators + typename ScalarC_, + /// If true, A operand is conjugated + bool ConjugateA = false, + /// If true, B operand is conjugated + bool ConjugateB = false, + /// If true, infinite results are saturated to +-MAX_FLOAT + bool SatFinite = false> +struct Volta884ComplexMultiplyAdd { + // + // Constant and type definitions + // + + /// Shape of a warp-level GEMM (K-by-N-by-M) + typedef WarpGemmShape_ WarpGemmShape; + + /// Shape of a warp-level GEMM (K-by-N-by-M) + typedef WarpGemmShape_ AccumulatorsPerWarp; + + /// Most of the Volta884 code assumes interleaved 32x32 tiles + typedef Shape<4, 32, 32> InterleavedTileShape; + + /// Shape of an individual warp-wide mma.sync instruction + typedef Shape<4, 16, 16> InstructionShape; + + /// Shape of a warp-level matrix multiply operation + typedef Shape WarpTile; + + /// Verify WarpTile is a multiple of fundamental 32x32 interleaved tile + static_assert(!(WarpTile::kH % InterleavedTileShape::kH) && + !(WarpTile::kW % InterleavedTileShape::kW) && WarpTile::kD == 4, + "WarpTile must be a multiple of InterleavedTileShape."); + + /// Layout of A multiplicand + static MatrixLayout::Kind const kLayoutA = LayoutA; + + /// Indicates matrix transform on multiplicand B + static MatrixTransform::Kind const kTransformA = TransformA; + + /// Layout of B multiplicand + static MatrixLayout::Kind const kLayoutB = LayoutB; + + /// Indicates matrix transform on multiplicand B + static MatrixTransform::Kind const kTransformB = TransformB; + + /// The type for A. + typedef ScalarA_ ScalarA; + /// The type for B. + typedef ScalarB_ ScalarB; + /// The type for C and D. + typedef ScalarC_ ScalarC; + + /// If true, infinite results are saturated to +-MAX_FLOAT + static bool const kSatFinite = SatFinite; + + /// Hard-coded comptue type supported on Volta + static arch::ComputeType::Kind const kComputeType = arch::ComputeType::kDefault; + + /// Underlying matrix multiply-add operator + typedef Volta884MultiplyAdd + RealMultiplyAdd; + + /// Fragment definition for A multiplicand + typedef ZipFragment + FragmentA; + + /// Fragment definition for B multiplicand + typedef ZipFragment + FragmentB; + + /// Fragment definition for accumulators + typedef ZipFragment + Accumulators; + + /// Number of mma.sync operations performed. See Volta884MultiplyAdd::Iterations for details. + typedef typename RealMultiplyAdd::Iterations Iterations; + + // + // Methods + // + + /// Ctor. + CUTLASS_DEVICE Volta884ComplexMultiplyAdd() {} + + /// Multiply : d = a*b. + CUTLASS_DEVICE void multiply_add(FragmentA const& A, + FragmentB const& B, + Accumulators const& C, + Accumulators& D) { + RealMultiplyAdd op; + + // complex-valued multiply-add + op.multiply_add(A.first, B.first, C.first, D.first); + op.multiply_add(A.first, B.second, C.second, D.second, kTransformB == MatrixTransform::kConjugate); + op.multiply_add(A.second, B.first, C.second, D.second, kTransformA == MatrixTransform::kConjugate); + op.multiply_add(A.second, B.second, C.first, D.first, + !((kTransformA == MatrixTransform::kConjugate) ^ (kTransformB == MatrixTransform::kConjugate))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Complex-valued epilogue +template +struct Volta884ComplexNaiveEpilogue { + /// Accumulator data type + typedef Accumulator ScalarC; + + /// Output accumulator type + typedef Accumulator ScalarD; + + /// BLAS Scalar type + typedef Accumulator Scalar; + + /// Real-valued epilogue + typedef Volta884NaiveEpilogue RealEpilogue; + + /// Params object + struct Params { + /// Parameters for the real-valued part + typename RealEpilogue::Params real; + + /// Parameters for the imaginary-valued part + typename RealEpilogue::Params imag; + + // + // Methods + // + + /// Default constructor + CUTLASS_HOST_DEVICE Params() {} + + /// Constructs from params object + CUTLASS_HOST_DEVICE Params(typename RealEpilogue::Params const& _real, + typename RealEpilogue::Params const& _imag) + : real(_real), imag(_imag) {} + + /// Construct from pointers + CUTLASS_HOST_DEVICE Params(ScalarC* _real, int _ldr, ScalarC* _imag, int _ldi) + : real(_real, _ldr), imag(_imag, _ldi) {} + + /// Construct from pointers + CUTLASS_HOST_DEVICE Params( + platform::complex const &alpha, + platform::complex const &beta, + ScalarC const *real_C, + int real_ldc, + ScalarC const *imag_C, + int imag_ldc, + ScalarD *real_D, + int real_ldd, + ScalarD *imag_D, + int imag_ldd + ): + real(real_D, real_ldd, alpha.real(), beta.real()), + imag(imag_D, imag_ldd, alpha.real(), beta.real()) { } + + /// Initializer method + CUTLASS_HOST_DEVICE + int initialize( + platform::complex const &alpha, + platform::complex const &beta, + ScalarC const *real_C, + int real_ldc, + ScalarC const *imag_C, + int imag_ldc, + ScalarD *real_D, + int real_ldd, + ScalarD *imag_D, + int imag_ldd + ) { + + real = typename RealEpilogue::Params(real_D, real_ldd, alpha.real(), beta.real()); + imag = typename RealEpilogue::Params(imag_D, imag_ldd, alpha.real(), beta.real()); + + return 0; + } + }; + + /// Shared stoarge + struct SharedStorage {}; + + /// Accumulator fragment definition + typedef ZipFragment< + typename RealEpilogue::Accumulators, + typename RealEpilogue::Accumulators> Accumulators; + + // + // Data members + // + + /// Epilogue for real part + RealEpilogue real; + + /// Epilogue for imaginary part + RealEpilogue imag; + + // + // Methods + // + + /// Constructs a complex-valued epilogue + CUTLASS_DEVICE Volta884ComplexNaiveEpilogue( + Params const& _params, Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024)) + : real(_params.real, _problem_size), imag(_params.imag, _problem_size) {} + + /// Constructs a complex-valued epilogue + CUTLASS_DEVICE Volta884ComplexNaiveEpilogue(ScalarC* _real, + int _ldr, + ScalarC* _imag, + int _ldi, + Coord<3> const& _problem_size = make_Coord(1024, + 1024, + 1024)) + : real(_real, _ldr, _problem_size), imag(_imag, _ldi, _problem_size) {} + + /// Constructs a complex-valued epilogue + CUTLASS_DEVICE Volta884ComplexNaiveEpilogue(Params const& _params, + SharedStorage& shared_storage, + Coord<3> const& _problem_size = make_Coord(1024, + 1024, + 1024)) + : real(_params.real, _problem_size), imag(_params.imag, _problem_size) {} + + /// Sets accumulators to zero + CUTLASS_DEVICE void clear(Accumulators& C) { + C.first.clear(); + C.second.clear(); + } + + /// Naive load operation for debugging + CUTLASS_DEVICE void load(Accumulators& C, + Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) { + real.load(C.first, threadblock_offset); + imag.load(C.second, threadblock_offset); + } + + /// Naive store operation for debugging + CUTLASS_DEVICE void store(Accumulators const& C, + Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) { + real.store(C.first, threadblock_offset); + imag.store(C.second, threadblock_offset); + } + + /// CUTLASS Epilogue interface + CUTLASS_DEVICE void epilogue(Accumulators const& C, + Coord<3> const& threadblock_offset = make_Coord(0, 0, 0), + int batch_id = 0) { + real.store(C.first, threadblock_offset); + imag.store(C.second, threadblock_offset); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/volta884_complex_shared_stream.h b/cutlass/gemm/volta884_complex_shared_stream.h new file mode 100644 index 0000000000..17d2afc008 --- /dev/null +++ b/cutlass/gemm/volta884_complex_shared_stream.h @@ -0,0 +1,152 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements efficient loading of the thread block-level tile from global memory and + storing to shared memory. +*/ + +#pragma once + +#include "cutlass/convert.h" +#include "cutlass/zip_fragment.h" +#include "cutlass/zip_tensor_ref.h" +#include "cutlass/zip_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Stream from shared memory to fragments for warp-level matrix multiply-accumulate +template < + /// The load iterator. + typename Iterator_, + /// The transformer to be applied after the data has been copied from shared memory. + typename Transformer_ = Copy, + /// Number of increments before iterator wraps - zero indicates no wrapping + int StageCount = 1> +struct Volta884ComplexSharedLoadStream { + /// The load iterator. + typedef Iterator_ RealIterator; + + /// Zips two real-valued iterators together + typedef ZipTileIterator Iterator; + + /// The transformer. + typedef Transformer_ RealTransformer; + + /// Zips two transfoerms + typedef ZipConvert Transformer; + + /// Number of increments before iterator wraps - zero indicates no wrapping + static int const kStageCount = StageCount; + + /// The fragment that is copied from shared memory. + typedef typename Iterator::Fragment FetchedFragment; + + /// The fragment that is obtained after the transformation by the transformer. + typedef typename Transformer::OutputFragment TransformedFragment; + + /// Make sure the fragments match. + static_assert((platform::is_same::value), + ""); + + /// The output fragment. + typedef TransformedFragment Fragment; + + /// Reference type + typedef ZipTensorRef< + TensorRef, + TensorRef + > TensorRef; + + /// Parameters passed from host + struct Params { }; + + // + // Data members + // + + /// Iterator for loading fragments for warp-level matrix multiply-accumulate + Iterator iterator; + + /// Fetched fragment + FetchedFragment fetched[2]; + + /// The transformer. + Transformer transformer; + + /// Transformed fragment + TransformedFragment transformed[2]; + + /// Counts the number of stages + int stage_index; + + // + // Methods + // + + /// Ctor. + CUTLASS_DEVICE Volta884ComplexSharedLoadStream() : stage_index(0) {} + + /// Ctor. + CUTLASS_DEVICE Volta884ComplexSharedLoadStream(Params const &_params, + TensorRef const &ref) + : iterator(ref), stage_index(0) {} + + /// Load the data from shared memory to the fetch fragment. + CUTLASS_DEVICE void copy(int step) { + iterator.load(fetched[step % 2], + make_Coord(step + stage_index * Iterator::First::VectorizedShape::kD, 0, 0, 0)); + } + + /// Commit the data. + CUTLASS_DEVICE void commit(int step) { + transformer.transform(fetched[step % 2], transformed[step % 2]); + } + + /// Gets the transformed fragment + CUTLASS_DEVICE + TransformedFragment &fragment(int step) { return transformed[step % 2]; } + + /// Gets the transformed fragment + CUTLASS_DEVICE + TransformedFragment const &fragment(int step) const { return transformed[step % 2]; } + + /// Increment the stage. + CUTLASS_DEVICE void inc_stage() { + ++stage_index; + if (kStageCount && stage_index == StageCount) { + stage_index = 0; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/volta884_gemm_epilogue_traits.h b/cutlass/gemm/volta884_gemm_epilogue_traits.h new file mode 100644 index 0000000000..02b154204f --- /dev/null +++ b/cutlass/gemm/volta884_gemm_epilogue_traits.h @@ -0,0 +1,771 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory + with the computed matrix product. +*/ + +#pragma once + +// clang-format off + +#include "cutlass/tile_stream.h" +#include "cutlass/tile_allocation.h" + +#include "cutlass/gemm/mma_shared_stream.h" + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstraction to select accumulators from an accumulator tile for each iteration fo the epilogue +template +struct Volta884SelectAccumulators; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Selects accumulators from Volta mma.sync.F32 layout +template +struct Volta884SelectAccumulators { + /// Shape of the warp-level matrix multiply operation + typedef WarpGemmShape_ WarpGemmShape; + + /// Describes tiling of warp elements + typedef WarpDelta_ WarpDelta; + + /// Data type of scalar + typedef float Scalar; + + // + // Derived types and constants + // + + /// (Actual) number of accumulators held by each individual thread + static int const kAccumulatorsPerThread = (WarpGemmShape::kH * WarpGemmShape::kW) / kWarpSize; + + /// Accumulators fragment + typedef Fragment Accumulators; + + /// Number of warps + static int const kWarpCount = ShapeCount::kCount; + + /// Interleaved mma.sync shape + typedef Shape<4, 32, 32> MmaTileShape; + + /// Hard-coded for FP32 layouts + typedef Shape<1, WarpGemmShape::kW / MmaTileShape::kW, 4> Elements; + + /// Number of elements + static int const kElements = ShapeCount::kCount; + + /// Slice of accumulators + typedef Fragment Fragment; + + // + // Methods + // + + /// Selects accumulators for a given iteration of the epilogue + CUTLASS_DEVICE + Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const { + Fragment frag; + + static int const kAccumPerOp = 8; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Elements::kH; ++j) { + + // selects the 32x32 tile + Coord<2> tile_32x32 = make_Coord(idx[0] / 8, j); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Elements::kW; ++i) { + Coord<2> mma_op = make_Coord(((idx[0] >> 1) & 1), i / 2); + + int element = ((i & 1) << 1) | (idx[0] & 1) | (idx[0] & 4); + + int mma_op_idx = mma_op[1] + mma_op[0] * 2 + 4 * (tile_32x32[1] + 2 * tile_32x32[0]); + + frag[i + j * Elements::kW] = accum[element + kAccumPerOp * mma_op_idx]; + } + } + + return frag; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Selects accumulators from Volta mma.sync.F16 layout +template +struct Volta884SelectAccumulators { + /// Shape of the warp-level matrix multiply operation + typedef WarpGemmShape_ WarpGemmShape; + + /// Describes tiling of warp elements + typedef WarpDelta_ WarpDelta; + + /// Data type of accumulator elements + typedef half Scalar; + + // + // Derived types and constants + // + + /// (Actual) number of accumulators held by each individual thread + static int const kAccumulatorsPerThread = (WarpGemmShape::kH * WarpGemmShape::kW) / kWarpSize; + + /// Accumulators fragment + typedef Fragment Accumulators; + + /// Number of warps + static int const kWarpCount = ShapeCount::kCount; + + /// Interleaved mma.sync shape + typedef Shape<4, 32, 32> MmaTileShape; + + /// Hard-coded for FP16 layouts + typedef Shape<1, WarpGemmShape::kW / MmaTileShape::kW, 2> Elements; + + /// Number of elements + static int const kElements = ShapeCount::kCount; + + /// Slice of accumulators + typedef Fragment Fragment; + + // + // Methods + // + + /// Selects accumulators for a given iteration of the epilogue + CUTLASS_DEVICE + Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const { + Fragment frag; + + static int const kAccumPerOp = 8; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Elements::kH; ++j) { + + // selects the 32x32 tile + Coord<2> tile_32x32 = make_Coord(idx[0] / 16, j); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Elements::kW; ++i) { + + Coord<2> mma_op = make_Coord(((idx[0] >> 2) & 1), i & 1); + + int element = (idx[0] & 3) | ((idx[0] >> 1) & 4); + + int mma_op_idx = mma_op[1] + mma_op[0] * 2 + 4 * (tile_32x32[1] + 2 * tile_32x32[0]); + + frag[i + j * Elements::kW] = accum[element + kAccumPerOp * mma_op_idx]; + } + } + + return frag; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// The warp-level GEMM tile + typename WarpGemmTile_, + /// Tiling of warp accumulator elements + typename WarpDelta_, + /// Size of vector to load or store + int AccessSize, + /// The accumulators fragment type - implies accumulator layout + typename Accumulators_> +struct Volta884EpilogueGlobalTileTraits; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Global tile traits specialized for Volta mma.sync.F32 layout +template < + /// The warp-level GEMM tile + typename WarpGemmTile_, + /// Tiling of warp accumulator elements + typename WarpDelta_, + /// Size of vector to load or store + int AccessSize> +struct Volta884EpilogueGlobalTileTraits { + /// Shape of warp-scoped GEMM tile + typedef WarpGemmTile_ WarpGemmTile; + + /// Structure of MMA + typedef WarpDelta_ WarpDelta; + + /// Access size of input/output elements + static int const kAccessSize = AccessSize; + + /// Scalar type of accumulators - used to imply accumulator layout, not the data + typedef float Accumulators; + + /// Strides for immediate offset computation + typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides; + + //typedef Shape<2, 2, 1, 1> Iterations; + + /// Hard-coded pitch between Volta mma.sync Quad Pair tiles + static int const kMmaQuadPairWidth = 16; + + /// Hard-coded pitch between warp tiles + static int const kInterleavedTileWidth = 32; + + /// Number of actual threads + static int const kThreadCount = (WarpDelta::kH * WarpDelta::kW) * kWarpSize; + + /// Shape of the tile + typedef Shape<2 * WarpDelta::kH, 2, WarpGemmTile::kW * WarpDelta::kW, 1> Tile; + + /// Number of iterations + typedef Shape<2 * WarpDelta::kH, + (kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH), + (kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount), + 1> Iterations; + + /// Delta between accesses + typedef Shape Delta; + + /// Number of warps in threadblock + static int const kWarpCount = ShapeCount::kCount; + + /// Custom thread-offset function + struct ThreadOffset { + CUTLASS_DEVICE + Coord<4> operator()() { + + int tid = threadIdx.x; + + int residual_w = (tid / (Tile::kW)); + int offset_w = (tid % (Tile::kW)); + + int offset_h = (residual_w % Tile::kH); + int offset_d = (residual_w / Tile::kH); + + Coord<4> offset = make_Coord(offset_d * Delta::kD, offset_h * Delta::kH, offset_w, 0); + + return offset; + } + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Global tile traits specialized for Volta mma.sync.F16 layout +template < + /// The warp-level GEMM tile + typename WarpGemmTile_, + /// Tiling of warp accumulator elements + typename WarpDelta_, + /// Size of vector to load or store + int AccessSize> +struct Volta884EpilogueGlobalTileTraits { + /// Shape of warp-scoped GEMM tile + typedef WarpGemmTile_ WarpGemmTile; + + /// Structure of MMA tiles + typedef WarpDelta_ WarpDelta; + + /// Access size of input/output elements + static int const kAccessSize = AccessSize; + + /// Scalar type of accumulators - used to imply accumulator layout, not the data + typedef half Accumulators; + + /// Hard-coded pitch between Volta mma.sync Quad Pair tiles + static int const kMmaQuadPairWidth = 16; + + /// Hard-coded pitch between warp tiles + static int const kInterleavedTileWidth = 32; + + /// Number of participating threads + static int const kThreadCount = kWarpSize * WarpDelta::kH * WarpDelta::kW; + + /// Shape of the tile + typedef Shape<1, 2 * WarpDelta::kH, WarpGemmTile::kW * WarpDelta::kW, 1> Tile; + + /// Strides for immediate offset computation + typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides; + + /// Number of iterations + typedef Shape< + 1, + (kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH), + (kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount), + 1> Iterations; + + + /// Delta between thread-level accesses + typedef typename platform::conditional< + kThreadCount >= Tile::kW, + Shape<1, kMmaQuadPairWidth * (kThreadCount / Tile::kW), 1, 1>, + Shape<1, kMmaQuadPairWidth, kThreadCount, 1> + >::type Delta; + + /// Number of warps in threadblock + static int const kWarpCount = ShapeCount::kCount; + + /// Custom thread-offset function + struct ThreadOffset { + CUTLASS_DEVICE + Coord<4> operator()() { + + int tid = threadIdx.x; + + int residual_w = (tid / (Tile::kW)); + int offset_w = (tid % (Tile::kW)); + + int offset_h = (residual_w % Tile::kH); + int offset_d = (residual_w / Tile::kH); + + Coord<4> offset = make_Coord(offset_d * Delta::kD, offset_h * kMmaQuadPairWidth, offset_w, 0); + + return offset; + } + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Global offset functor for Volta884 epilogues +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Volta884EpilogueGlobalOffset; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Functor computing the offset from the threadblock origin per iteration of +/// the epilogue. Specialized for Volta mma.sync.F32 +template +struct Volta884EpilogueGlobalOffset { + + /// mma.sync instructions are arranged as spatially overlapping 32x32 tiles + typedef Shape<4, 32, 32> MmaTileShape; + + CUTLASS_DEVICE + Coord<3> operator()(Coord<2> const &iteration) const { + + int h = iteration[0]; + + // C++ needs a better way to express bit swizzling + int h_offset = ((h & 1) | ((h & 2) << 1) | (((h & 4) >> 2) * 8) | + (((h & 8) >> 3) * WarpDelta::kH * MmaTileShape::kH)); + + return make_Coord(0, h_offset, iteration[1] * MmaTileShape::kW * WarpDelta::kW); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Functor computing the offset from the threadblock origin per iteration of +/// the epilogue. Specialized for Volta mma.sync.F16 +template +struct Volta884EpilogueGlobalOffset { + + /// mma.sync instructions are arranged as spatially overlapping 32x32 tiles + typedef Shape<4, 32, 32> MmaTileShape; + + CUTLASS_DEVICE + Coord<3> operator()(Coord<2> const &iteration) const { + + int h = iteration[0]; + + // C++ needs a better way to express bit swizzling + int h_offset = (h & 15) | (h & 16) * 2 * WarpDelta::kH; + + Coord<3> offset = make_Coord(0, h_offset, iteration[1] * MmaTileShape::kW * WarpDelta::kW); + return offset; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Epilogue traits for Volta884 epilogue +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Epilogue traits for Volta884 GEMMs +template < + /// The threadblock GEMM tile + typename OutputTile_, + /// The warp-level GEMM tile + typename WarpGemmTile_, + /// Tiling of warp accumulator elements + typename WarpDelta_, + /// The accumulators fragment type. + typename Accumulators_, + /// Selects a slice of accumulators + typename SelectAccumulators_, + /// The iterator to load source matrix from global memory. + typename GlobalLoadStreamC_, + /// The iterator to store the final GEMM computation to global memory. + typename GlobalStoreStreamD_, + /// The stream to store matrix product to shared memory + typename SharedStoreStreamD_, + /// The stream to load the matrix product from shared memory + typename SharedLoadStreamD_, + /// The functor computing an element-wise operation on the matrix product + typename Functor_, + /// Global memory mapping function + typename GlobalDataLayout_ = MatrixLayout::ColumnMajor, + /// The index. + typename Index_ = int> +struct Volta884EpilogueTraits { + /// The output tile. + typedef OutputTile_ OutputTile; + + /// The warp-level GEMM tile + typedef WarpGemmTile_ WarpGemmTile; + + /// Tiling of warp accumulator elements + typedef WarpDelta_ WarpDelta; + + /// The accumulators fragment type. + typedef Accumulators_ Accumulators; + + /// Selects a subset of accumulators for a given epilogue iteration + typedef SelectAccumulators_ SelectAccumulators; + + /// The iterator to load source matrix from global memory. + typedef GlobalLoadStreamC_ GlobalLoadStreamC; + + /// The iterator to store the final GEMM computation to global memory. + typedef GlobalStoreStreamD_ GlobalStoreStreamD; + + /// The stream to store matrix product to shared memory + typedef SharedStoreStreamD_ SharedStoreStreamD; + + /// The stream to load the matrix product from shared memory + typedef SharedLoadStreamD_ SharedLoadStreamD; + + /// The functor computing an element-wise operation on the matrix product + typedef Functor_ Functor; + + /// Global memory mapping function + typedef GlobalDataLayout_ GlobalDataLayout; + + /// The index. + typedef Index_ Index; + + /// The scalar type of the source accumulator matrix. + typedef typename GlobalLoadStreamC::Iterator::Scalar ScalarC; + + /// The scalar type of the destination accumulator matrix. + typedef typename GlobalStoreStreamD::Iterator::Scalar ScalarD; + + // + // Dependent types + // + + static bool const kFp32Arrangement = sizeof(typename SelectAccumulators::Scalar) == 4; + + /// Skew elements + static int const kSkew = 2; + + /// Number of columns of accumulators stored/loaded depends on the accumulator arrangement + static int const kColumnsPerWarp = (kFp32Arrangement ? 4 : 2); + + /// mma.sync instructions are arranged as spatially overlapping 32x32 tiles + typedef Shape<4, 32, 32> MmaTileShape; + + /// Cover an entire warp-level tile + typedef Shape<1, + WarpGemmTile::kH / kColumnsPerWarp, // iterates over 32x32 accumulator tiles along N dimension + 1, // iterates over 32x32 accumulator tiles along M dimension + 1> + Iterations; + + /// Skew is needed to reduce bank conflicts to SMEM - this shape depends on accumulator layout + typedef Shape<1, + WarpDelta::kH * kColumnsPerWarp, // multiple columns in the gemm N dimension + WarpDelta::kW * WarpGemmTile::kW + kSkew, // rows in the gemm M dimension + 1 + > EpilogueTileAllocation; + + /// Parameters structure initialized on the host + struct Params { + /// The params for the C iterator. + typename GlobalLoadStreamC::Params load_stream_c; + + /// The params for the D global iterator. + typename GlobalStoreStreamD::Params store_stream_d; + + /// Epilogue functor params + typename Functor::Params functor; + + /// The params for the D shared store iterator. + typename SharedStoreStreamD::Params shared_store_stream_d; + + /// The params for the D shared load stream. + typename SharedLoadStreamD::Params shared_load_stream_d; + + /// + long long int batch_stride_C; + + /// + long long int batch_stride_D; + + // + // Methods + // + + /// Default constructor + CUTLASS_HOST_DEVICE + Params() {} + + /// Helper constructor taking pointer, stride for source and destination matrices and functor + /// params + CUTLASS_HOST_DEVICE + Params(ScalarD *ptr_D, + int ldd, + ScalarC const *ptr_C, + int ldc, + typename Functor::Params _functor = Functor::Params()) + : load_stream_c(), store_stream_d(), functor(_functor) {} + + /// Setup the params. + template + CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) { + batch_stride_C = desc.batch_stride_C; + batch_stride_D = desc.batch_stride_D; + + // The parameters for the functor. + int error_code = functor.initialize(desc); + if (error_code) { + return error_code; + } + + // Setup the params for the global memory iterator for C. + error_code = load_stream_c.iterator.initialize( + desc.C.data(), desc.C.leading_dim(), desc.C.leading_dim(), 1 + ); + + if (error_code) { + return error_code; + } + + // Setup the params for the global memory iterator for D. + return store_stream_d.iterator.initialize( + desc.D.data(), desc.D.leading_dim(), desc.D.leading_dim(), 1 + ); + } + }; + + /// Shared memory buffer used by epilogue + typedef TileAllocation< + typename SharedStoreStreamD::Iterator::Scalar, + EpilogueTileAllocation> SharedStorage; + + /// Functor computing the offset from the threadblock origin per iteration of + /// the epilogue. + typedef Volta884EpilogueGlobalOffset GlobalOffset; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Volta884 Epilogue helper +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Volta884EpiloguePredicateFunctor; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Functor specialized for the predicate arrangement in the Volta884 epilogue +template +struct Volta884EpiloguePredicateFunctor { + /// Dimensions of the bounding volume + Coord<3> bounds; + + /// Constructs a predicate functor given the bounds of a tensor + CUTLASS_HOST_DEVICE + Volta884EpiloguePredicateFunctor(Coord<3> _bounds) : bounds(_bounds) {} + + /// Computes the predicate given the logical position of an access + CUTLASS_HOST_DEVICE + bool operator()(Coord<3> const &iteration, Coord<3> const &offset) const { + return + (iteration[0] * TileTraits::Delta::kD + iteration[1] * TileTraits::Delta::kH + + offset[1] < bounds[1]) && + (iteration[2] * TileTraits::Delta::kW + offset[2] < bounds[2]); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Functor specialized for the predicate arrangement in the Volta884 epilogue +template +struct Volta884EpiloguePredicateFunctor { + /// Dimensions of the bounding volume + Coord<3> bounds; + + /// Constructs a predicate functor given the bounds of a tensor + CUTLASS_HOST_DEVICE + Volta884EpiloguePredicateFunctor(Coord<3> _bounds) : bounds(_bounds) {} + + /// Computes the predicate given the logical position of an access + CUTLASS_HOST_DEVICE + bool operator()(Coord<3> const &iteration, Coord<3> const &offset) const { + return iteration[1] * TileTraits::Delta::kH + offset[1] < bounds[1] && + iteration[2] * TileTraits::Delta::kW + offset[2] < bounds[2]; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Volta884 Epilogue helper +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to define the traits for a Volta884 Epilogue +template < + typename GemmConfig_, + typename EpilogueFunctor_, + typename MultiplyAdd_ = typename GemmConfig_::MultiplyAdd, + typename Index_ = int> +struct Volta884GemmEpilogueTraitsHelper { + + /// Configuration object defining GEMM properties + typedef GemmConfig_ GemmConfig; + + /// Warp-level tile + typedef typename GemmConfig::AccumulatorsPerWarp WarpGemmShape; + + /// Warp delta + typedef typename ShapeDiv< + typename GemmConfig::OutputTile, + WarpGemmShape>::Shape WarpDelta; + + /// Thread-block scoped tile + typedef typename cutlass::ShapeMul< + WarpGemmShape, + WarpDelta + >::Shape OutputTile; + + /// Multiply-add operation + typedef MultiplyAdd_ MultiplyAdd; + + /// Epilogue functor + typedef EpilogueFunctor_ Functor; + + /// Traits for global tile access + typedef cutlass::gemm::Volta884EpilogueGlobalTileTraits< + WarpGemmShape, + WarpDelta, + 1, + typename MultiplyAdd::ScalarC + > EpilogueGlobalTileTraits; + + /// Iterator to load a slice of the C matrix from global memory + typedef cutlass::TileLoadIterator< + EpilogueGlobalTileTraits, + typename GemmConfig::ScalarC, + cutlass::IteratorAdvance::kW, + cutlass::MemorySpace::kGlobal + > TileLoadIteratorC; + + /// Conversion from C data type to accumulator data type + typedef Convert< + typename TileLoadIteratorC::Fragment, + Fragment + > ConvertSourceFragment; + + /// Iterator to store a slice of the D matrix to global memory + typedef cutlass::TileStoreIterator< + EpilogueGlobalTileTraits, + typename GemmConfig::ScalarD, + cutlass::IteratorAdvance::kW, + cutlass::MemorySpace::kGlobal + > TileStoreIteratorD; + + /// Conversion from accumulator data type to D data type + typedef Convert< + Fragment, + typename TileStoreIteratorD::Fragment + > ConvertDestinationFragment; + + /// Defines traits for an epilogue of a Volta884 GEMM + typedef cutlass::gemm::Volta884EpilogueTraits< + OutputTile, + WarpGemmShape, + WarpDelta, + typename MultiplyAdd::Accumulators, + cutlass::gemm::Volta884SelectAccumulators< + WarpGemmShape, + WarpDelta, + typename MultiplyAdd::ScalarC + >, + cutlass::PredicatedTileLoadStream< + TileLoadIteratorC, + cutlass::gemm::Volta884EpiloguePredicateFunctor< + EpilogueGlobalTileTraits, + typename MultiplyAdd::ScalarC>, + ConvertSourceFragment + >, + cutlass::PredicatedTileStoreStream< + TileStoreIteratorD, + cutlass::gemm::Volta884EpiloguePredicateFunctor< + EpilogueGlobalTileTraits, + typename MultiplyAdd::ScalarC>, + ConvertDestinationFragment + >, + cutlass::TileStoreStream< + cutlass::gemm::Volta884EpilogueSharedStoreIterator< + WarpGemmShape, + WarpDelta, + typename MultiplyAdd::ScalarC, + typename MultiplyAdd::ScalarC + > + >, + cutlass::TileLoadStream< + cutlass::gemm::Volta884EpilogueSharedLoadIterator< + WarpGemmShape, + WarpDelta, + typename MultiplyAdd::ScalarC, + 1, + typename MultiplyAdd::ScalarC + > + >, + Functor + > EpilogueTraits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass + +// clang-format on diff --git a/cutlass/gemm/volta884_gemm_traits.h b/cutlass/gemm/volta884_gemm_traits.h new file mode 100644 index 0000000000..2ca8b00854 --- /dev/null +++ b/cutlass/gemm/volta884_gemm_traits.h @@ -0,0 +1,585 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines structural properties for GEMM targeting Volta's mma.sync instruction +*/ + +#pragma once + +// clang-format off + +#include "cutlass/gemm/clear_accumulators.h" +#include "cutlass/gemm/gemm_config.h" +#include "cutlass/gemm/gemm_global_stream.h" +#include "cutlass/gemm/gemm_stream_pair.h" +#include "cutlass/gemm/threadblock_swizzle.h" +#include "cutlass/gemm/linear_scaling.h" +#include "cutlass/kernel_launch.h" + +#include "cutlass/gemm/gemm_desc.h" +#include "cutlass/gemm/volta884_multiplicand.h" +#include "cutlass/gemm/volta884_multiply_add.h" +#include "cutlass/gemm/mma_global_stream.h" +#include "cutlass/gemm/mma_shared_stream.h" +#include "cutlass/gemm/volta884_gemm_epilogue_traits.h" +#include "cutlass/gemm/mma_epilogue.h" +#include "cutlass/gemm/gemm_mainloop.h" +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines configuration for Volta884 GEMM +template < + /// The layout for A. + MatrixLayout::Kind LayoutA, + /// The layout for B. + MatrixLayout::Kind LayoutB, + /// The tile size for the GEMM KxNxM. + typename OutputTile_, + /// Tile size for warp-level GEMM (K-by-N-by-M) + typename WarpGemmShape_, + /// The accumulator type. + typename Accumulator_, + /// The source matrix type type. + typename ScalarC_, + /// The destination matrix type + typename ScalarD_, + /// Number of stages in shared memory + int StageCount, + + /// If true, kernel is launched with CUDA launch bounds specified + bool kLaunchBounds = true, + /// If true, residue is computed in mainloop. If false, separate loops are instantiated. + bool kResidueSeparate = true, + /// Is residue performed in prologue? + bool kResidueInProlog = false> +struct Volta884GemmConfig : public GemmConfig< + /// The scalar type for A. + half, + /// The scalar type for B. + half, + /// The scalar type for C. + ScalarC_, + /// The scalar type for D. + ScalarD_, + /// The threadblock tile size + OutputTile_, + /// The functor to do the math in the main loop. + Volta884MultiplyAdd, + /// The number of scalars per LDG for A. + 8, + /// The number of scalars per STS for A. + 8, + /// The number of scalars per LDS for A. + 8, + /// The number of scalars per LDG for B. + 8, + /// The number of scalars per STS for B. + 8, + /// The number of scalars per LDS for B. + 8, + /// The number of scalars per LDG for C and STG for D. + 16 / int(sizeof(ScalarD_)), + /// The number of scalars per STS for D. + 16 / int(sizeof(ScalarD_)), + /// The number of scalars per LDS for D. + 16 / int(sizeof(ScalarD_)), + /// The number of stages in shared memory. + StageCount, + /// If true, separate mainloop is instantiated + kResidueSeparate, + /// If true, compute residue in prolog + kResidueInProlog, + /// Launch bounds not used + kLaunchBounds> {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines components of Volta884 GEMM +template < + /// The layout for A. + MatrixLayout::Kind LayoutA, + /// The layout for B. + MatrixLayout::Kind LayoutB, + /// The tile size for the GEMM KxNxM. + typename OutputTile_, + /// Tile size for warp-level GEMM (K-by-N-by-M) + typename WarpGemmShape_, + /// The accumulator type. + typename Accumulator_, + /// The input matrix type type. + typename ScalarC_, + /// The output matrix type type. + typename ScalarD_, + /// Number of buffers in shared memory to use + int StageCount, + /// The functor to do the math in the epilogue. + typename EpilogueFunctor_ = LinearScaling, + /// The block swizzle to reorganize the grid. + typename BlockSwizzle_ = IdentityBlockSwizzle, + /// Selectively enables launch bounds + bool LaunchBounds = false +> +struct Volta884GemmTraits { + /// This traits + typedef Volta884GemmTraits< + LayoutA, + LayoutB, + OutputTile_, + WarpGemmShape_, + Accumulator_, + ScalarC_, + ScalarD_, + StageCount, + EpilogueFunctor_, + BlockSwizzle_, + LaunchBounds> This_; + /// The struct that consumes this Traits + typedef typename cutlass::gemm::GemmMainloop KernelClass; + + /// Layout of multiplicand A matrix + static MatrixLayout::Kind const kLayoutA = LayoutA; + + /// Layout of multiplicand B matrix + static MatrixLayout::Kind const kLayoutB = LayoutB; + + /// Dimensions of threadblock tile (concept Shape) + typedef OutputTile_ OutputTile; + + /// Shape of warp-level accumulators + typedef WarpGemmShape_ WarpGemmShape; + + /// Multiplicand A scalar type + typedef half ScalarA; + + /// Multiplicand B scalar type + typedef half ScalarB; + + /// Data type of internal accumulator + typedef Accumulator_ Accumulator; + + /// Data type of input accumulator matrix operand + typedef ScalarC_ ScalarC; + + /// Data type of output accumulator matrix operand + typedef ScalarD_ ScalarD; + + /// Shape of individual mma.sync instruction + typedef Shape<4, 16, 16> InstructionShape; + + /// Tile size for an individual warp-level multiply-add + typedef Shape WarpTile; + + /// Defines properties about GEMM needed by host code + typedef Volta884GemmConfig + GemmConfig; + + // + // Derived types + // + + /// Index type + typedef int Index; + + /// Partitioning of threadblock into warps + typedef typename ShapeDiv::Shape WarpDelta; + + /// Number of warps per threadblock + static int const kWarpCount = ShapeCount::kCount; + + /// Defines iterators for A matrix + typedef Volta884Multiplicand + MultiplicandA; + + /// Defines iterators for B matrix + typedef Volta884Multiplicand + MultiplicandB; + + // + // GemmTraits mandatory type definitions + // + + /// Maps hardware threadblocks to logical partitions of the GEMM + typedef BlockSwizzle_ BlockSwizzle; + + /// Clears accumulators + typedef ClearAccumulators ClearAccumulators; + + /// Loads multiplicands from global memory + typedef GlobalLoadStreamPair< + MMAGlobalLoadStream, + typename MultiplicandA::StoreIterator, + StageCount>, + MMAGlobalLoadStream, + typename MultiplicandB::StoreIterator, + StageCount>, + GemmConfig::kResidueInProlog > + GlobalLoadStream; + + /// Memory needed to store the threadblock-scoped GEMM tile + typedef typename GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage; + union MainLoopStorage { + + /// Stores the threadblock tile + ThreadblockTileStorage threadblock_tile; + + /// Storage for GEMM global stream + typename GlobalLoadStream::SharedStorage global_to_shared_stream; + }; + + /// Loads multiplicands from shared memory + typedef SharedStreamPair< + MMASharedLoadStream, + StageCount>, + MMASharedLoadStream, + StageCount> > + SharedStream; + + // Multiply-add object specialized for Volta mma.sync + typedef typename GemmConfig::MultiplyAdd MultiplyAdd; + +#if 0 + /// Naive epilogue for updating the output matrix + typedef cutlass::gemm::Volta884NaiveEpilogue + Epilogue; +#else + + /// Efficient epilogue + typedef cutlass::gemm::MMAEpilogue< + typename Volta884GemmEpilogueTraitsHelper< + GemmConfig, + EpilogueFunctor_ + >::EpilogueTraits + > Epilogue; + +#endif + + /// Parameters structure + struct Params : public KernelLaunchConfiguration { + /// The dimensions of the GEMM. + GemmCoord problem_size; + + /// The K range for every partition except the last one + int partitionK_range; + + /// The params for the global load stream + typename GlobalLoadStream::Params global_to_shared_stream; + + /// The params for the shared load stream + typename SharedStream::Params shared_stream; + + /// The params for the epilogue. + typename Epilogue::Params epilogue; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() {} + + /// Initialize the parameters. + template + CUTLASS_HOST_DEVICE Params(GemmDesc_ const& desc) { + initialize(desc); + } + + /// Initialize the Params struct + template + CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) { + + // Problem size + problem_size = desc.problem_size; + + // there is no partitionK in the default case + partitionK_range = problem_size[0]; + // Compute grid dimensions + BlockSwizzle block_swizzle; + this->block = dim3(GemmConfig::kThreads); + this->grid = block_swizzle.get_grid_layout( + problem_size, + make_Coord_from_shape()); + + // Compute offset to residue + Index gemm_k = problem_size[0]; + Index offset_to_residue = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0; + Index offset_to_residue_last_partition = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0; + // Initialize parameters objects for + global_to_shared_stream.stream_a.initialize( + desc.A, + desc.batch_stride_A, + offset_to_residue, + offset_to_residue_last_partition); + + global_to_shared_stream.stream_b.initialize( + desc.B, + desc.batch_stride_B, + offset_to_residue, + offset_to_residue_last_partition); + + // The epilogue. + epilogue.initialize(desc); + return 0; + } + + /// Helper to construct a GEMM params using a BLAS-like API + CUTLASS_HOST_DEVICE int initialize(Index m, + Index n, + Index k, + typename Epilogue::Scalar alpha, + ScalarA const* d_a, + Index lda, + ScalarB const* d_b, + Index ldb, + typename Epilogue::Scalar beta, + ScalarC const* d_c, + Index ldc, + ScalarD* d_d, + Index ldd) { + + GemmDesc desc( + GemmCoord(k, n, m, 1), + alpha, + TensorRef(d_a, lda), + TensorRef(d_b, ldb), + beta, + TensorRef(d_c, ldc), + TensorRef(d_d, ldd) + ); + + return this->initialize(desc); + } + + /// Helper to construct a batched GEMM params + CUTLASS_HOST_DEVICE int initialize(Index m, + Index n, + Index k, + typename Epilogue::Scalar alpha, + ScalarA const* d_a, + Index lda, + long long int batch_stride_A, + ScalarB const* d_b, + Index ldb, + long long int batch_stride_B, + typename Epilogue::Scalar beta, + ScalarC const* d_c, + Index ldc, + long long int batch_stride_C, + ScalarD* d_d, + Index ldd, + long long int batch_stride_D, + Index batch_count) { + + GemmDesc desc( + make_Coord(k, n, m, batch_count), + alpha, + TensorRef(d_a, lda), + batch_stride_A, + TensorRef(d_b, ldb), + batch_stride_B, + beta, + TensorRef(d_c, ldc), + batch_stride_C, + TensorRef(d_d, ldd), + batch_stride_D + ); + + return this->initialize(desc); + } + + /// Helper to construct a partitionedK GEMM params + template + CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc, Index partitionK_count_, Index partitionK_multiple_ = 1) { + // partitionK GEMM is a specialized batched stried gemm with different K ranges per batch + // the problem_size of each batch is (lastK_size, n, m) + // add more comments here + // the k range for every batch excpet the last one + + partitionK_range = partitonK_desc.problem_size.k() / partitionK_count_; + partitionK_range = partitionK_range - (partitionK_range % partitionK_multiple_); + // the k range of the last batch + // int lastK_range = (partitonK_desc.problem_size.k() % partitionK_range) + partitionK_range; + int lastK_range = partitonK_desc.problem_size.k() - partitionK_range * (partitionK_count_ - 1); + + assert((partitionK_range % partitionK_multiple_) == 0); + assert(partitionK_range > 0); + assert((lastK_range % partitionK_multiple_) == 0); + assert(lastK_range > 0); + + int k_size = lastK_range; + int lda = partitonK_desc.A.stride(0); + int ldb = partitonK_desc.B.stride(0); + int ldc = partitonK_desc.C.stride(0); + int ldd = partitonK_desc.D.stride(0); + int n = partitonK_desc.problem_size.n(); + + long long int batch_stride_A = (kLayoutA == cutlass::MatrixLayout::kColumnMajor) ? lda * partitionK_range : partitionK_range; + long long int batch_stride_B = (kLayoutB == cutlass::MatrixLayout::kColumnMajor) ? partitionK_range : partitionK_range * ldb; + long long int batch_stride_C = ldc * n; + long long int batch_stride_D = ldd * n; + + GemmDesc desc( + //we pass lastK_size as per batch K. there is also a range that will match partitionK_size + GemmCoord(k_size, partitonK_desc.problem_size.n(), partitonK_desc.problem_size.m(), partitionK_count_), + partitonK_desc.alpha, + partitonK_desc.A, + batch_stride_A, + partitonK_desc.B, + batch_stride_B, + partitonK_desc.beta, + partitonK_desc.C, + batch_stride_C, + partitonK_desc.D, + batch_stride_D + ); + + // Set the problem size. + problem_size = desc.problem_size; + + // Compute grid dimensions + BlockSwizzle block_swizzle; + this->block = dim3(GemmConfig::kThreads); + this->grid = block_swizzle.get_grid_layout( + problem_size, + make_Coord_from_shape()); + + // Compute offset to residue. + // partitionK_range <= problem_size[0] + Index gemm_k = problem_size[0]; + Index offset_to_residue_last_partition = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0; + Index offset_to_residue = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0; + + // Initialize parameters objects for + int error_code = global_to_shared_stream.stream_a.initialize( + desc.A, + desc.batch_stride_A, + offset_to_residue, + offset_to_residue_last_partition + ); + if (error_code) { + return error_code; + } + + error_code = global_to_shared_stream.stream_b.initialize( + desc.B, + desc.batch_stride_B, + offset_to_residue, + offset_to_residue_last_partition + ); + + if (error_code) { + return error_code; + } + + // The epilogue. + return epilogue.initialize(desc); + } + + /// Helper to construct a partitionedK GEMM params + CUTLASS_HOST_DEVICE int initialize(Index m, + Index n, + Index k, + typename Epilogue::Scalar alpha, + ScalarA const* d_a, + Index lda, + ScalarB const* d_b, + Index ldb, + typename Epilogue::Scalar beta, + ScalarC const* d_c, + Index ldc, + ScalarD* d_d, + Index ldd, + Index partitionK_count_, + Index partitionK_multiple_ = 1) { + + GemmDesc desc( + GemmCoord(k, n, m, 1), + alpha, + TensorRef(d_a, lda), + TensorRef(d_b, ldb), + beta, + TensorRef(d_c, ldc), + TensorRef(d_d, ldd) + ); + + + return this->initialize(desc, partitionK_count_, partitionK_multiple_); + } + }; + + /// Shared memory storage + union SharedStorage { + /// Storage required during mainloop phase + MainLoopStorage main_loop; + + /// Shared storage needed for epilogue + typename Epilogue::SharedStorage epilogue; + }; + + /// The memory fence for shared loads. + static CUTLASS_DEVICE void shared_load_fence(bool in_loop) { + if (StageCount < 2) { + __syncthreads(); + } + } + + /// The memory fence for shared stores. + static CUTLASS_DEVICE void shared_store_fence(bool in_loop) { + __syncthreads(); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass + +// clang-format on diff --git a/cutlass/gemm/volta884_multiplicand.h b/cutlass/gemm/volta884_multiplicand.h new file mode 100644 index 0000000000..8d4b1665fa --- /dev/null +++ b/cutlass/gemm/volta884_multiplicand.h @@ -0,0 +1,298 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines structural properties for GEMM targeting Volta's mma.sync instruction +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/gemm/gemm_operand.h" +#include "cutlass/reshape_tile.h" +#include "cutlass/tile_iterator.h" +#include "cutlass/util/platform.h" + +#include "cutlass/gemm/mma_global_tile.h" +#include "cutlass/gemm/volta884_shared_tile.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines iterators for loading and storing multiplicands +template < + /// Identifies multiplicand of GEMM (A or B) + GemmOperand::Kind Operand, + /// Specifies layout of data in source memory + MatrixLayout::Kind Layout, + /// Specifies threadblock tile shape + typename Tile, + /// Specifies warp tile shape + typename WarpTile, + /// Specifies the number of participating warps + int WarpCount, + /// Specifies the delta between warp tiles + typename WarpDelta_> +struct Volta884Multiplicand; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines iterators for loading and storing multiplicands for A.column_major +template +struct Volta884Multiplicand { + /// Identifies multiplicand of GEMM (A or B) + static GemmOperand::Kind const kOperand = GemmOperand::kA; + + /// Specifies layout of data in source memory + static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor; + + /// Thread-block tile shape + typedef Tile_ Tile; + + /// Warp-level matrix multiply-add shape + typedef WarpTile_ WarpTile; + + /// Total number of participating warps + static int const kWarpCount = WarpCount; + + /// Delta between warp tiles + typedef WarpDelta_ WarpDelta; + + // + // Thread-block load iterator + // + typedef + typename MMAThreadblockCongruousLoad::Iterator + LoadIterator; + + // + // Thread-block store iterator + // + typedef Volta884ThreadblockMultiplicandStoreIterator + StoreIterator; + + // + // Warp-level load iterator + // + typedef Volta884WarpMultiplicandLoadIterator + WarpLoadIterator; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines iterators for loading and storing multiplicands for B.row_major +template +struct Volta884Multiplicand { + /// Identifies multiplicand of GEMM (A or B) + static GemmOperand::Kind const kOperand = GemmOperand::kB; + + /// Specifies layout of data in source memory + static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor; + + /// Thread-block tile shape + typedef Tile_ Tile; + + /// Warp-level matrix multiply-add shape + typedef WarpTile_ WarpTile; + + /// Total number of participating warps + static int const kWarpCount = WarpCount; + + /// Delta between warp tiles + typedef WarpDelta_ WarpDelta; + + // + // Thread-block load iterator + // + typedef + typename MMAThreadblockCongruousLoad::Iterator + LoadIterator; + + // + // Thread-block store iterator + // + typedef Volta884ThreadblockMultiplicandStoreIterator + StoreIterator; + + // + // Warp-level load iterator + // + typedef Volta884WarpMultiplicandLoadIterator + WarpLoadIterator; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines iterators for loading and storing multiplicands for A.row_major +template +struct Volta884Multiplicand { + /// Identifies multiplicand of GEMM (A or B) + static GemmOperand::Kind const kOperand = GemmOperand::kA; + + /// Specifies layout of data in source memory + static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor; + + /// Thread-block tile shape + typedef Tile_ Tile; + + /// Warp-level matrix multiply-add shape + typedef WarpTile_ WarpTile; + + /// Total number of participating warps + static int const kWarpCount = WarpCount; + + /// Delta between warp tiles + typedef WarpDelta_ WarpDelta; + + // + // Thread-block load iterator + // + typedef + typename MMAThreadblockCrosswiseLoad::Iterator + LoadIterator; + + // + // Thread-block store iterator + // + typedef Volta884ThreadblockMultiplicandStoreIterator + StoreIterator; + + // + // Warp-level load iterator + // + typedef Volta884WarpMultiplicandLoadIterator + WarpLoadIterator; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines iterators for loading and storing multiplicands for B.row_major +template +struct Volta884Multiplicand { + /// Identifies multiplicand of GEMM (A or B) + static GemmOperand::Kind const kOperand = GemmOperand::kB; + + /// Specifies layout of data in source memory + static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor; + + /// Thread-block tile shape + typedef Tile_ Tile; + + /// Warp-level matrix multiply-add shape + typedef WarpTile_ WarpTile; + + /// Total number of participating warps + static int const kWarpCount = WarpCount; + + /// Delta between warp tiles + typedef WarpDelta_ WarpDelta; + + // + // Thread-block load iterator + // + typedef + typename MMAThreadblockCrosswiseLoad::Iterator + LoadIterator; + + // + // Thread-block store iterator + // + typedef Volta884ThreadblockMultiplicandStoreIterator + StoreIterator; + + // + // Warp-level load iterator + // + typedef Volta884WarpMultiplicandLoadIterator + WarpLoadIterator; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/volta884_multiply_add.h b/cutlass/gemm/volta884_multiply_add.h new file mode 100644 index 0000000000..0565c1cd0f --- /dev/null +++ b/cutlass/gemm/volta884_multiply_add.h @@ -0,0 +1,704 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements warp-level multiply-accumulate operations using Volta's mma.sync instruction +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/fragment.h" + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Shape of a warp-level GEMM (K-by-N-by-M) + typename WarpGemmShape_, + /// Layout of A multiplicand + MatrixLayout::Kind LayoutA, + /// Data type of A multiplicand + typename ScalarA_, + /// Layout of B multiplicand + MatrixLayout::Kind LayoutB, + /// Data type of A multiplicand + typename ScalarB_, + /// Data type of accumulators + typename ScalarC_> +struct Volta884MultiplyAdd { + // + // Constant and type definitions + // + + /// Shape of a warp-level GEMM (K-by-N-by-M) + typedef WarpGemmShape_ WarpGemmShape; + + /// Shape of a warp-level GEMM (K-by-N-by-M) + typedef WarpGemmShape_ AccumulatorsPerWarp; + + /// Most of the Volta884 code assumes interleaved 32x32 tiles + typedef Shape<4, 32, 32> InterleavedTileShape; + + /// Shape of an individual warp-wide Volta mma.sync instruction + typedef Shape<4, 16, 16> InstructionShape; + + /// Shape of a warp-level matrix multiply operation + typedef Shape WarpTile; + + /// Verify WarpTile is a multiple of fundamental 32x32 interleaved tile + static_assert(!(WarpTile::kH % InterleavedTileShape::kH) && + !(WarpTile::kW % InterleavedTileShape::kW) && WarpTile::kD == 4, + "WarpTile must be a multiple of InterleavedTileShape."); + + /// Layout of A multiplicand + static MatrixLayout::Kind const kLayoutA = LayoutA; + /// Layout of B multiplicand + static MatrixLayout::Kind const kLayoutB = LayoutB; + + /// The type for A. + typedef ScalarA_ ScalarA; + /// The type for B. + typedef ScalarB_ ScalarB; + /// The type for C and D. + typedef ScalarC_ ScalarC; + + /// Hard-coded comptue type supported on Volta + static arch::ComputeType::Kind const kComputeType = arch::ComputeType::kDefault; + + /// Defines a warp-level matrix multiply-accumulate operation performed by a warp. + // + // The layout is as follows. The entire warp performs a 64x64x4 GEMM using Volta mma.sync macros + // arranged as a 2x2 tile of adjacent, 32x32x4 matrix products. These are implemented as a + // 2x2 arrangement of spatially interleaved Volta mma.sync macros. + // + // The Iterations shape maps to the following dimensions of the above warp-level GEMM: + // + // kC: number of rows of Volta mma.sync macros in 32x32x4 tile + // kW: number of columns of Volta mma.sync macros in 32x32x4 tile + // kH: number of rows of 32x32x4 macros in larger 64x64x4 tile + // kD: number of columns of 32x32x4 macros in larger 64x64x4 tile + // + // A column-major ordering would arrange C and H as the inner-most loops, with W and D as the + // outer-most. + // + typedef Shape + Iterations; + + /// Number of multiplicand elements per instruction + static int const kMultElementsPerInst = 4; + + /// Number of multiplicand elements per instruction + static int const kAccumElementsPerInst = 8; + + /// Fragment definition for A multiplicand + typedef Fragment FragmentA; + + /// Fragment definition for B multiplicand + typedef Fragment FragmentB; + + /// Fragment definition for accumulators + typedef Fragment::kCount * kAccumElementsPerInst> Accumulators; + + // + // Methods + // + + /// Ctor. + CUTLASS_DEVICE Volta884MultiplyAdd() {} + + /// Multiply : d = (-)a*b + c. + CUTLASS_DEVICE void multiply_add(FragmentA const& A, + FragmentB const& B, + Accumulators const& C, + Accumulators& D, + bool negate = false) { +// Guard conditional needed for __hneg2 +#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA) + + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { // Outer column + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { // Inner column + CUTLASS_PRAGMA_UNROLL + for (int h_raw = 0; h_raw < Iterations::kH; ++h_raw) { // Outer row + CUTLASS_PRAGMA_UNROLL + for (int c_raw = 0; c_raw < Iterations::kC; ++c_raw) { // Inner row + + int op_col = (w + Iterations::kW * d); + + // Column-major serpentine sequence to maximize reuse of B operand. + int h = h_raw; + int c = c_raw; + + if (op_col & 1) { + h = Iterations::kH - h_raw - 1; + c = Iterations::kC - c_raw - 1; + } + + int op_row = (c + Iterations::kC * h); + int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d)); + + ScalarA operand_A[kMultElementsPerInst]; + + reinterpret_cast(operand_A[0]) = + reinterpret_cast(A[op_row * kMultElementsPerInst]); + + if (negate) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMultElementsPerInst; i += 2) { + reinterpret_cast(operand_A[i]) = + __hneg2(reinterpret_cast(A[op_row * kMultElementsPerInst + i])); + } + } + + // Issue a Volta mma.sync instruction + arch::mma( + + operand_A, //&A[op_row * kMultElementsPerInst], + &B[op_col * kMultElementsPerInst], + &C[op_idx * kAccumElementsPerInst], + &D[op_idx * kAccumElementsPerInst]); + } + } + } + } +#endif // if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <=750 && CUTLASS_ENABLE_TENSOR_CORE_MMA) + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Volta884NaiveEpilogue; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Naive epilogue specialized for f32 accumulators - may be considered authoritative mapping of +/// accumulators to mma.sync operations. +template +struct Volta884NaiveEpilogue { + /// Accumulator data type + typedef float ScalarC; + + /// Output accumulator type + typedef float ScalarD; + + /// BLAS Scalar type + typedef float Scalar; + + /// Delta among warp tiles + typedef WarpDelta_ WarpDelta; + + /// Number of Volta mma.sync operations + typedef Iterations_ Iterations; + + /// Most of the Volta884 code assumes interleaved 32x32 tiles + typedef Shape<4, 32, 32> InterleavedTileShape; + + /// Number of multiplicand elements per instruction + static int const kAccumElementsPerInst = 8; + + /// Fragment definition for accumulators + typedef Fragment::kCount * kAccumElementsPerInst> Accumulators; + + /// Params object + struct Params { + /// Pointer to output matrix + ScalarC* ptr; + + /// stride + int ldm; + + /// Scalar alpha + float alpha; + + /// Scalar beta + float beta; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() : ptr(0), ldm(0), alpha(1), beta(0) {} + + CUTLASS_HOST_DEVICE + Params(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0) + : ptr(_ptr), ldm(_ldm), alpha(_alpha), beta(_beta) {} + + /// Initialize method + CUTLASS_HOST_DEVICE + int initialize(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0) { + ptr = _ptr; + ldm = _ldm; + alpha = _alpha; + beta = _beta; + return 0; + } + + template + CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) { + ptr = reinterpret_cast(desc.D.data()); + ldm = desc.D.leading_dim(); + alpha = desc.alpha; + beta = desc.beta; + return 0; + } + }; + + /// Shared stoarge + struct SharedStorage {}; + + /// Helper used to compute initial offset for each thread + struct InitialOffset { + int row_offset; + int col_offset; + + /// Constructor + CUTLASS_DEVICE + InitialOffset() { + int warp_id = (threadIdx.x >> 5); + int lane_id = (threadIdx.x & 0x1f); + int quad_id = (lane_id >> 2); + int quadpair_id = (quad_id & 0x3); + + int quadpair_row = (quadpair_id & 1); + int quadpair_col = (quadpair_id >> 1); + int quad_hilo = (quad_id >> 2) & 1; + + // compute initial offset + int warp_row_offset = (warp_id % WarpDelta::kW) * InterleavedTileShape::kW; + int warp_col_offset = (warp_id / WarpDelta::kW) * InterleavedTileShape::kH; + + int thread_row_offset = (quadpair_row * 2 + quad_hilo) * 8 + (lane_id & 1); + int thread_col_offset = (quadpair_col * 2) * 8 + (lane_id & 2); + + row_offset = warp_row_offset + thread_row_offset; + col_offset = warp_col_offset + thread_col_offset; + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// Problem size + Coord<3> problem_size; + + // + // Methods + // + + /// Computes initial offset for each thread + CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params, + Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024)) + : params(_params), problem_size(_problem_size) {} + + /// Computes initial offset for each thread + CUTLASS_DEVICE Volta884NaiveEpilogue(ScalarC* _ptr, + int _ldm, + Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024)) + : params(_ptr, _ldm), problem_size(_problem_size) {} + + /// Computes initial offset for each thread + CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params, + SharedStorage& shared_storage, + Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024)) + : params(_params), problem_size(_problem_size) {} + + /// Sets accumulators to zero + CUTLASS_DEVICE void clear(Accumulators& C) { + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < Iterations::kC; ++c) { + int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d)); + + CUTLASS_PRAGMA_UNROLL + for (int reg = 0; reg < kAccumElementsPerInst; ++reg) { + C[op_idx * kAccumElementsPerInst + reg] = 0; + } + } + } + } + } + } + + /// Naive load operation for debugging + CUTLASS_DEVICE void load(Accumulators& C, + Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) { + InitialOffset initial; + + initial.row_offset += threadblock_offset[2]; + initial.col_offset += threadblock_offset[1]; + + ScalarC const* load_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset; + + // loads accumulators + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < Iterations::kC; ++c) { + ScalarC const* op_ptr = load_ptr + h * WarpDelta::kW * InterleavedTileShape::kW + + d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm; + + int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d)); + + CUTLASS_PRAGMA_UNROLL + for (int reg = 0; reg < kAccumElementsPerInst; ++reg) { + int tr = (reg & 2) + c * 4; + int tc = (reg & 1) + (reg & 4) * 2 + w * 4; + + int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr; + int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc; + + if (row < problem_size[2] && column < problem_size[1]) { + C[op_idx * kAccumElementsPerInst + reg] = op_ptr[tr + tc * params.ldm]; + } + } + } + } + } + } + } + + /// Naive store operation for debugging + CUTLASS_DEVICE void store(Accumulators const& C, + Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) { + InitialOffset initial; + + initial.row_offset += threadblock_offset[2]; + initial.col_offset += threadblock_offset[1]; + + ScalarC* store_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset; + + // store out accumulators + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < Iterations::kC; ++c) { + ScalarC* op_ptr = store_ptr + h * WarpDelta::kW * InterleavedTileShape::kW + + d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm; + + int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d)); + + CUTLASS_PRAGMA_UNROLL + for (int reg = 0; reg < kAccumElementsPerInst; ++reg) { + int tr = (reg & 2) + c * 4; + int tc = (reg & 1) + (reg & 4) * 2 + w * 4; + + int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr; + int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc; + + if (row < problem_size[2] && column < problem_size[1]) { + op_ptr[tr + tc * params.ldm] = + params.alpha * C[op_idx * kAccumElementsPerInst + reg] + + params.beta * op_ptr[tr + tc * params.ldm]; + } + } + } + } + } + } + } + + /// CUTLASS Epilogue interface + CUTLASS_DEVICE void epilogue(Accumulators const& C, + Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) { + store(C, threadblock_offset); + } + + CUTLASS_DEVICE void epilogue(Accumulators& C, + Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) { + store(C, threadblock_offset); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Naive epilogue specialized for f16 accumulators - may be considered authoritative mapping of +/// accumulators to mma.sync operations. +template +struct Volta884NaiveEpilogue { + /// Accumulator data type + typedef half ScalarC; + + /// Output accumulator type + typedef half ScalarD; + + /// BLAS Scalar type + typedef half Scalar; + + /// Delta among warp tiles + typedef WarpDelta_ WarpDelta; + + /// Number of Volta mma.sync operations + typedef Iterations_ Iterations; + + /// Most of the Volta884 code assumes interleaved 32x32 tiles + typedef Shape<4, 32, 32> InterleavedTileShape; + + /// Number of multiplicand elements per instruction + static int const kAccumElementsPerInst = 8; + + /// Fragment definition for accumulators + typedef Fragment::kCount * kAccumElementsPerInst> Accumulators; + + /// Params object + struct Params { + /// Pointer to output matrix + ScalarC* ptr; + + /// stride + int ldm; + + /// Scalar alpha + half alpha; + + /// Scalar beta + half beta; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() : ptr(0), ldm(0), alpha(1), beta(0) {} + + CUTLASS_HOST_DEVICE + Params(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0) + : ptr(_ptr), ldm(_ldm), alpha(_alpha), beta(_beta) {} + + /// Initialize method + CUTLASS_HOST_DEVICE + int initialize(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0) { + ptr = _ptr; + ldm = _ldm; + alpha = _alpha; + beta = _beta; + return 0; + } + + template + CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) { + ptr = reinterpret_cast(desc.D.data()); + ldm = desc.D.leading_dim(); + alpha = desc.alpha; + beta = desc.beta; + return 0; + } + }; + + /// Shared stoarge + struct SharedStorage {}; + + /// Helper used to compute initial offset for each thread + struct InitialOffset { + int row_offset; + int col_offset; + + /// Constructor + CUTLASS_DEVICE + InitialOffset() { + int warp_id = (threadIdx.x >> 5); + int lane_id = (threadIdx.x & 0x1f); + int quad_id = (lane_id >> 2); + int quadpair_id = (quad_id & 0x3); + + int quadpair_row = (quadpair_id & 1); + int quadpair_col = (quadpair_id >> 1); + int quad_hilo = (quad_id >> 2) & 1; + + // compute initial offset + int warp_row_offset = (warp_id % WarpDelta::kW) * InterleavedTileShape::kW; + int warp_col_offset = (warp_id / WarpDelta::kW) * InterleavedTileShape::kH; + + int thread_row_offset = (quadpair_row * 2 + quad_hilo) * 8 + (lane_id & 3); + int thread_col_offset = (quadpair_col * 2) * 8; + + row_offset = warp_row_offset + thread_row_offset; + col_offset = warp_col_offset + thread_col_offset; + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// Problem size + Coord<3> problem_size; + + // + // Methods + // + + /// Computes initial offset for each thread + CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params) + : params(_params), problem_size(make_Coord(1024, 1024, 1024)) {} + + /// Computes initial offset for each thread + CUTLASS_DEVICE Volta884NaiveEpilogue(ScalarC* _ptr, int _ldm) + : params(_ptr, _ldm), problem_size(make_Coord(1024, 1024, 1024)) {} + + /// Computes initial offset for each thread + CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params, + SharedStorage& shared_storage, + Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024)) + : params(_params), problem_size(_problem_size) {} + + /// Sets accumulators to zero + CUTLASS_DEVICE void clear(Accumulators& C) { C.clear(); } + + /// Naive load operation for debugging + CUTLASS_DEVICE void load(Accumulators& C, + Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) { + InitialOffset initial; + + initial.row_offset += threadblock_offset[2]; + initial.col_offset += threadblock_offset[1]; + + ScalarC const* load_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset; + + // loads accumulators + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < Iterations::kC; ++c) { + ScalarC const* op_ptr = load_ptr + h * WarpDelta::kW * InterleavedTileShape::kW + + d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm; + + int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d)); + + CUTLASS_PRAGMA_UNROLL + for (int reg = 0; reg < kAccumElementsPerInst; ++reg) { + int tr = c * 4; + int tc = (reg & 3) + (reg & 4) * 2 + w * 4; + + int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr; + int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc; + + if (row < problem_size[2] && column < problem_size[1]) { + C[op_idx * kAccumElementsPerInst + reg] = op_ptr[tr + tc * params.ldm]; + } + } + } + } + } + } + } + + /// Naive store operation for debugging + CUTLASS_DEVICE void store(Accumulators const& C, + Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) { + InitialOffset initial; + + initial.row_offset += threadblock_offset[2]; + initial.col_offset += threadblock_offset[1]; + + ScalarC* store_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset; + + // store out accumulators + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < Iterations::kC; ++c) { + ScalarC* op_ptr = store_ptr + h * WarpDelta::kW * InterleavedTileShape::kW + + d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm; + + int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d)); + + CUTLASS_PRAGMA_UNROLL + for (int reg = 0; reg < kAccumElementsPerInst; ++reg) { + int tr = c * 4; + int tc = (reg & 3) + (reg & 4) * 2 + w * 4; + + int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr; + int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc; + + if (row < problem_size[2] && column < problem_size[1]) { + op_ptr[tr + tc * params.ldm] = + params.alpha * C[op_idx * kAccumElementsPerInst + reg] + + params.beta * op_ptr[tr + tc * params.ldm]; + } + } + } + } + } + } + } + + /// CUTLASS Epilogue interface + CUTLASS_DEVICE void epilogue(Accumulators const& C, + Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) { + store(C, threadblock_offset); + } + + CUTLASS_DEVICE void epilogue(Accumulators& C, + Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) { + store(C, threadblock_offset); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/volta884_shared_tile.h b/cutlass/gemm/volta884_shared_tile.h new file mode 100644 index 0000000000..26165f624e --- /dev/null +++ b/cutlass/gemm/volta884_shared_tile.h @@ -0,0 +1,142 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines structural properties for GEMM targeting Volta's mma.sync instruction +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/gemm/gemm_operand.h" +#include "cutlass/reshape_tile.h" +#include "cutlass/tile_iterator.h" +#include "cutlass/util/platform.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Warp-scoped shared memory load iterators +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +///! Iterator to store a thread-block scoped fragment to shared memory +template < + /// Identifies multiplicand of GEMM (A or B) + GemmOperand::Kind Operand, + /// Specifies layout of data in source memory + MatrixLayout::Kind Layout, + /// Specifies threadblock tile shape + typename Tile, + /// Specifies the number of participating warps + int WarpCount, + /// Specifies the delta between warp accesses along the outer dimension + int WarpDelta> +struct Volta884ThreadblockMultiplicandStoreIterator; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterator to load a fragment for each warp-level tile +template < + /// Identifies multiplicand of GEMM (A or B) + GemmOperand::Kind Operand, + /// Specifies layout of data in source memory + MatrixLayout::Kind Layout, + /// Specifies threadblock tile shape + typename Tile, + /// Specifies the warp tile shape + typename WarpTile, + /// Specifies the number of participating warps + int WarpCount, + /// Specifies the delta between warp accesses along the outer dimension + typename WarpDelta> +struct Volta884WarpMultiplicandLoadIterator; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Fully-specialized implementations extracted in separate headers. +// + +#include "cutlass/gemm/volta884_shared_tile_contiguous.h" +#include "cutlass/gemm/volta884_shared_tile_crosswise.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Epilogue shared memory iterators +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Stores an accumulator fragment to shared memory +template < + /// Shape of warp-level GEMM + typename WarpGemmTile_, + /// Tiling of warp accumulator elements + typename WarpDelta_, + /// Data type of accumulator elements + typename Scalar_, + /// Data type of mma.sync accumulator - this is used to infer layout. + typename Accumulator_> +struct Volta884EpilogueSharedStoreIterator; + +/// Loads an accumulator fragment from shared memory +template < + /// Shape of warp-level GEMM + typename WarpGemmTile_, + /// Tiling of warp accumulator elements + typename WarpDelta_, + /// Data type of accumulator elements + typename Scalar_, + /// Number of scalar elements loaded + int AccessSize_, + /// Data type of mma.sync accumulator - this is used to infer layout. + typename Accumulator_> +struct Volta884EpilogueSharedLoadIterator; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass + +// +// Partially-specialized implementations extracted in separate header. +// + +#include "cutlass/gemm/volta884_shared_tile_epilogue.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass/gemm/volta884_shared_tile_contiguous.h b/cutlass/gemm/volta884_shared_tile_contiguous.h new file mode 100644 index 0000000000..f79f0d262b --- /dev/null +++ b/cutlass/gemm/volta884_shared_tile_contiguous.h @@ -0,0 +1,974 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines structural properties for GEMM targeting Volta's mma.sync instruction +*/ + +#pragma once + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Congruous loading +// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Store iterator specialized for A.column_major +template < + /// Specifies threadblock tile shape + typename Tile_, + /// Specifies the number of participating warps + int WarpCount, + /// Specifies the delta between warp accesses along the outer dimension + int WarpDelta> +struct Volta884ThreadblockMultiplicandStoreIterator { + // + // Constant and type definitions + // + + /// Identifies multiplicand of GEMM (A or B) + static GemmOperand::Kind const kOperand = GemmOperand::kA; + + /// Specifies layout of data in source memory + static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor; + + /// Shape of thread-block multiplicand + typedef Tile_ Tile; + + /// Number of participating warps + static int const kWarpCount = WarpCount; + + /// Delta between warp accumulator tiles along the outer dimension + static int const kWarpDelta = WarpDelta; + + /// This implementation is specialized for 128b loads + static int const kAccessSize = 8; + + /// Swizzled store iterator + struct ThreadOffset { + __device__ Coord<4> operator()() const { + int warp_id = (threadIdx.x >> 5); + int lane_id = (threadIdx.x & 0x1f); + + int k_idx = warp_id; + + // This is an 8-element vector within one 32x32 tile + int vec_idx = lane_id & 3; + int vec_col = (vec_idx / 2); + + int t4t3 = (lane_id >> 3); + int col_rotate = ((lane_id >> 1) & 2) | (lane_id & 1); + + int t_col = (vec_col << 2) | (col_rotate ^ t4t3); + + Coord<4> offset = make_Coord(k_idx, col_rotate, t_col, 0); + + return offset; + } + }; + + /// Projects the threadblock tile + typedef typename GemmMultiplicandTraits::Shape OperandShape; + + /// Stored tile has a structure designed for efficient MIO storing and loading + typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension + (OperandShape::kW >> 4), // four rows of SMEM per 64xK tile + kAccessSize, // Eight banks of MIO + kAccessSize> + VectorizedShape; // 128b stores + + /// Offset between stores + typedef Shape Delta; + + /// Number of iterations + typedef Shape<(VectorizedShape::kD / WarpCount), (OperandShape::kW >> 6), 1, 1> Iterations; + + /// Source tile traits + typedef TileTraits Traits; + + /// Scalar type + typedef half Scalar; + + /// Index type + typedef int Index; + + /// Index type + typedef int LongIndex; + + // + // Derived types + // + + /// Tensor reference + typedef TensorRef TensorRef; + + /// Predicate vector + typedef PredicateVector::kCount> PredicateVector; + + /// Fragment definition + typedef Fragment::kCount * kAccessSize> Fragment; + + /// Elements loaded by one instruction + typedef typename Vectorize::Type AccessType; + + /// The fragment iterator. + typedef FragmentIterator FragmentIterator; + + /// The fragment const iterator. + typedef FragmentConstIterator FragmentConstIterator; + + /// Strides into expected SMEM tile + typedef typename ShapeStrides::Shape Strides; + + /// Memory space access + static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric; + + /// Parameters object + struct Params { + // + // Data members + // + + /// Pointer to element type + Scalar *pointer; + + /// Strides + Coord<4> stride; + + // + // Methods + // + + /// Constructs a parameters object + CUTLASS_HOST_DEVICE + Params(Scalar *_pointer = 0) + : pointer(_pointer), + stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {} + + /// Constructs a params object from a TensorRef + CUTLASS_HOST_DEVICE + Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Constructs a store iterator + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator( + Params const &_params, + Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0), + ThreadOffset offset_func = ThreadOffset()) + : params(_params) { + // Compute initial thread offset + Coord<4> offset = offset_func(); + + params.pointer += (_block_offset + offset).template dot(params.stride); + } + + /// Stores a fragment + CUTLASS_DEVICE void store(Fragment const &fragment, + Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const { + FragmentConstIterator frag_iterator(fragment); + + // Iterate over each store + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + int idx = w + Iterations::kW * h; + + int row = idx * 4; + + Coord<4> sts_offset = + make_Coord(d, row, 0, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC); + + Store::store( + reinterpret_cast(frag_iterator.at(d, h, w, 0)), + params.pointer, + params.stride.template dot(sts_offset + offset)); + } + } + } + } + + /// Increments store iterator to next tile + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &increment(int count = 1) { + params.pointer += + make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot(params.stride); + return *this; + } + + /// Increments to next tile + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator++() { return increment(); } + + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator+=(int count) { + return increment(count); + } + + /// Increments store iterator to previous tile + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &decrement(int count = 1) { + params.pointer -= + make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot(params.stride); + return *this; + } + + /// Increments to subsequent tile + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator--() { return decrement(); } + + /// Decrements to previous tile + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator-=(int count) { + return decrement(count); + } + + /// Stores a fragment and increments in the K dimension + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &store_post_increment( + Fragment const &fragment, Coord<4> const &offset = make_Coord(0, 0, 0, 0)) { + store(fragment, offset); + return increment(); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterator to load a fragment for each warp-level tile specialized for A.column_major +template < + /// Specifies threadblock tile shape + typename Tile_, + /// Specifies the warp tile shape + typename WarpTile_, + /// Specifies the number of participating warps + int WarpCount, + /// Specifies the delta between warp accesses along the outer dimension + typename WarpDelta_> +struct Volta884WarpMultiplicandLoadIterator { + // + // Constant and type definitions + // + + /// Identifies multiplicand of GEMM (A or B) + static GemmOperand::Kind const kOperand = GemmOperand::kA; + + /// Specifies layout of data in source memory + static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor; + + /// Shape of thread-block multiplicand + typedef Tile_ Tile; + + /// Shape of warp-tile matrix operation + typedef WarpTile_ WarpTile; + + /// Hard-coded tile shape + typedef Shape<4, 32, 32> InterleavedTileShape; + + /// Number of participating warps + static int const kWarpCount = WarpCount; + + /// Delta between warp accumulator tiles along the outer dimension + typedef WarpDelta_ WarpDelta; + + /// Two SMEM read pointers are needed + static int const kPointerCount = (WarpDelta::kW == 1 ? 2 : 1); + + /// This implementation is specialized for 128b loads + static int const kAccessSize = 8; + + /// Swizzled store iterator + struct ThreadOffset { + /// Compute thread offset coordinate for each pointer + CUTLASS_DEVICE Coord<4> operator()(int pointer_idx = 0) const { + // Determine the warp's reading location within the SMEM tile + int warp_id = ((threadIdx.x >> 5) % WarpDelta::kW); + + // This is an 8-element vector within one 32x32 tile + int lane_id = (threadIdx.x & 0x1f); + int vec_row = (lane_id >> 4); + int vec_col = ((lane_id & 4) >> 2); + + int tile_row = pointer_idx * 2 + vec_row; + + // Column rotation function + int t_col = (vec_col * 4); + if (pointer_idx == 1 || (WarpDelta::kW > 1 && (warp_id & 1))) { + vec_row |= 2; + } + + t_col = t_col | ((lane_id & 3) ^ vec_row); + + Coord<4> offset = make_Coord(0, warp_id * 2 + tile_row, t_col, 0); + + return offset; + } + }; + + /// Projects the threadblock tile + typedef typename GemmMultiplicandTraits::Shape OperandShape; + + /// Stored tile has a structure designed for efficient MIO storing and loading + typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension + (OperandShape::kW >> 4), // four rows of SMEM per 64xK tile + kAccessSize, // Eight banks of MIO + kAccessSize> + VectorizedShape; // 128b stores + + /// Offset between acceses + typedef typename platform::conditional, + Shape<1, 2 * WarpDelta::kW, 0, 0> >::type Delta; + + /// Number of iterations + typedef Shape<1, WarpTile::kW / InterleavedTileShape::kW, 1, 1> Iterations; + + /// Source tile traits + typedef TileTraits Traits; + + /// Scalar type + typedef half Scalar; + + /// Index type + typedef int Index; + + /// Index type + typedef int LongIndex; + + // + // Derived types + // + + /// Tensor reference + typedef TensorRef TensorRef; + + /// Predicate vector + typedef PredicateVector::kCount> PredicateVector; + + /// Fragment definition + typedef Fragment::kCount * kAccessSize> Fragment; + + /// Elements loaded by one instruction + typedef typename Vectorize::Type AccessType; + + /// The fragment iterator. + typedef FragmentIterator FragmentIterator; + + /// The fragment const iterator. + typedef FragmentConstIterator FragmentConstIterator; + + /// Strides into expected SMEM tile + typedef typename ShapeStrides::Shape Strides; + + /// Memory space access + static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric; + + /// Parameters object + struct Params { + // + // Data members + // + + /// Base pointer to SMEM allocation + Scalar const *pointer; + + /// SMEM strides + Coord<4> stride; + + // + // Methods + // + + /// Constructs a parameters object + CUTLASS_HOST_DEVICE + Params(Scalar const *_pointer = 0) + : pointer(_pointer), + stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {} + + /// Constructs a params object from a TensorRef + CUTLASS_HOST_DEVICE + Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { } + }; + + // + // Data members + // + + // A.column requires two SMEM pointers. + // Because Params only supplies a base pointer and strides, there is no usual params + // data member. Instead, it is used to initialize the following. + + /// Pointer to SMEM allocation. + Scalar const *pointer[kPointerCount]; + + /// SMEM strides + Coord<4> stride; + + // + // Methods + // + + /// Constructs a load iterator + CUTLASS_DEVICE Volta884WarpMultiplicandLoadIterator( + Params const &_params, + Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0), + ThreadOffset offset_func = ThreadOffset()) + : stride(_params.stride) { + CUTLASS_PRAGMA_UNROLL + for (int idx = 0; idx < kPointerCount; ++idx) { + Coord<4> offset = offset_func(idx); + + pointer[idx] = _params.pointer + (_block_offset + offset).template dot(stride); + } + } + + /// Loads a fragment + CUTLASS_DEVICE void load(Fragment &fragment, + Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const { + FragmentIterator frag_iterator(fragment); + + // Iterate over each load + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + // Pointers mapped to Iterations::kH dimension + Scalar const *_pointer = pointer[(kPointerCount == 2 ? h : 0)]; + + Coord<4> lds_offset = + make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC); + + Load::load( + reinterpret_cast(frag_iterator.at(d, h, w, 0)), + _pointer, + stride.template dot(lds_offset + offset)); + } + } + } + } + + /// Loads a fragment and increments to next K-index + CUTLASS_DEVICE void load_post_increment(Fragment &fragment, + Coord<4> const &offset = make_Coord(0, 0, 0, 0)) { + load(fragment, offset); + + for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) { + pointer[ptr_idx] += make_Coord(1, 0, 0, 0).template dot(stride); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Store iterator specialized for B.row_major +template < + /// Specifies threadblock tile shape + typename Tile_, + /// Specifies the number of participating warps + int WarpCount, + /// Specifies the delta between warp accesses along the outer dimension + int WarpDelta> +struct Volta884ThreadblockMultiplicandStoreIterator { + // + // Constant and type definitions + // + + /// Identifies multiplicand of GEMM (A or B) + static GemmOperand::Kind const kOperand = GemmOperand::kB; + + /// Specifies layout of data in source memory + static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor; + + /// Shape of thread-block multiplicand + typedef Tile_ Tile; + + /// Number of participating warps + static int const kWarpCount = WarpCount; + + /// Delta between warp accumulator tiles along the outer dimension + static int const kWarpDelta = WarpDelta; + + /// This implementation is specialized for 128b loads + static int const kAccessSize = 8; + + /// Index type + typedef int Index; + + /// Index type + typedef int LongIndex; + + /// Swizzled store iterator + struct ThreadOffset { + CUTLASS_DEVICE Coord<4> operator()() const { + int warp_id = (threadIdx.x >> 5); + int lane_id = (threadIdx.x & 0x1f); + + int k_idx = warp_id; + + // This is an 8-element vector within one 32x32 tile + int vec_idx = lane_id & 3; + int vec_col = (vec_idx / 2); + + int t4t3 = (lane_id >> 3); + int col_rotate = ((lane_id >> 1) & 2) | (lane_id & 1); + + int t_col = (vec_col << 2) | (col_rotate ^ t4t3); + + Coord<4> offset = make_Coord(k_idx, col_rotate , t_col, 0); + + return offset; + } + }; + + /// Projects the threadblock tile + typedef typename GemmMultiplicandTraits::Shape OperandShape; + + /// Stored tile has a structure designed for efficient MIO storing and loading + typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension + (OperandShape::kW >> 4), // four rows of SMEM per 64xK tile + kAccessSize, // Eight banks of MIO + kAccessSize> + VectorizedShape; // 128b stores + + /// Offset between stores + typedef Shape Delta; + + /// Number of iterations + typedef Shape<(VectorizedShape::kD / WarpCount), (OperandShape::kW >> 6), 1, 1> Iterations; + + /// Source tile traits + typedef TileTraits Traits; + + /// Scalar type + typedef half Scalar; + + // + // Derived types + // + + /// Tensor reference + typedef TensorRef TensorRef; + + /// Predicate vector + typedef PredicateVector::kCount> PredicateVector; + + /// Fragment definition + typedef Fragment::kCount * kAccessSize> Fragment; + + /// Elements loaded by one instruction + typedef typename Vectorize::Type AccessType; + + /// The fragment iterator. + typedef FragmentIterator FragmentIterator; + + /// The fragment const iterator. + typedef FragmentConstIterator FragmentConstIterator; + + /// Strides into expected SMEM tile + typedef typename ShapeStrides::Shape Strides; + + /// Memory space access + static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric; + + /// Parameters object + struct Params { + // + // Data members + // + + /// Pointer to element type + Scalar *pointer; + + /// Strides + Coord<4> stride; + + // + // Methods + // + + /// Constructs a parameters object + CUTLASS_HOST_DEVICE + Params(Scalar *_pointer = 0) + : pointer(_pointer), + stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {} + + /// Constructs a params object from a TensorRef + CUTLASS_HOST_DEVICE + Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Constructs a store iterator + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator( + Params const &_params, + Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0), + ThreadOffset offset_func = ThreadOffset()) + : params(_params) { + // Compute initial offset for each thread + Coord<4> offset = offset_func(); + + params.pointer += (_block_offset + offset).template dot(params.stride); + } + + /// Stores a fragment + CUTLASS_DEVICE void store(Fragment const &fragment, + Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const { + FragmentConstIterator frag_iterator(fragment); + + // Iterate over each store + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + int idx = w + Iterations::kW * h; + int row = idx * 4; + + Coord<4> sts_offset = + make_Coord(d, row, 0, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC); + + Index _offset = params.stride.template dot(sts_offset + offset); + + Store::store( + reinterpret_cast(frag_iterator.at(d, h, w, 0)), + params.pointer, + _offset); + } + } + } + } + + /// Increments store iterator to next tile + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &increment(int count = 1) { + params.pointer += + make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot(params.stride); + return *this; + } + + /// Increments to next tile + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator++() { return increment(); } + + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator+=(int count) { + return increment(count); + } + + /// Increments store iterator to previous tile + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &decrement(int count = 1) { + params.pointer -= + make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot(params.stride); + return *this; + } + + /// Increments to subsequent tile + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator--() { return decrement(); } + + /// Decrements to previous tile + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator-=(int count) { + return decrement(count); + } + + /// Stores a fragment and increments in the K dimension + CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &store_post_increment( + Fragment const &fragment, Coord<4> const &offset = make_Coord(0, 0, 0, 0)) { + store(fragment, offset); + return increment(); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterator to load a fragment for each warp-level tile specialized for B.row_major +template < + /// Specifies threadblock tile shape + typename Tile_, + /// Specifies the warp tile shape + typename WarpTile_, + /// Specifies the number of participating warps + int WarpCount, + /// Specifies the delta between warp accesses along the outer dimension + typename WarpDelta_> +struct Volta884WarpMultiplicandLoadIterator { + // + // Constant and type definitions + // + + /// Identifies multiplicand of GEMM (A or B) + static GemmOperand::Kind const kOperand = GemmOperand::kB; + + /// Specifies layout of data in source memory + static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor; + + /// Shape of thread-block multiplicand + typedef Tile_ Tile; + + /// Shape of warp-tile matrix operation + typedef WarpTile_ WarpTile; + + /// Hard-coded tile shape + typedef Shape<4, 32, 32> InterleavedTileShape; + + /// Number of participating warps + static int const kWarpCount = WarpCount; + + /// Delta between warp accumulator tiles along the outer dimension + typedef WarpDelta_ WarpDelta; + + /// This implementation is specialized for 128b loads + static int const kAccessSize = 8; + + /// Swizzled store iterator + struct ThreadOffset { + /// Computes the initial offset + CUTLASS_DEVICE Coord<4> operator()(int pointer_idx) const { + // Determine the warp's reading location within the SMEM tile + int warp_id = ((threadIdx.x >> 5) / WarpDelta::kW); + + // This is an 8-element vector within one 32x32 tile + int lane_id = (threadIdx.x & 0x1f); + int vec_row = (lane_id >> 4); + int vec_col = ((lane_id & 8) >> 3); + + int tile_row = pointer_idx * 2 + vec_row; + + // Column rotation function + int t_col = (vec_col * 4); + if (pointer_idx == 1 || (WarpDelta::kH > 1 && (warp_id & 1))) { + vec_row |= 2; + } + + t_col = t_col | ((lane_id & 3) ^ vec_row); + Coord<4> offset = make_Coord(0, warp_id * 2 + tile_row, t_col, 0); + + return offset; + } + }; + + /// Projects the threadblock tile + typedef typename GemmMultiplicandTraits::Shape OperandShape; + + /// Stored tile has a structure designed for efficient MIO storing and loading + typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension + (OperandShape::kW >> 4), // four rows of SMEM per 64xK tile + kAccessSize, // Eight banks of MIO + kAccessSize> + VectorizedShape; // 128b stores + + /// Delta between accesses + typedef typename platform::conditional, + Shape<1, 2 * WarpDelta::kH, 0, 0> >::type Delta; + + /// Number of iterations + typedef Shape<1, WarpTile::kH / InterleavedTileShape::kH, 1, 1> Iterations; + + /// Source tile traits + typedef TileTraits Traits; + + /// Scalar type + typedef half Scalar; + + /// Index type + typedef int Index; + + /// Index type + typedef int LongIndex; + + // + // Derived types + // + + /// Tensor reference + typedef TensorRef TensorRef; + + /// Predicate vector + typedef PredicateVector::kCount> PredicateVector; + + /// Fragment definition + typedef Fragment::kCount * kAccessSize> Fragment; + + /// Elements loaded by one instruction + typedef typename Vectorize::Type AccessType; + + /// The fragment iterator. + typedef FragmentIterator FragmentIterator; + + /// The fragment const iterator. + typedef FragmentConstIterator FragmentConstIterator; + + /// Strides into expected SMEM tile + typedef typename ShapeStrides::Shape Strides; + + /// Memory space access + static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric; + + /// Number of SMEM read pointers needed + static int const kPointerCount = (WarpDelta::kH == 1 ? 2 : 1); + + /// Parameters object + struct Params { + // + // Data members + // + + /// Pointer to element type + Scalar const *pointer; + + /// Strides + Coord<4> stride; + + // + // Methods + // + + /// Constructs a parameters object + CUTLASS_HOST_DEVICE + Params(Scalar const *_pointer = 0) + : pointer(_pointer), + stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {} + + /// Constructs a params object from a TensorRef + CUTLASS_HOST_DEVICE + Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { } + }; + + // + // Data members + // + + /// Pointer to element type + Scalar const *pointer[kPointerCount]; + + /// Strides + Coord<4> stride; + + // + // Methods + // + + /// Constructs a load iterator + CUTLASS_DEVICE Volta884WarpMultiplicandLoadIterator( + Params const &_params, + Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0), + ThreadOffset offset_func = ThreadOffset()) + : stride(_params.stride) { + CUTLASS_PRAGMA_UNROLL + for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) { + Coord<4> offset = offset_func(ptr_idx); + + pointer[ptr_idx] = _params.pointer + (_block_offset + offset).template dot(stride); + } + } + + /// Stores a fragment + CUTLASS_DEVICE void load(Fragment &fragment, + Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const { + FragmentIterator frag_iterator(fragment); + + // Iterate over each load + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + // Pointers mapped to Iterations::kH dimension + Scalar const *_pointer = pointer[(kPointerCount == 2 ? h : 0)]; + + Coord<4> lds_offset = + make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC); + + Load::load( + reinterpret_cast(frag_iterator.at(d, h, w, 0)), + _pointer, + stride.template dot(lds_offset + offset)); + } + } + } + } + + /// Loads a fragment and increments to next K-index + CUTLASS_DEVICE void load_post_increment(Fragment &fragment, + Coord<4> const &offset = make_Coord(0, 0, 0, 0)) { + load(fragment, offset); + + CUTLASS_PRAGMA_UNROLL + for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) { + pointer[ptr_idx] += make_Coord(1, 0, 0, 0).template dot(stride); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass + diff --git a/cutlass/gemm/volta884_shared_tile_crosswise.h b/cutlass/gemm/volta884_shared_tile_crosswise.h new file mode 100644 index 0000000000..361c791a1e --- /dev/null +++ b/cutlass/gemm/volta884_shared_tile_crosswise.h @@ -0,0 +1,1063 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines structural properties for GEMM targeting Volta's mma.sync instruction +*/ + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Crosswise loading +// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Store iterator specialized for A.row_major +template < + /// Specifies threadblock tile shape + typename Tile_, + /// Specifies the number of participating warps + int WarpCount, + /// Specifies the delta between warp accesses along the outer dimension + int WarpDelta> +struct Volta884ThreadblockMultiplicandStoreIterator { + // + // Assertions + // + + // Crosswise loaders may only span 32 x 128b along the K dimension + static_assert(!(Tile_::kW % 8) && (Tile_::kW <= 256), + "Tile dimensions must be divisible by 8 elements, and the K dimension may not span " + "more than what a single warp can load"); + + // + // Constant and type definitions + // + + /// Identifies multiplicand of GEMM (A or B) + static GemmOperand::Kind const kOperand = GemmOperand::kA; + + /// Specifies layout of data in source memory + static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor; + + /// Shape of thread-block multiplicand + typedef Tile_ Tile; + + /// Number of participating warps + static int const kWarpCount = WarpCount; + + /// Delta between warp accumulator tiles along the outer dimension + static int const kWarpDelta = WarpDelta; + + /// LDG.128 loads + static int const kLdgAccessSize = 8; + + /// This implementation is specialized for 64b loads + static int const kAccessSize = 4; + + /// Projects the threadblock tile + typedef typename GemmMultiplicandTraits::Shape OperandShape; + + /// Stored tile has a structure designed for efficient MIO storing and loading + typedef Shape<(OperandShape::kW >> 2), // one 3D tile per four elements in the K dimension + (OperandShape::kH >> 4), // four rows of SMEM per 64xK tile + 16, // Sixteen banks of MIO + kAccessSize> + VectorizedShape; + + /// Offset between stores + typedef Shape Delta; + + /// Shape of tile + typedef Shape<1, 8, 4> WarpStoreCoverage; + + /// Number of iterations + typedef Shape< + // # of LDG.128s along the strided (outer) dimension + OperandShape::kH / (WarpStoreCoverage::kH * kWarpCount), + // # of LDG.128s along the contiguous (K) dimension + OperandShape::kW / (WarpStoreCoverage::kW * kLdgAccessSize), + // # STSs per LDG + (kLdgAccessSize / kAccessSize), + 1> + Iterations; + + /// Swizzled store iterator + struct ThreadOffset { + __device__ Coord<4> operator()(int ptr_idx) const { + int warp_id = (threadIdx.x >> 5); + int lane_id = (threadIdx.x & 0x1f); + + // Assumes a contiguous/blocked warp loading strategy + int load_tile_idx = warp_id * Iterations::kD; + + // Compute swizzled destination address + int lane_w = (lane_id % WarpStoreCoverage::kW); + int store_k_idx = lane_w * 2; + + int dest_tile_idx = load_tile_idx / 4; + + int dest_row = ((load_tile_idx >> 1) & 1); + int dest_bank = (lane_id & 0x0f) ^ ((lane_id >> 4) & 1) ^ (ptr_idx << 1); + + Coord<4> offset = make_Coord(store_k_idx, dest_tile_idx * 2 + dest_row, dest_bank, 0); + + return offset; + } + }; + + /// Source tile traits + typedef TileTraits Traits; + + /// Scalar type + typedef half Scalar; + + /// Index type + typedef int Index; + + /// Index type + typedef int LongIndex; + + // + // Derived types + // + + /// Tensor reference + typedef TensorRef TensorRef; + + /// Predicate vector + typedef PredicateVector::kCount> PredicateVector; + + /// Fragment definition + typedef Fragment::kCount * kAccessSize> Fragment; + + /// Elements loaded by one instruction + typedef typename Vectorize::Type AccessType; + + /// The fragment iterator. + typedef FragmentIterator FragmentIterator; + + /// The fragment const iterator. + typedef FragmentConstIterator FragmentConstIterator; + + /// Strides into expected SMEM tile + typedef typename ShapeStrides::Shape Strides; + + /// Memory space access + static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric; + + /// Store iterators require two pointers + static int const kPointerCount = 2; + + /// Parameters object + struct Params { + // + // Data members + // + + /// Pointer to element type + Scalar *pointer; + + /// Strides + Coord<4> stride; + + // + // Methods + // + + /// Constructs a parameters object + CUTLASS_HOST_DEVICE + Params(Scalar *_pointer = 0) + : pointer(_pointer), + stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {} + + /// Constructs a params object from a TensorRef + CUTLASS_HOST_DEVICE + Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { } + }; + + // + // Data members + // + + /// Pointer to element type + Scalar *pointer[kPointerCount]; + + /// Strides + Coord<4> stride; + + // + // Methods + // + + /// Constructs a store iterator + __device__ Volta884ThreadblockMultiplicandStoreIterator( + Params const &_params, + Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0), + ThreadOffset offset_func = ThreadOffset()) + : stride(_params.stride) { + // Initialize each pointer + CUTLASS_PRAGMA_UNROLL + for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) { + Coord<4> offset = offset_func(ptr_idx); + pointer[ptr_idx] = _params.pointer + (_block_offset + offset).template dot(stride); + } + } + + /// Stores a fragment + __device__ void store(Fragment const &fragment, + Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const { + FragmentConstIterator frag_iterator(fragment); + + // Iterate over each store + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { // strided LDG.128s + + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { // contiguous LDG.128s + + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { // 2x STS operations per LDG + + int warp_id = (threadIdx.x >> 5); + + int ldg_idx = d + warp_id * Iterations::kD; + int k_idx = w + h * 8; + int smem_row = (d >> 1); + + // Two store pointers + int ptr_idx = ((ldg_idx & 1) ^ ((ldg_idx >> 1) & 1)); + + Scalar *_pointer = pointer[ptr_idx]; + Coord<4> sts_offset = make_Coord(k_idx, smem_row, 0, 0); + + Store::store( + reinterpret_cast(frag_iterator.at(d, h, w, 0)), + _pointer, + stride.template dot(sts_offset + offset)); + } + } + } + } + + /// Increments store iterator to next tile + __device__ Volta884ThreadblockMultiplicandStoreIterator &increment(int count = 1) { + for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) { + pointer[ptr_idx] += + make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot(stride); + } + return *this; + } + + /// Increments to next tile + __device__ Volta884ThreadblockMultiplicandStoreIterator &operator++() { return increment(1); } + + __device__ Volta884ThreadblockMultiplicandStoreIterator &operator+=(int count) { + return increment(count); + } + + /// Increments store iterator to previous tile + __device__ Volta884ThreadblockMultiplicandStoreIterator &decrement(int count = 1) { + for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) { + pointer[ptr_idx] -= + make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot(stride); + } + return *this; + } + + /// Increments to subsequent tile + __device__ Volta884ThreadblockMultiplicandStoreIterator &operator--() { return decrement(1); } + + /// Decrements to previous tile + __device__ Volta884ThreadblockMultiplicandStoreIterator &operator-=(int count) { + return decrement(count); + } + + /// Stores a fragment and increments in the K dimension + __device__ Volta884ThreadblockMultiplicandStoreIterator &store_post_increment( + Fragment const &fragment, Coord<4> const &offset = make_Coord(0, 0, 0, 0)) { + store(fragment, offset); + return increment(); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterator to load a fragment for each warp-level tile specialized for A.row_major +template < + /// Specifies threadblock tile shape + typename Tile_, + /// Specifies the shape of thewarp tile + typename WarpTile_, + /// Specifies the number of participating warps + int WarpCount, + /// Specifies the delta between warp accesses along the outer dimension + typename WarpDelta_> +struct Volta884WarpMultiplicandLoadIterator { + // + // Constant and type definitions + // + + /// Identifies multiplicand of GEMM (A or B) + static GemmOperand::Kind const kOperand = GemmOperand::kA; + + /// Specifies layout of data in source memory + static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor; + + /// Shape of thread-block multiplicand + typedef Tile_ Tile; + + /// Shape of warp-tile matrix operation + typedef WarpTile_ WarpTile; + + /// Hard-coded tile shape + typedef Shape<4, 32, 32> InterleavedTileShape; + + /// Number of participating warps + static int const kWarpCount = WarpCount; + + /// Delta between warp accumulator tiles along the outer dimension + typedef WarpDelta_ WarpDelta; + + /// This implementation is specialized for 128b loads + static int const kAccessSize = 8; + + /// Swizzled store iterator + struct ThreadOffset { + /// Compute thread offset coordinate for each pointer + __device__ Coord<4> operator()(int ptr_idx) const { + int warp_id = ((threadIdx.x >> 5) % WarpDelta::kW); + int lane_id = (threadIdx.x & 0x1f); + + int lane_in_quad = (lane_id & 0x3); + int quad_id = ((lane_id >> 2) & 0x7); + + int oct_row_id = ((quad_id >> 1) & 2) | (quad_id & 1); + int oct_row = (oct_row_id & 1); + int oct_left_right = (oct_row_id & 1) ^ ((oct_row_id >> 1) & 1) ^ ptr_idx; + + Coord<4> offset = make_Coord(0, warp_id * 2 + oct_row, lane_in_quad * 2 + oct_left_right, 0); + + return offset; + } + }; + + /// Projects the threadblock tile + typedef typename GemmMultiplicandTraits::Shape OperandShape; + + /// Loaded tile has a structure designed for efficient MIO storing and loading + typedef Shape<(OperandShape::kW >> 2), // one 3D tile per four elements in the K dimension + (OperandShape::kH >> 4), // four rows of SMEM per 64xK tile + 8, // Eight banks of MIO + kAccessSize> + VectorizedShape; + + /// Offset between acceses + typedef Shape<1, 2 * WarpDelta::kW, 1, 1> Delta; + + /// Number of iterations + typedef Shape<1, WarpTile::kW / InterleavedTileShape::kW, 1, 1> Iterations; + + /// Source tile traits + typedef TileTraits Traits; + + /// Scalar type + typedef half Scalar; + + /// Index type + typedef int Index; + + /// Index type + typedef int LongIndex; + + // + // Derived types + // + + /// Tensor reference + typedef TensorRef TensorRef; + + /// Predicate vector + typedef PredicateVector::kCount> PredicateVector; + + /// Fragment definition + typedef Fragment::kCount * kAccessSize> Fragment; + + /// Elements loaded by one instruction + typedef typename Vectorize::Type AccessType; + + /// The fragment iterator. + typedef FragmentIterator FragmentIterator; + + /// The fragment const iterator. + typedef FragmentConstIterator FragmentConstIterator; + + /// Strides into expected SMEM tile + typedef typename ShapeStrides::Shape Strides; + + /// Memory space access + static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric; + + /// Pointer count is always two + static int const kPointerCount = 2; + + /// Parameters object + struct Params { + // + // Data members + // + + /// Base pointer to SMEM allocation + Scalar const *pointer; + + /// SMEM strides + Coord<4> stride; + + // + // Methods + // + + /// Constructs a parameters object + CUTLASS_HOST_DEVICE + Params(Scalar const *_pointer = 0) + : pointer(_pointer), + stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {} + + /// Constructs a params object from a TensorRef + CUTLASS_HOST_DEVICE + Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { } + }; + + // + // Data members + // + + /// Shared memory load pointer + Scalar const *pointer[kPointerCount]; + + /// SMEM strides + Coord<4> stride; + + /// Index in D dimension - needed to permute loads + int k_index; + + // + // Methods + // + + /// Constructs a load iterator + __device__ Volta884WarpMultiplicandLoadIterator( + Params const &_params, + Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0), + ThreadOffset offset_func = ThreadOffset()) + : stride(_params.stride), k_index(0) { + CUTLASS_PRAGMA_UNROLL + for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) { + Coord<4> offset = offset_func(ptr_idx); + + pointer[ptr_idx] = _params.pointer + (_block_offset + offset).template dot(stride); + } + } + + /// Stores a fragment + __device__ void load(Fragment &fragment, Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const { + FragmentIterator frag_iterator(fragment); + + // Iterate over each load + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + Coord<4> lds_offset = + make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC); + + int ptr_idx = ((offset[0] >> 2) & 1); + Scalar const *_pointer = pointer[ptr_idx]; + + Load::load( + reinterpret_cast(frag_iterator.at(d, h, w, 0)), + _pointer, + stride.template dot(lds_offset + offset)); + + if (offset[0] & 2) { + // peculiar swap for crosswise loads + int lds128_idx = w + Iterations::kW * (h + Iterations::kH * d); + uint64_t *left = reinterpret_cast(&fragment) + lds128_idx * 2; + uint64_t *right = reinterpret_cast(&fragment) + lds128_idx * 2 + 1; + uint64_t tmp = *left; + *left = *right; + *right = tmp; + } + } + } + } + } + + /// Loads a fragment and increments to next K-index + __device__ void load_post_increment(Fragment &fragment, + Coord<4> const &offset = make_Coord(0, 0, 0, 0)) { + load(fragment, offset + make_Coord(k_index, 0, 0, 0)); + ++k_index; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Store iterator specialized for B.column_major +template < + /// Specifies threadblock tile shape + typename Tile_, + /// Specifies the number of participating warps + int WarpCount, + /// Specifies the delta between warp accesses along the outer dimension + int WarpDelta> +struct Volta884ThreadblockMultiplicandStoreIterator { + // + // Assertions + // + + // Crosswise loaders may only span 32 x 128b along the K dimension + static_assert(!(Tile_::kW % 8) && (Tile_::kW <= 256), + "Tile dimensions must be divisible by 8 elements, and the K dimension may not span " + "more than what a single warp can load"); + + // + // Constant and type definitions + // + + /// Identifies multiplicand of GEMM (A or B) + static GemmOperand::Kind const kOperand = GemmOperand::kB; + + /// Specifies layout of data in source memory + static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor; + + /// Shape of thread-block multiplicand + typedef Tile_ Tile; + + /// Number of participating warps + static int const kWarpCount = WarpCount; + + /// Delta between warp accumulator tiles along the outer dimension + static int const kWarpDelta = WarpDelta; + + /// LDG.128 loads + static int const kLdgAccessSize = 8; + + /// This implementation is specialized for 64b loads + static int const kAccessSize = 4; + + /// Projects the threadblock tile + typedef typename GemmMultiplicandTraits::Shape OperandShape; + + /// Stored tile has a structure designed for efficient MIO storing and loading + typedef Shape<(OperandShape::kW >> 2), // one 3D tile per four elements in the K dimension + (OperandShape::kH >> 4), // four rows of SMEM per 64xK tile + 16, // Sixteen banks of MIO + kAccessSize> + VectorizedShape; + + /// Offset between stores + typedef Shape Delta; + + /// Shape of tile + typedef Shape<1, 8, 4> WarpStoreCoverage; + + /// Number of iterations + typedef Shape< + // # of LDG.128s along the strided (outer) dimension + OperandShape::kH / (WarpStoreCoverage::kH * kWarpCount), + // # of LDG.128s along the contiguous (K) dimension + OperandShape::kW / (WarpStoreCoverage::kW * kLdgAccessSize), + // # STSs per LDG + (kLdgAccessSize / kAccessSize), + 1> + Iterations; + + /// Swizzled store iterator + struct ThreadOffset { + __device__ Coord<4> operator()(int ptr_idx) const { + int warp_id = (threadIdx.x >> 5); + int lane_id = (threadIdx.x & 0x1f); + + // Assumes a contiguous/blocked warp loading strategy + int load_tile_idx = warp_id * Iterations::kD; + + // if Iterations::kD < 4, then we need to permute pointers + if (Iterations::kD == 2) { + ptr_idx ^= (warp_id & 1); + } + + // Compute swizzled destination address + int lane_w = (lane_id % WarpStoreCoverage::kW); + int store_k_idx = lane_w * 2; + + int dest_tile_idx = load_tile_idx / 4; + + int dest_row = ((load_tile_idx >> 1) & 1); + int dest_bank = (lane_id & 0x0f) ^ ((lane_id >> 4) & 1) ^ (ptr_idx << 1); + + Coord<4> offset = make_Coord(store_k_idx, dest_tile_idx * 2 + dest_row, dest_bank, 0); + + return offset; + } + }; + + /// Source tile traits + typedef TileTraits Traits; + + /// Scalar type + typedef half Scalar; + + /// Index type + typedef int Index; + + /// Index type + typedef int LongIndex; + + // + // Derived types + // + + /// Tensor reference + typedef TensorRef TensorRef; + + /// Predicate vector + typedef PredicateVector::kCount> PredicateVector; + + /// Fragment definition + typedef Fragment::kCount * kAccessSize> Fragment; + + /// Elements loaded by one instruction + typedef typename Vectorize::Type AccessType; + + /// The fragment iterator. + typedef FragmentIterator FragmentIterator; + + /// The fragment const iterator. + typedef FragmentConstIterator FragmentConstIterator; + + /// Strides into expected SMEM tile + typedef typename ShapeStrides::Shape Strides; + + /// Memory space access + static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric; + + /// Store iterators require two pointers + static int const kPointerCount = 2; + + /// Parameters object + struct Params { + // + // Data members + // + + /// Pointer to element type + Scalar *pointer; + + /// Strides + Coord<4> stride; + + // + // Methods + // + + /// Constructs a parameters object + CUTLASS_HOST_DEVICE + Params(Scalar *_pointer = 0) + : pointer(_pointer), + stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {} + + /// Constructs a params object from a TensorRef + CUTLASS_HOST_DEVICE + Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { } + }; + + // + // Data members + // + + /// Pointer to element type + Scalar *pointer[kPointerCount]; + + /// Strides + Coord<4> stride; + + // + // Methods + // + + /// Constructs a store iterator + __device__ Volta884ThreadblockMultiplicandStoreIterator( + Params const &_params, + Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0), + ThreadOffset offset_func = ThreadOffset()) + : stride(_params.stride) { + for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) { + Coord<4> offset = offset_func(ptr_idx); + pointer[ptr_idx] = _params.pointer + (_block_offset + offset).template dot(stride); + } + } + + /// Stores a fragment + CUTLASS_DEVICE + void store(Fragment const &fragment, Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const { + FragmentConstIterator frag_iterator(fragment); + + // Iterate over each store + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { // strided LDG.128s + + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { // contiguous LDG.128s + + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { // 2x STS operations per LDG + + int load_tile_idx = d; + int k_idx = w + h * 8; + int smem_row = (d >> 1); + + // Two store pointers + int ptr_idx = ((load_tile_idx & 1) ^ ((load_tile_idx >> 1) & 1)); + + Coord<4> sts_offset = make_Coord(k_idx, smem_row, 0, 0); + + if (true || (d == 0 && (threadIdx.x / 32) == 1)) { + Store::store( + reinterpret_cast(frag_iterator.at(d, h, w, 0)), + pointer[ptr_idx], + stride.template dot(sts_offset + offset)); + } + } + } + } + } + + /// Increments store iterator to next tile + __device__ Volta884ThreadblockMultiplicandStoreIterator &increment(int count = 1) { + for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) { + pointer[ptr_idx] += + make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot(stride); + } + return *this; + } + + /// Increments to next tile + __device__ Volta884ThreadblockMultiplicandStoreIterator &operator++() { return increment(); } + + __device__ Volta884ThreadblockMultiplicandStoreIterator &operator+=(int count) { + return increment(count); + } + + /// Increments store iterator to previous tile + __device__ Volta884ThreadblockMultiplicandStoreIterator &decrement(int count = 1) { + for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) { + pointer[ptr_idx] -= + make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot(stride); + } + return *this; + } + + /// Increments to subsequent tile + __device__ Volta884ThreadblockMultiplicandStoreIterator &operator--() { return decrement(); } + + /// Decrements to previous tile + __device__ Volta884ThreadblockMultiplicandStoreIterator &operator-=(int count) { + return decrement(count); + } + + /// Stores a fragment and increments in the K dimension + __device__ Volta884ThreadblockMultiplicandStoreIterator &store_post_increment( + Fragment const &fragment, Coord<4> const &offset = make_Coord(0, 0, 0, 0)) { + store(fragment, offset); + return increment(); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterator to load a fragment for each warp-level tile specialized for B.column_major +template < + /// Specifies threadblock tile shape + typename Tile_, + /// Specifies the warp tile shape + typename WarpTile_, + /// Specifies the number of participating warps + int WarpCount, + /// Specifies the delta between warp accesses along the outer dimension + typename WarpDelta_> +struct Volta884WarpMultiplicandLoadIterator { + // + // Constant and type definitions + // + + /// Identifies multiplicand of GEMM (A or B) + static GemmOperand::Kind const kOperand = GemmOperand::kB; + + /// Specifies layout of data in source memory + static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor; + + /// Shape of thread-block multiplicand + typedef Tile_ Tile; + + /// Shape of warp-tile matrix operation + typedef WarpTile_ WarpTile; + + /// Hard-coded tile shape + typedef Shape<4, 32, 32> InterleavedTileShape; + + /// Number of participating warps + static int const kWarpCount = WarpCount; + + /// Delta between warp accumulator tiles along the outer dimension + typedef WarpDelta_ WarpDelta; + + /// This implementation is specialized for 128b loads + static int const kAccessSize = 8; + + /// Swizzled store iterator + struct ThreadOffset { + /// Compute thread offset coordinate for each pointer + __device__ Coord<4> operator()(int ptr_idx) const { + int warp_id = (threadIdx.x >> 5) / WarpDelta::kW; + int lane_id = (threadIdx.x & 0x1f); + + int lane_in_quad = (lane_id & 0x3); + int quad_id = ((lane_id >> 2) & 0x7); + + int oct_col_id = (quad_id >> 1); + + int oct_col = (oct_col_id & 1); + int oct_left_right = ((oct_col_id >> 1) & 1) ^ (oct_col_id & 1) ^ ptr_idx; + + Coord<4> offset = + make_Coord(0, warp_id * 2 + oct_col, (lane_in_quad * 2) + oct_left_right, 0); + + return offset; + } + }; + + /// Projects the threadblock tile + typedef typename GemmMultiplicandTraits::Shape OperandShape; + + /// Loaded tile has a structure designed for efficient MIO storing and loading + typedef Shape<(OperandShape::kW >> 2), // one 3D tile per four elements in the K dimension + (OperandShape::kH >> 4), // four rows of SMEM per 64xK tile + 8, // Eight banks of MIO + kAccessSize> + VectorizedShape; + + /// Offset between acceses + typedef Shape<1, 2 * WarpDelta::kH, 1, 1> Delta; + + /// Number of iterations + typedef Shape<1, WarpTile::kH / InterleavedTileShape::kH, 1, 1> Iterations; + + /// Source tile traits + typedef TileTraits Traits; + + /// Scalar type + typedef half Scalar; + + /// Index type + typedef int Index; + + /// Index type + typedef int LongIndex; + + // + // Derived types + // + + /// Tensor reference + typedef TensorRef TensorRef; + + /// Predicate vector + typedef PredicateVector::kCount> PredicateVector; + + /// Fragment definition + typedef Fragment::kCount * kAccessSize> Fragment; + + /// Elements loaded by one instruction + typedef typename Vectorize::Type AccessType; + + /// The fragment iterator. + typedef FragmentIterator FragmentIterator; + + /// The fragment const iterator. + typedef FragmentConstIterator FragmentConstIterator; + + /// Strides into expected SMEM tile + typedef typename ShapeStrides::Shape Strides; + + /// Memory space access + static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric; + + /// Pointer count is always two + static int const kPointerCount = 2; + + /// Parameters object + struct Params { + // + // Data members + // + + /// Base pointer to SMEM allocation + Scalar const *pointer; + + /// SMEM strides + Coord<4> stride; + + // + // Methods + // + + /// Constructs a parameters object + CUTLASS_HOST_DEVICE + Params(Scalar const *_pointer = 0) + : pointer(_pointer), + stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {} + + /// Constructs a params object from a TensorRef + CUTLASS_HOST_DEVICE + Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { } + }; + + // + // Data members + // + + /// Shared memory load pointer + Scalar const *pointer[kPointerCount]; + + /// SMEM strides + Coord<4> stride; + + /// Index in D dimension - needed to permute loads + int k_index; + + // + // Methods + // + + __device__ int column(uint16_t item) const { return ((item >> 8) & 0xff); } + + __device__ int column(half const *ptr) const { + return column(reinterpret_cast(*ptr)); + } + + /// Constructs a load iterator + __device__ Volta884WarpMultiplicandLoadIterator( + Params const &_params, + Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0), + ThreadOffset offset_func = ThreadOffset()) + : stride(_params.stride), k_index(0) { + CUTLASS_PRAGMA_UNROLL + for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) { + Coord<4> offset = offset_func(ptr_idx); + + pointer[ptr_idx] = _params.pointer + (_block_offset + offset).template dot(stride); + } + } + + /// Stores a fragment + __device__ void load(Fragment &fragment, Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const { + FragmentIterator frag_iterator(fragment); + + // Iterate over each load + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + Coord<4> lds_offset = + make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC); + + int ptr_idx = ((offset[0] >> 2) & 1); + Scalar const *_pointer = pointer[ptr_idx]; + + Load::load( + reinterpret_cast(frag_iterator.at(d, h, w, 0)), + _pointer, + stride.template dot(lds_offset + offset)); + + if (offset[0] & 2) { + // peculiar swap for crosswise loads + int lds128_idx = w + Iterations::kW * (h + Iterations::kH * d); + uint64_t *left = reinterpret_cast(&fragment) + lds128_idx * 2; + uint64_t *right = reinterpret_cast(&fragment) + lds128_idx * 2 + 1; + uint64_t tmp = *left; + *left = *right; + *right = tmp; + } + } + } + } + } + + /// Loads a fragment and increments to next K-index + __device__ void load_post_increment(Fragment &fragment, + Coord<4> const &offset = make_Coord(0, 0, 0, 0)) { + load(fragment, offset + make_Coord(k_index, 0, 0, 0)); + ++k_index; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/volta884_shared_tile_epilogue.h b/cutlass/gemm/volta884_shared_tile_epilogue.h new file mode 100644 index 0000000000..19802adc0e --- /dev/null +++ b/cutlass/gemm/volta884_shared_tile_epilogue.h @@ -0,0 +1,629 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines structural properties for GEMM targeting Volta's mma.sync instruction + + DO NOT INCLUDE THIS FILE DIRECTLY. + + This file is intended to be included by and defines + partial specializations for templates specified therein. +*/ + +#pragma once + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for FP32 accumulator layouts +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Epilogue shared memory store iterator specialized for Volta's mma.sync.FP32 layout +template < + /// Shape of warp-level GEMM + typename WarpGemmTile_, + /// Tiling of warp accumulator elements + typename WarpDelta_, + /// Data type of accumulator elements + typename Scalar_> +struct Volta884EpilogueSharedStoreIterator { + /// Warp-scoped GEMM tile size + typedef WarpGemmTile_ WarpGemmTile; + + /// Tiling of warp elements across threadblock + typedef WarpDelta_ WarpDelta; + + /// Scalar data type + typedef Scalar_ Scalar; + + /// Accumulator data type (and layout) + typedef float Accumulator; + + /// Index type + typedef int Index; + + /// Index type + typedef int LongIndex; + + // Host-side params + struct Params {}; + + /// Access size + static int const kAccessSize = 1; + + /// Skew elements to ensure conflict free stores + static int const kSkew = 2; + + /// Shape of one interleaved mma.sync tile + typedef Shape<4, 32, 32> MmaTileShape; + + /// Four element fragment + typedef Shape Iterations; + + /// Delta separated by two elements + typedef Shape Delta; + + // + // Dependent types + // + + /// Predicate vector + typedef PredicateVector::kCount> PredicateVector; + + /// Memory space access + static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric; + + /// Fragment definition + typedef Fragment::kCount * kAccessSize> Fragment; + + /// Elements loaded by one instruction + typedef typename Vectorize::Type AccessType; + + /// The fragment iterator. + typedef FragmentIterator FragmentIterator; + + /// The fragment const iterator. + typedef FragmentConstIterator FragmentConstIterator; + + /// Tensor reference type + typedef TensorRef TensorRef; + + // + // Data members + // + + /// Base pointer to SMEM allocation + Scalar *pointer; + + /// Stride in shared memory + Coord<4> strides; + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + Volta884EpilogueSharedStoreIterator(Params const &_params, TensorRef const &ref) + : pointer(ref.data()), strides(make_Coord(1, WarpDelta::kW * WarpGemmTile::kW + kSkew, 1, 1)) { + + int warp_id = (threadIdx.x / kWarpSize); + int lane_id = (threadIdx.x % kWarpSize); + + Coord<4> warp_idx = make_Coord(0, warp_id / WarpDelta::kW, warp_id % WarpDelta::kW, 0); + + Coord<4> warp_base = warp_idx * make_Coord(0, 4, MmaTileShape::kW, 0); + + Coord<4> thread_idx = make_Coord(0, + (((lane_id >> 1) & 4) | (lane_id & 2)) >> 1, + (lane_id & 1) | ((lane_id >> 1) & 8) | ((lane_id << 2) & 16), + 0); + + int offset = strides.template dot(warp_base + thread_idx); + + pointer += offset; + } + + /// Store to the epilogue tile. + CUTLASS_DEVICE + void store(Fragment const &fragment) const { + FragmentConstIterator frag_iterator(fragment); + + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + Coord<4> coord = + make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC); + + int _offset = coord.template dot(strides); + + Store::store( + reinterpret_cast(frag_iterator.at(d, h, w, 0)), pointer, + _offset); + } + } + } + } + + /// Stores to the epilogue tile - this iterator does not advance, so increment is null. + CUTLASS_DEVICE + void store_post_increment(Fragment const &fragment) { store(fragment); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Epilogue shared memory load iterator specialized for Volta's mma.sync.FP32 layout +template < + /// Shape of warp-level GEMM + typename WarpGemmTile_, + /// Tiling of warp accumulator elements + typename WarpDelta_, + /// Data type of accumulator elements + typename Scalar_, + /// Number of elements loaded per access + int AccessSize_> +struct Volta884EpilogueSharedLoadIterator { + /// Warp-scoped GEMM tile size + typedef WarpGemmTile_ WarpGemmTile; + + /// Tiling of warp elements across threadblock + typedef WarpDelta_ WarpDelta; + + /// Scalar data type + typedef Scalar_ Scalar; + + /// Accumulator data type (and layout) + typedef float Accumulator; + + /// Index type + typedef int Index; + + /// Index type + typedef int LongIndex; + + /// Number of elements accessed at once + static int const kAccessSize = AccessSize_; + + /// Shape of one interleaved mma.sync tile + typedef Shape<4, 32, 32> MmaTileShape; + + /// Total participating warps + static int const kWarpCount = ShapeCount::kCount; + + /// Total participating threads + static int const kThreadCount = kWarpCount * kWarpSize; + + /// Skew elements + static int const kSkew = 2; + + /// This tile is to be strip-mined with a swizzling function + typedef Shape<2 * WarpDelta::kH, 2, WarpGemmTile::kW * WarpDelta::kW, 1> Tile; + + /// Number of iterations + typedef Shape<2 * WarpDelta::kH, + (kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH), + (kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount), + 1> + Iterations; + + /// Delta between accesses + typedef Shape<2, 1, kThreadCount, 1> Delta; + + // + // Derived quantities + // + + /// Predicate vector + typedef PredicateVector::kCount> PredicateVector; + + /// Fragment of elements to load + typedef Fragment::kCount * kAccessSize> Fragment; + + /// Elements loaded by one instruction + typedef typename Vectorize::Type AccessType; + + /// The fragment iterator. + typedef FragmentIterator FragmentIterator; + + /// The fragment const iterator. + typedef FragmentConstIterator FragmentConstIterator; + + static_assert(!(kSkew % kAccessSize), "Access size must have compatible alignment with skew"); + + /// Memory space access + static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric; + + /// Tensor reference type + typedef TensorRef TensorRef; + + /// Host-side params + struct Params {}; + + // + // Data members + // + + /// Pointer + Scalar const *pointer; + + /// Strides + Coord<4> strides; + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + Volta884EpilogueSharedLoadIterator(Params const &_params, TensorRef const &ref) + : pointer(ref.data()), + strides(make_Coord((WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize, + (WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize, + kAccessSize, + 1)) { + // strip-mine this tile + int tid = threadIdx.x; + + int residual_w = (tid / (Tile::kW)); + int offset_w = (tid % (Tile::kW)); + + int offset_h = (residual_w % Tile::kH); + int offset_d = (residual_w / Tile::kH); + + Coord<4> offset = make_Coord(offset_d * Delta::kW, offset_h * Delta::kH, offset_w, 0); + + pointer += strides.template dot(offset); + } + + /// Loads a fragment from the epilogue tile. + CUTLASS_DEVICE + void load(Fragment &fragment) const { + FragmentIterator frag_iterator(fragment); + + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + Coord<4> coord = + make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kW); + + int _offset = coord.template dot(strides); + + Load::load( + reinterpret_cast(frag_iterator.at(d, h, w, 0)), pointer, _offset); + } + } + } + } + + /// Loads a fragment - iterator does not actually advance, so increment operation is null. + CUTLASS_DEVICE + void load_post_increment(Fragment &fragment) { load(fragment); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for FP16 accumulator layouts +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Epilogue shared memory store iterator specialized for Volta's mma.sync.FP16 layout +template < + /// Shape of warp-level GEMM + typename WarpGemmTile_, + /// Tiling of warp accumulator elements + typename WarpDelta_, + /// Data type of accumulator elements + typename Scalar_> +struct Volta884EpilogueSharedStoreIterator { + /// Warp-scoped GEMM tile size + typedef WarpGemmTile_ WarpGemmTile; + + /// Tiling of warp elements across threadblock + typedef WarpDelta_ WarpDelta; + + /// Scalar data type + typedef Scalar_ Scalar; + + /// Accumulator data type (and layout) + typedef half Accumulator; + + /// Index type + typedef int Index; + + /// Index type + typedef int LongIndex; + + /// Host-side params + struct Params {}; + + /// Dimensions of contiguous 32x32x4 Volta's mma.sync tile + typedef Shape<4, 32, 32> MmaTileShape; + + /// Accumulator fragment + typedef Shape Iterations; + + /// Delta separated by two elements + typedef Shape Delta; + + /// Access size + static int const kAccessSize = 1; + + /// Skew elements to ensure conflict free stores + static int const kSkew = 2; + + /// Tensor reference type + typedef TensorRef TensorRef; + + // + // Dependent types + // + + /// Predicate vector + typedef PredicateVector::kCount> PredicateVector; + + /// Memory space access + static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric; + + /// Fragment definition + typedef Fragment::kCount * kAccessSize> Fragment; + + /// Elements loaded by one instruction + typedef typename Vectorize::Type AccessType; + + /// The fragment iterator. + typedef FragmentIterator FragmentIterator; + + /// The fragment const iterator. + typedef FragmentConstIterator FragmentConstIterator; + + // + // Data members + // + + /// Base pointer to SMEM allocation + Scalar *pointer; + + /// Stride in shared memory + Coord<4> strides; + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + Volta884EpilogueSharedStoreIterator(Params const &_params, TensorRef const &ref) + : pointer(ref.data()), strides(make_Coord(1, WarpGemmTile::kW * WarpDelta::kW + kSkew, 1, 1)) { + + int warp_id = (threadIdx.x / kWarpSize); + int lane_id = (threadIdx.x % kWarpSize); + + int quad_id = (lane_id >> 2); + int quadpair_id = (quad_id & 0x3); + + int quadpair_row = (quadpair_id & 1); + int quadpair_col = (quadpair_id >> 1); + int quad_hilo = (quad_id >> 2) & 1; + + int thread_row_offset = (quadpair_row * 2 + quad_hilo) * 8 + (lane_id & 3); + int thread_col_offset = quadpair_col; + + Coord<4> thread_idx = make_Coord(0, thread_col_offset, thread_row_offset, 0); + + Coord<4> warp_base = make_Coord(0, warp_id / WarpDelta::kW, warp_id % WarpDelta::kW, 0) * + make_Coord(0, 2, kWarpSize, 0); + Coord<4> offset = warp_base + thread_idx; + + pointer += strides.template dot(offset); + } + + /// Store to the epilogue tile. + CUTLASS_DEVICE + void store(Fragment const &fragment) const { + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + Coord<4> coord = + make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC); + + int _offset = coord.template dot(strides); + + Store::store( + reinterpret_cast(fragment[w + Iterations::kW * d]), + pointer, + _offset); + } + } + } + } + + /// Stores to the epilogue tile - this iterator does not advance, so increment is null. + CUTLASS_DEVICE + void store_post_increment(Fragment const &fragment) { store(fragment); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Epilogue shared memory load iterator specialized for Volta's mma.sync.FP16 layout +template < + /// Shape of warp-level GEMM + typename WarpGemmTile_, + /// Tiling of warp accumulator elements + typename WarpDelta_, + /// Data type of accumulator elements + typename Scalar_, + /// Number of elements loaded per access + int AccessSize_> +struct Volta884EpilogueSharedLoadIterator { + /// Warp-scoped GEMM tile size + typedef WarpGemmTile_ WarpGemmTile; + + /// Tiling of warp elements across threadblock + typedef WarpDelta_ WarpDelta; + + /// Scalar data type + typedef Scalar_ Scalar; + + /// Accumulator data type (and layout) + typedef half Accumulator; + + /// Number of elements accessed at once + static int const kAccessSize = AccessSize_; + + /// Shape of one interleaved mma.sync tile + typedef Shape<4, 32, 32> MmaTileShape; + + /// This tile is to be strip-mined with a swizzling function + typedef Shape<1, 2 * WarpDelta::kH, WarpGemmTile::kW * WarpDelta::kW / kAccessSize, kAccessSize> + Tile; + + /// Index type + typedef int Index; + + /// Index type + typedef int LongIndex; + + /// Total participating warps + static int const kWarpCount = ShapeCount::kCount; + + /// Number of participating threads + static int const kThreadCount = kWarpSize * kWarpCount; + + /// Number of iterations + typedef Shape<1, + (kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH), + (kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount), + 1> + Iterations; + + /// Delta between thread-level accesses + typedef typename platform::conditional= Tile::kW, + Shape<1, (kThreadCount / Tile::kW), 1, 1>, + Shape<1, 1, kThreadCount, 1> >::type Delta; + + // + // Derived quantities + // + + /// Predicate vector + typedef PredicateVector::kCount> PredicateVector; + + /// Fragment of elements to load + typedef Fragment::kCount * kAccessSize> Fragment; + + /// Elements loaded by one instruction + typedef typename Vectorize::Type AccessType; + + /// The fragment iterator. + typedef FragmentIterator FragmentIterator; + + /// The fragment const iterator. + typedef FragmentConstIterator FragmentConstIterator; + + /// Skew elements + static int const kSkew = 2; + + static_assert(!(kSkew % kAccessSize), "Access size must have compatible alignment with skew"); + + /// Memory space access + static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric; + + /// Tensor reference type + typedef TensorRef TensorRef; + + /// Host-side params + struct Params {}; + + // + // Data members + // + + /// Pointer + Scalar const *pointer; + + /// Strides + Coord<4> strides; + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + Volta884EpilogueSharedLoadIterator(Params const &_params, TensorRef const &ref) + : pointer(ref.data()), + strides(make_Coord(2 * (WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize, + (WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize, + kAccessSize, + 1)) { + // strip-mine this tile + Coord<4> offset = make_Coord(0, threadIdx.x / Tile::kW, threadIdx.x % Tile::kW, 0); + + pointer += strides.template dot(offset); + } + + /// Loads a fragment from the epilogue tile. + CUTLASS_DEVICE + void load(Fragment &fragment) const { + FragmentIterator frag_iterator(fragment); + + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < Iterations::kW; ++w) { + Coord<4> coord = + make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kW); + + int _offset = coord.template dot(strides); + + Load::load( + reinterpret_cast(fragment[w + Iterations::kW * h]), pointer, _offset); + } + } + } + } + + /// Loads a fragment - iterator does not actually advance, so increment operation is null. + CUTLASS_DEVICE + void load_post_increment(Fragment &fragment) { load(fragment); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/wmma_gemm_epilogue_traits.h b/cutlass/gemm/wmma_gemm_epilogue_traits.h index 0eccab02b8..38a4ed0a38 100644 --- a/cutlass/gemm/wmma_gemm_epilogue_traits.h +++ b/cutlass/gemm/wmma_gemm_epilogue_traits.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -106,7 +106,7 @@ struct WmmaGemmEpilogueTraitsHelper { // The number of scalars per LDS. GemmConfig_::kScalarsPerLdsD, // this parameter helps with swizzling when accum is fp32 and output is fp16 - sizeof(Accumulator_) / sizeof(typename GemmConfig_::ScalarD) + int(sizeof(Accumulator_)) / int(sizeof(typename GemmConfig_::ScalarD)) > SharedLoadTileTraits; diff --git a/cutlass/gemm/wmma_gemm_global_tile.h b/cutlass/gemm/wmma_gemm_global_tile.h index 2c197a8b47..1b235337bc 100644 --- a/cutlass/gemm/wmma_gemm_global_tile.h +++ b/cutlass/gemm/wmma_gemm_global_tile.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -103,7 +103,7 @@ struct WmmaGemmGlobalIteratorCd : public GemmGlobalIteratorCdpointer = pointer; + BaseParams::pointer = pointer; // Stride between GEMMs this->stride_d = batch_stride; // Setup the base stride. One "group of threads" per column. diff --git a/cutlass/gemm/wmma_gemm_multiply_add.h b/cutlass/gemm/wmma_gemm_multiply_add.h index 328e43adbd..796c0cfd41 100644 --- a/cutlass/gemm/wmma_gemm_multiply_add.h +++ b/cutlass/gemm/wmma_gemm_multiply_add.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -85,7 +85,10 @@ struct WmmaGemmMultiplyAdd { FragmentB const& b, Accumulators const& c, Accumulators& d) { + + CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Iterations::kH; ++j) { + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < Iterations::kW; ++i) { // The input elements. ElementA const& elt_a = a[i]; @@ -164,7 +167,10 @@ struct WmmaGemmMultiplyAdd CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment) { typename InputIterator::FragmentIterator frag_iterator(fragment); + CUTLASS_PRAGMA_UNROLL for (int d = 0; d < InputIterator::Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL for (int h = 0; h < InputIterator::Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL for (int w = 0; w < InputIterator::Iterations::kW; ++w) { + CUTLASS_PRAGMA_UNROLL for (int c = 0; c < InputIterator::Iterations::kC; ++c) { if (iterator.valid(d, h, w, c)) { iterator.load_element(reinterpret_cast( @@ -69,9 +73,13 @@ CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragme template CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment) { typename OutputIterator::FragmentIterator frag_iterator(fragment); + CUTLASS_PRAGMA_UNROLL for (int d = 0; d < OutputIterator::Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL for (int h = 0; h < OutputIterator::Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL for (int w = 0; w < OutputIterator::Iterations::kW; ++w) { + CUTLASS_PRAGMA_UNROLL for (int c = 0; c < OutputIterator::Iterations::kC; ++c) { if (iterator.valid(d, h, w, c)) { iterator.store_element(reinterpret_cast( diff --git a/cutlass/kernel_launch.h b/cutlass/kernel_launch.h index ee37b2fda9..b48fd7d0b0 100644 --- a/cutlass/kernel_launch.h +++ b/cutlass/kernel_launch.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/layout/thread/tensor_foreach.h b/cutlass/layout/thread/tensor_foreach.h new file mode 100644 index 0000000000..d1f6233458 --- /dev/null +++ b/cutlass/layout/thread/tensor_foreach.h @@ -0,0 +1,90 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +//#include +#include "cutlass/cutlass.h" +//#include "tools/util/reference/device/kernel/tensor_foreach.h" + +namespace cutlass { +namespace layout { +namespace thread { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines several helpers +namespace detail { + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + /// Index of the active rank + static int const kActiveRank = Rank - RankRemaining - 1; + + /// Constructor for general rank + CUTLASS_DEVICE TensorForEachHelper(Func &func, Coord const &size, Coord &coord) { + for (int i = 0; i < size.at(kActiveRank); ++i) { + coord[kActiveRank] = i; + TensorForEachHelper(func, size, coord); + } + } +}; + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + /// Index of the active rank + static int const kActiveRank = Rank - 1; + + /// Constructor for fastest chaning rank + CUTLASS_DEVICE TensorForEachHelper(Func &func, Coord const &size, Coord &coord) { + for (int i = 0; i < size.at(kActiveRank); ++i) { + coord[kActiveRank] = i; + func(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over the index space of a tensor +template +struct TensorForEach { + /// Constructor performs the operation. + CUTLASS_DEVICE TensorForEach(Coord size, Params params = Params()) { + Func func(params); + Coord coord; + + detail::TensorForEachHelper(func, size, coord); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace layout +} // namespace cutlass diff --git a/cutlass/layout/thread/transform.h b/cutlass/layout/thread/transform.h new file mode 100644 index 0000000000..3abb22e141 --- /dev/null +++ b/cutlass/layout/thread/transform.h @@ -0,0 +1,300 @@ +/*************************************************************************************************** + * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Basic copy routines for tensor views +*/ + +#pragma once + +#include "cutlass/fragment.h" +#include "cutlass/layout/thread/tensor_foreach.h" +#include "cutlass/tensor_view.h" + +namespace cutlass { +namespace layout { +namespace thread { + +/// Define a functor that performs a copy operation on a tensor. +template +struct CopyFunc { + /// Coordinate of index space + typedef typename View_dst::TensorCoord TensorCoord; + + View_dst dst; + + View_src src; + + /// Constructor + CUTLASS_DEVICE + CopyFunc(View_dst dst, View_src src) : dst(dst), src(src) {} + + /// copy function + CUTLASS_DEVICE + void operator()(TensorCoord const& coord) { + dst.at(coord) = src.at(coord); // uses tensor view's map() + } +}; + +template +struct Copy { + CUTLASS_DEVICE void copy(cutlass::TensorView dst, + cutlass::TensorView src) { + // Define a functor called by TensorForEach<> + typedef CopyFunc, + cutlass::TensorView > + CopyFunc; + + // Instantiate on device with TensorViews + CopyFunc copy_func(dst, src); + + // Invoke device-side for-each computation on the tensor + cutlass::layout::thread::TensorForEach(src.size(), copy_func); + } +}; + +template +struct Copy { + CUTLASS_DEVICE void copy(cutlass::TensorView dst, + cutlass::TensorView src) { + bool isPacked = dst.isPacked() && src.isPacked(); + if (isPacked) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < src.capacity(); ++i) { + dst.at(i) = src.at(i); + } + } else { + typedef CopyFunc, + cutlass::TensorView > + CopyFunc; + + // Instantiate on device with TensorViews + CopyFunc copy_func(dst, src); + + // Invoke device-side for-each computation on the tensor + cutlass::layout::thread::TensorForEach(src.size(), copy_func); + } + } +}; + +/// hgemm swizzle +/// Transform a fragment. +template <> +struct Copy { + CUTLASS_DEVICE void copy( + cutlass::TensorView dst, + cutlass::TensorView src) { + // Expose src/dst as int arrays. + int const* src_int = reinterpret_cast(src.const_ref().data()); + int* dst_int = reinterpret_cast(dst.ref().data()); + + int kD = src.size(0); + int kDhw = src.size(0) * src.size(1); + + // Transpose the data. + // CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < kD; ++d) { + // The indices to read two consecutive "rows". + int const i0 = 2 * d + 0; + int const i1 = 2 * d + 1; + + int a0 = src_int[i0]; + int a1 = src_int[i1]; + + int b0, b1; + asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(a0), "r"(a1)); + asm volatile("prmt.b32 %0, %1, %2, 0x7632;" : "=r"(b1) : "r"(a0), "r"(a1)); + + // The indices to store with "strides". + int const j0 = 0 * (kDhw / 2) + d; + int const j1 = 1 * (kDhw / 2) + d; + + dst_int[j0] = b0; + dst_int[j1] = b1; + } + } +}; + +/// igemm swizzle +/// Transform a fragment. +template <> +struct Copy { + CUTLASS_DEVICE void copy( + cutlass::TensorView dst, + cutlass::TensorView src) { + // Expose src/dst as int arrays. + int const* src_int = reinterpret_cast(src.const_ref().data()); + int* dst_int = reinterpret_cast(dst.ref().data()); + + int kD = src.size(0); + int kH = src.size(1); + int kWc = src.stride(0); + int kHwc = kH * kWc; + + // Transpose the data. + CUTLASS_PRAGMA_UNROLL + for (int d = 0; d < kD; ++d) { + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < kH / 4; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < kWc / 4; ++w) { + int const i0 = d * (kHwc / 4) + (4 * h + 0) * (kWc / 4) + w; + int const i1 = d * (kHwc / 4) + (4 * h + 1) * (kWc / 4) + w; + int const i2 = d * (kHwc / 4) + (4 * h + 2) * (kWc / 4) + w; + int const i3 = d * (kHwc / 4) + (4 * h + 3) * (kWc / 4) + w; + + int a0 = src_int[i0]; + int a1 = src_int[i1]; + int a2 = src_int[i2]; + int a3 = src_int[i3]; + + int b0, b1, b2, b3, c0; + asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(b0) : "r"(a0), "r"(a1)); + asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(c0) : "r"(a2), "r"(a3)); + asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(b0), "r"(c0)); + + asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(b1) : "r"(a0), "r"(a1)); + asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(c0) : "r"(a2), "r"(a3)); + asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b1) : "r"(b1), "r"(c0)); + + asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(b2) : "r"(a0), "r"(a1)); + asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(c0) : "r"(a2), "r"(a3)); + asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b2) : "r"(b2), "r"(c0)); + + asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(b3) : "r"(a0), "r"(a1)); + asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(c0) : "r"(a2), "r"(a3)); + asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b3) : "r"(b3), "r"(c0)); + + dst_int[i0] = b0; + dst_int[i1] = b1; + dst_int[i2] = b2; + dst_int[i3] = b3; + } + } + } + } +}; + +template +struct Transform { + + typedef Fragment::kCount> DstFragment; + typedef Fragment::kCount> SrcFragment; + + /// The input fragment. + typedef SrcFragment InputFragment; + /// The output fragment. + typedef DstFragment OutputFragment; + + CUTLASS_DEVICE void transform(SrcFragment& src, DstFragment& dst) { + cutlass::TensorView dstView( + &dst[0], // pointer to base of matrix in device memory + cutlass::make_Coord(Shape::kD, 1), // stride vector + cutlass::make_Coord(Shape::kD, + Shape::kH * Shape::kW) // bounds of matrix + ); + cutlass::TensorView srcView( + &src[0], // pointer to base of matrix in device memory + cutlass::make_Coord(Shape::kD, 1), // stride vector + cutlass::make_Coord(Shape::kD, + Shape::kH * Shape::kW) // bounds of matrix + ); + cutlass::layout::thread::Copy Transformer; + Transformer.copy(dstView, srcView); + } +}; + +template +struct Transform { + typedef Fragment::kCount> DstFragment; + typedef Fragment::kCount> SrcFragment; + + /// The input fragment. + typedef SrcFragment InputFragment; + /// The output fragment. + typedef DstFragment OutputFragment; + + CUTLASS_DEVICE void transform(SrcFragment& src, DstFragment& dst) { + cutlass::TensorView dstView( + &dst[0], // pointer to base of matrix in device memory + cutlass::make_Coord(Shape::kD, 1), // stride vector + cutlass::make_Coord(Shape::kD, + Shape::kH * Shape::kW) // bounds of matrix + ); + cutlass::TensorView srcView( + &src[0], // pointer to base of matrix in device memory + cutlass::make_Coord(Shape::kD, 1), // stride vector + cutlass::make_Coord(Shape::kD, + Shape::kH * Shape::kW) // bounds of matrix + ); + cutlass::layout::thread::Copy Transformer; + Transformer.copy(dstView, srcView); + } +}; + +template +struct Transform { + typedef Fragment::kCount> DstFragment; + typedef Fragment::kCount> SrcFragment; + + /// The input fragment. + typedef SrcFragment InputFragment; + /// The output fragment. + typedef DstFragment OutputFragment; + + CUTLASS_DEVICE void transform(SrcFragment& src, DstFragment& dst) { + cutlass::TensorView dstView( + &dst[0], // pointer to base of matrix in device memory + cutlass::make_Coord(Shape::kW * Shape::kC, 1), // stride vector + cutlass::make_Coord(Shape::kD, + Shape::kH) // bounds of matrix + ); + cutlass::TensorView srcView( + &src[0], // pointer to base of matrix in device memory + cutlass::make_Coord(Shape::kW * Shape::kC, 1), // stride vector + cutlass::make_Coord(Shape::kD, + Shape::kH) // bounds of matrix + ); + cutlass::layout::thread::Copy Transformer; + Transformer.copy(dstView, srcView); + } +}; + +} // namespace thread +} // namespace layout +} // namespace cutlass diff --git a/cutlass/load_store.h b/cutlass/load_store.h index db09dd0a48..ca6ad88315 100644 --- a/cutlass/load_store.h +++ b/cutlass/load_store.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -66,6 +66,11 @@ struct Load { dst = *reinterpret_cast(pointer + offset); } + /// The clear function. + static CUTLASS_HOST_DEVICE void clear(AccessType& dst) { + dst = 0; + } + }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -80,6 +85,11 @@ struct Load(dst) = reinterpret_cast(&pointer[offset])[0]; } + + /// The clear function. + static CUTLASS_HOST_DEVICE void clear(AccessType& dst) { + dst = uint16_t(0); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -94,6 +104,10 @@ struct Load(&pointer[offset])[0]; } + /// The clear function. + static CUTLASS_HOST_DEVICE void clear(AccessType& dst) { + dst.registers[0] = uint32_t(0); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -109,6 +123,13 @@ struct Load { dst.registers[2] = tmp.x; dst.registers[3] = tmp.y; } + + /// The clear function. + static CUTLASS_HOST_DEVICE void clear(AccessType& dst) { + int2 zero = make_int2(0,0); + dst.registers[0] = zero.x; + dst.registers[1] = zero.y; + dst.registers[2] = zero.x; + dst.registers[3] = zero.y; + } }; #endif @@ -164,6 +201,15 @@ struct Load { }; //////////////////////////////////////////////////////////////////////////////////////////////////// - /// Defines data layouts of various matrix formats usable by TensorRef and other classes. // // The following define classes satisfying the TensorRefMapFunc concept. These must support the @@ -367,6 +366,12 @@ struct MatrixTransform { }; }; +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Tensor layout +namespace TensorLayout { + + enum Kind { kNHWC, kNCHW }; +}; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/cutlass/predicate_vector.h b/cutlass/predicate_vector.h index 4a37d017d7..457008fb86 100644 --- a/cutlass/predicate_vector.h +++ b/cutlass/predicate_vector.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -186,9 +186,9 @@ struct PredicateVector { CUTLASS_HOST_DEVICE ConstIterator(ConstIterator const &it) : vec_(it.vec_), bit_(it.bit_) {} - /// + /// Copy ctor CUTLASS_HOST_DEVICE - ConstIterator(PredicateVector const &_vec, int _start = 0) : vec_(_vec), bit_(_start) {} + ConstIterator(PredicateVector const &vec, int _start = 0) : vec_(vec), bit_(_start) {} /// Pre-increment CUTLASS_HOST_DEVICE @@ -197,6 +197,13 @@ struct PredicateVector { return *this; } + /// Increment + CUTLASS_HOST_DEVICE + ConstIterator &operator+=(int offset) { + bit_ += offset; + return *this; + } + /// Pre-decrement CUTLASS_HOST_DEVICE ConstIterator &operator--() { @@ -204,6 +211,13 @@ struct PredicateVector { return *this; } + /// Decrement + CUTLASS_HOST_DEVICE + ConstIterator &operator-=(int offset) { + bit_ -= offset; + return *this; + } + /// Post-increment CUTLASS_HOST_DEVICE ConstIterator operator++(int) { @@ -220,6 +234,22 @@ struct PredicateVector { return ret; } + /// Iterator advances by some amount + CUTLASS_HOST_DEVICE + ConstIterator operator+(int offset) { + ConstIterator ret(*this); + ret.bit_ += offset; + return ret; + } + + /// Iterator recedes by some amount + CUTLASS_HOST_DEVICE + ConstIterator operator-(int offset) { + ConstIterator ret(*this); + ret.bit_ -= offset; + return ret; + } + /// Returns true if iterators point to the same bit CUTLASS_HOST_DEVICE bool operator==(ConstIterator const &it) const { return bit_ == it.bit_; } @@ -230,7 +260,15 @@ struct PredicateVector { /// Dereferences iterator CUTLASS_HOST_DEVICE - bool operator*() const { return vec_[bit_]; } + bool operator*() const { return vec_.at(bit_); } + + /// Gets the bit at the pointed to location + CUTLASS_HOST_DEVICE + bool get() const { return vec_.at(bit_); } + + /// Gets the bit at the pointed to location + CUTLASS_HOST_DEVICE + bool at() const { return vec_.at(bit_); } }; /** @@ -252,7 +290,7 @@ struct PredicateVector { /// Constructs an iterator from a PredicateVector CUTLASS_HOST_DEVICE - Iterator(PredicateVector &_vec, int _start = 0) : vec_(_vec), bit_(_start) {} + Iterator(PredicateVector &vec, int _start = 0) : vec_(vec), bit_(_start) {} /// Pre-increment CUTLASS_HOST_DEVICE @@ -261,6 +299,13 @@ struct PredicateVector { return *this; } + /// Increment + CUTLASS_HOST_DEVICE + Iterator &operator+=(int offset) { + bit_ += offset; + return *this; + } + /// Pre-decrement CUTLASS_HOST_DEVICE Iterator &operator--() { @@ -268,6 +313,13 @@ struct PredicateVector { return *this; } + /// Decrement + CUTLASS_HOST_DEVICE + Iterator &operator-=(int offset) { + bit_ -= offset; + return *this; + } + /// Post-increment CUTLASS_HOST_DEVICE Iterator operator++(int) { @@ -284,6 +336,22 @@ struct PredicateVector { return ret; } + /// Iterator advances by some amount + CUTLASS_HOST_DEVICE + Iterator operator+(int offset) { + Iterator ret(*this); + ret.bit_ += offset; + return ret; + } + + /// Iterator recedes by some amount + CUTLASS_HOST_DEVICE + Iterator operator-(int offset) { + ConstIterator ret(*this); + ret.bit_ -= offset; + return ret; + } + /// Returns true if iterators point to the same bit CUTLASS_HOST_DEVICE bool operator==(Iterator const &it) const { return bit_ == it.bit_; } @@ -294,11 +362,15 @@ struct PredicateVector { /// Gets the bit at the pointed to location CUTLASS_HOST_DEVICE - bool get() { return vec_[bit_]; } + bool get() { return vec_.at(bit_); } + + /// Gets the bit at the pointed to location + CUTLASS_HOST_DEVICE + bool at() const { return vec_.at(bit_); } /// Dereferences iterator CUTLASS_HOST_DEVICE - bool operator*() const { return vec_[bit_]; } + bool operator*() const { return at(); } /// Sets the bit at the pointed to location CUTLASS_HOST_DEVICE diff --git a/cutlass/reduction/batched_reduction.h b/cutlass/reduction/batched_reduction.h index 28e14c494b..83324ec012 100644 --- a/cutlass/reduction/batched_reduction.h +++ b/cutlass/reduction/batched_reduction.h @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -80,14 +80,16 @@ struct BatchedReduction { typename Traits::ScalarA inRegs[Traits::maxInReg]; typename Traits::ScalarAccum AccumRegs[Traits::maxOutReg]; - +#pragma unroll for (int subTile = 0; subTile < tileSize; subTile += subTileSize) { int tileOffset = subTileBase + subTileOffset; // Init AccumRegs +#pragma unroll for (int i = 0; i < Traits::ThreadShape::kW; i++) AccumRegs[i] = static_cast(0.0f); // Fetch c0 typename Traits::ScalarAccum c0[Traits::ThreadShape::kW]; +#pragma unroll for (int i = 0; i< Traits::ThreadShape::kW; i++) c0[i] = static_cast(params.d_c[tileOffset + i]); @@ -131,11 +133,13 @@ struct BatchedReduction { template CUTLASS_DEVICE void functor_caller(typename Traits::ScalarAccum const *accum, typename Traits::ScalarAccum const *old, typename Traits::ScalarAccum *output) { if (ThreadShapeMultiple2 == true) { +#pragma unroll for (int i = 0; i < Traits::ThreadShape::kW / 2; i++) { functor.template evaluate(&accum[2 * i], &old[2 * i], &output[2 * i]); } } else { +#pragma unroll for (int i = 0; i < Traits::ThreadShape::kW; i++) { functor.template evaluate(&accum[i], &old[i], &output[i]); } diff --git a/cutlass/reduction/batched_reduction_traits.h b/cutlass/reduction/batched_reduction_traits.h index bc0c1f2ac9..c44238e1e8 100644 --- a/cutlass/reduction/batched_reduction_traits.h +++ b/cutlass/reduction/batched_reduction_traits.h @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/reduction/threadblock_swizzle.h b/cutlass/reduction/threadblock_swizzle.h index 8be29eed12..6e42cadab4 100644 --- a/cutlass/reduction/threadblock_swizzle.h +++ b/cutlass/reduction/threadblock_swizzle.h @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/reshape_tile.h b/cutlass/reshape_tile.h index 2ae5122036..a0b482ce17 100644 --- a/cutlass/reshape_tile.h +++ b/cutlass/reshape_tile.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/shape.h b/cutlass/shape.h index b8c0c66f35..4c6f95ee33 100644 --- a/cutlass/shape.h +++ b/cutlass/shape.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,6 +72,23 @@ struct Shape { static int const kC = kC_; }; + +/** +* @brief A Shape implementing \ref layout_concept describing the dimensions of a cube. +* @concept{layout_concept} +*/ +template +struct Shape<1, kH_, kW_, 1> { + /// The depth of the cube. + static int const kD = 1; + /// The height of the cube. + static int const kH = kH_; + /// The width of the cube. + static int const kW = kW_; + /// The number of scalars per element. + static int const kC = 1; +}; + /** * @brief Compute derived counted of a \ref layout_concept based class */ diff --git a/cutlass/tensor_ref.h b/cutlass/tensor_ref.h index 09134190c0..d7be9b8a0d 100644 --- a/cutlass/tensor_ref.h +++ b/cutlass/tensor_ref.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -63,12 +63,7 @@ struct IdentityTensorMapFunc { and assumptions about vectorizing memory accesses throughout CUTLASS. It also matches various BLAS conventions in which only the "leading dimension" or most significant stride of a rank=2 matrix is provided. - - This does affect the ability of constructing arbitrary "sparse" 2-D matrices in memory where all - stride elements are > 1. This can be overcome by defining a custom mapping function and a - StorageRank of 3 or more. - - + Examples: (These examples use helpers for matrix layouts defined in cutlass/matrix_traits.h) @@ -85,7 +80,7 @@ struct IdentityTensorMapFunc { TensorRef > C; - 4. Defining a sparse matrix with arbitrary strides in each dimension + 4. Defining a matrix with arbitrary strides in each dimension struct ContiguousLayout { @@ -545,6 +540,10 @@ class TensorRef { CUTLASS_HOST_DEVICE Storage * data() const { return ptr_; } + /// Returns the pointer to referenced data at the given coordinate + CUTLASS_HOST_DEVICE + Storage * data(TensorCoord const& coord) const { return ptr_ + offset(coord); } + /// Returns the stride of the tensor CUTLASS_HOST_DEVICE StorageCoord stride() const { diff --git a/cutlass/tensor_ref_collection.h b/cutlass/tensor_ref_collection.h index 79c0d2683d..eb79bb6c2f 100644 --- a/cutlass/tensor_ref_collection.h +++ b/cutlass/tensor_ref_collection.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/tensor_view.h b/cutlass/tensor_view.h index 4ef99e027e..d770a193a3 100644 --- a/cutlass/tensor_view.h +++ b/cutlass/tensor_view.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -75,7 +75,7 @@ class TensorView : public TensorRef ConstTensorRef; /// Base tensor reference - typedef Base TensorRef; + typedef Base TensorRef_t; /// Storage type typedef typename Base::Storage Storage; @@ -84,14 +84,14 @@ class TensorView : public TensorRefstride(order[i]) < this->stride(order[i + 1])) { + int temp = order[i]; + order[i] = order[i + 1]; + order[i + 1] = temp; + } + } + } + // post-condition: this->stride(ord[i]) >= this->stride(ord[i+1]) for i from [0,Rank_-2] + } + + /// Determines if the values in the tensor are contiguous + CUTLASS_HOST_DEVICE + bool isPacked() const { + if (Rank_ <= 0) return true; + int ord[Rank_]; + getStrideOrder(ord); + // first check if the slowest dimension has a stride of 1 + if (this->stride(ord[Rank_ - 1]) != 1) return false; + // now check that there are no gaps between strides + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank_; i++) + if (this->stride(ord[i]) != this->stride(ord[i + 1]) * size_[ord[i + 1]]) return false; + return true; } /// Returns a TensorRef pointing to the first element of the tensor. CUTLASS_HOST_DEVICE + TensorRef_t ref() const { + return TensorRef_t(*this); + } + + /// Returns a TensorRef_t pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE ConstTensorRef const_ref() const { return ConstTensorRef(*this); } @@ -238,22 +267,22 @@ class TensorView : public TensorRefadd_pointer_offset(this->offset(b)); return *this; } - /// Returns a TensorRef offset by a given amount + /// Returns a TensorRef_t offset by a given amount CUTLASS_HOST_DEVICE TensorView operator-(TensorCoord const& b) const { - TensorRef result(*this); + TensorRef_t result(*this); result.add_pointer_offset(-this->offset(b)); return result; } - /// Returns a TensorRef offset by a given amount + /// Returns a TensorRef_t offset by a given amount CUTLASS_HOST_DEVICE TensorView& operator-=(TensorCoord const& b) { this->add_pointer_offset(-this->offset(b)); diff --git a/cutlass/tile_allocation.h b/cutlass/tile_allocation.h index 873f67d022..24470815f4 100644 --- a/cutlass/tile_allocation.h +++ b/cutlass/tile_allocation.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -48,7 +48,7 @@ struct TileAllocation { typedef Scalar_ Scalar; /// The actual storage (may differ from the scalar type) - typedef typename StorageType::Type Storage; + typedef typename StorageType::Type Storage; /// Size of the allocation in units of scalars typedef Shape_ Shape; @@ -165,4 +165,62 @@ struct ZipTileAllocation { //////////////////////////////////////////////////////////////////////////////////////////////////// +/// Manages a pair of tile allocations as if they are one allocation +template +struct ZipTileAllocationTriple { + // + // Type definitions + // + + /// First tensor allocation + typedef First_ First; + + /// Second tensor allocation + typedef Second_ Second; + + /// meta data tensor allocation + typedef Third_ Third; + + /// Defines the tensor reference for this allocation + typedef Zip3TensorRef TensorRef; + + /// Defines the tensor reference for this allocation + typedef Zip3TensorRef + ConstTensorRef; + + // + // Data members + // + + /// First tensor allocation + First first; + + /// Second tensor allocation + Second second; + + /// meta data tensor + Third third; + // + // Methods + // + + /// Returns a TensorRef object pointing to the data + CUTLASS_DEVICE + TensorRef reference() { + return TensorRef(first.reference(), second.reference(), third.reference()); + } + + /// Returns a TensorRef object pointing to the data + CUTLASS_DEVICE + ConstTensorRef reference() const { + return ConstTensorRef(first.reference(), second.reference(), third.reference()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass diff --git a/cutlass/tile_coord.h b/cutlass/tile_coord.h index b3d809bc36..e9b8ddf6e4 100644 --- a/cutlass/tile_coord.h +++ b/cutlass/tile_coord.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/tile_iterator.h b/cutlass/tile_iterator.h index 71b2e55460..923d7e107e 100644 --- a/cutlass/tile_iterator.h +++ b/cutlass/tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -293,12 +293,10 @@ struct TileIteratorBase { stride_d = _stride_d; stride_h = _stride_h; stride_w = _stride_w; - inc_w = stride_w * Delta::kW; inc_h = stride_h * Delta::kH - stride_w * Delta::kW * (Iterations::kW - 1); inc_d = stride_h * Delta::kD - stride_h * Delta::kH * (Iterations::kH - 1) - stride_w * Delta::kW * (Iterations::kW - 1); - inc_advance = 0; if (kAdvance == IteratorAdvance::kH) { @@ -740,9 +738,13 @@ struct TileLoadIterator : public TileIteratorBase CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) { FragmentIterator frag_iterator(fragment); + CUTLASS_PRAGMA_UNROLL for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL for (int w = 0; w < Iterations::kW; ++w, ++pred_it) { + CUTLASS_PRAGMA_UNROLL for (int c = 0; c < Iterations::kC; ++c) { if (*pred_it) { load_element( @@ -789,8 +791,11 @@ struct TileLoadIterator : public TileIteratorBase CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) { FragmentIterator frag_iterator(fragment); + CUTLASS_PRAGMA_UNROLL for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL for (int w = 0; w < Iterations::kW; ++w) { + CUTLASS_PRAGMA_UNROLL for (int c = 0; c < Iterations::kC; ++c) { load_element(reinterpret_cast(frag_iterator.at(0, h, w, c)), d, h, w, c); } @@ -1076,7 +1081,6 @@ struct TileStoreIterator : public TileIteratorBase CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it) { FragmentConstIterator frag_iterator(fragment); - + CUTLASS_PRAGMA_UNROLL for (int d = 0; d < Iterations::kD; ++d) { + CUTLASS_PRAGMA_UNROLL for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL for (int w = 0; w < Iterations::kW; ++w, ++pred_it) { + CUTLASS_PRAGMA_UNROLL for (int c = 0; c < Iterations::kC; ++c) { if (*pred_it) { store_element( @@ -1213,9 +1220,13 @@ struct TileStoreIterator : public TileIteratorBase CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) { FragmentIterator frag_iterator(fragment); + CUTLASS_PRAGMA_UNROLL for (int h = 0; h < Iterations::kH; ++h) { + CUTLASS_PRAGMA_UNROLL for (int w = 0; w < Iterations::kW; ++w) { + CUTLASS_PRAGMA_UNROLL for (int c = 0; c < Iterations::kC; ++c) { load_element(reinterpret_cast(frag_iterator.at(0, h, w, c)), d, h, w, c); } diff --git a/cutlass/tile_stream.h b/cutlass/tile_stream.h index 7790605a05..00d1964d36 100644 --- a/cutlass/tile_stream.h +++ b/cutlass/tile_stream.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -272,6 +272,9 @@ struct PredicatedTileLoadStream : public TileLoadStream /// Parameters object used to construct generic load stream typedef typename Base::Params Params; + + /// + typedef typename Iterator::Scalar Scalar; // // Data members @@ -331,6 +334,9 @@ struct PredicatedTileStoreStream : public TileStoreStream +template struct TileTraitsWarpRake { /// Shape of tile typedef Tile_ Tile; @@ -163,10 +163,10 @@ struct TileTraitsWarpRake { typedef Shape<1, kWarpsStrided, kWarpsContiguous * kWarpSize> ThreadShape; /// The same warp rakes along the contiguous dimension - typedef Shape<1, kWarpsStrided, kWarpSize> Delta; + typedef Shape<1, kWarpsStrided, kWarpSize * AccessSize> Delta; /// Number of iterations - typedef Shape<1, Tile::kH / Delta::kH, Tile::kW / ThreadShape::kW> Iterations; + typedef Shape<1, Tile::kH / Delta::kH, (Tile::kW / AccessSize) / ThreadShape::kW> Iterations; /// Computes the thread offset in (H, W) based on thread ID struct ThreadOffset { @@ -182,7 +182,7 @@ struct TileTraitsWarpRake { int warp_w = (warp % kWarpsContiguous); int warp_h = (warp / kWarpsContiguous); - return make_Coord(0, warp_h, lane + kWarpSpanContiguous * warp_w, 0); + return make_Coord(0, warp_h, AccessSize * (lane + kWarpSpanContiguous * warp_w), 0); } }; }; diff --git a/cutlass/util/complex.h b/cutlass/util/complex.h index 260a3abd2c..3e794f6574 100644 --- a/cutlass/util/complex.h +++ b/cutlass/util/complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/util/cutlass_math.h b/cutlass/util/cutlass_math.h index e3b46ef35a..55a73c60ca 100644 --- a/cutlass/util/cutlass_math.h +++ b/cutlass/util/cutlass_math.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/util/debug.h b/cutlass/util/debug.h index 6055e3fcc6..9941b41a17 100644 --- a/cutlass/util/debug.h +++ b/cutlass/util/debug.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/util/numeric_types.h b/cutlass/util/numeric_types.h index d8094a2567..4861a5f6d2 100644 --- a/cutlass/util/numeric_types.h +++ b/cutlass/util/numeric_types.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/util/pair.h b/cutlass/util/pair.h index c6ba65a80b..3079ed0f8f 100644 --- a/cutlass/util/pair.h +++ b/cutlass/util/pair.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/util/platform.h b/cutlass/util/platform.h index 3fd7c897d9..1b173d6786 100644 --- a/cutlass/util/platform.h +++ b/cutlass/util/platform.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -84,7 +84,6 @@ * - \p aligned_storage * * (4) Functions and types that are STL-like (but aren't in the STL): - * - \p TODO: min and max functors? * * The idea is that, as we drop support for older compilers, we can simply #define * the \p __NV_STD_XYZ macros and \p platform namespace to alias their C++ diff --git a/cutlass/vector.h b/cutlass/vector.h index aeababb667..9b8a30ea09 100644 --- a/cutlass/vector.h +++ b/cutlass/vector.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -259,6 +259,40 @@ union Vector { //////////////////////////////////////////////////////////////////////////////////////////////////// +/// Vector definition for 4-bit signed integer datatype +template +union Vector { + /// The scalar type. + typedef int8_t Scalar; + + /// The number of elements in the vector. + enum { kLanes = kLanes_ }; + /// The size of the vector. + enum { kVectorSize = kLanes }; + /// The number of registers needed to store the vector. + enum { kRegisters = kVectorSize < 4 ? 1 : (kVectorSize+3) / 4 }; + +// static_assert((kLanes >= 2) && !(kLanes % 2), +// "May only construct vectors of int8_t that are multiples of 8 bits."); + + /// The aligned storage to make sure we have good alignment. + AlignedStruct aligned_; + /// The data in registers. + uint32_t registers[kRegisters]; + + /// Default Constructor + CUTLASS_HOST_DEVICE + Vector() {} + /// Constructor to convert from uint32_t type + CUTLASS_HOST_DEVICE Vector(uint32_t value) { registers[0] = value; } + /// Accessor to the ith lane. + CUTLASS_HOST_DEVICE int operator[](uint32_t i) const { + return (registers[i / 4] >> (i % 4 * 8) & 0xff); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template CUTLASS_HOST_DEVICE void make_zero(Scalar_& x) { x = Scalar_(0); diff --git a/cutlass/wmma_matrix.h b/cutlass/wmma_matrix.h index 61c4ed2724..647acd80f8 100644 --- a/cutlass/wmma_matrix.h +++ b/cutlass/wmma_matrix.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -30,6 +30,10 @@ #if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700) #define CUTLASS_USE_WMMA_API +#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 720) +#define CUTLASS_USE_INT_WMMA +#endif + #if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750) #define CUTLASS_USE_SUBBYTE_WMMA #endif diff --git a/cutlass/zip_fragment.h b/cutlass/zip_fragment.h index 37a788614a..e89ffb3cba 100644 --- a/cutlass/zip_fragment.h +++ b/cutlass/zip_fragment.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/cutlass/zip_tensor_ref.h b/cutlass/zip_tensor_ref.h index d2cff9e0c0..281253afe1 100644 --- a/cutlass/zip_tensor_ref.h +++ b/cutlass/zip_tensor_ref.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -72,6 +72,57 @@ ZipTensorRef make_ZipTensorRef(First const &first, Second const & return ZipTensorRef(first, second); } +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Any simple way to do so? +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Zip3TensorRef { + /// First tensor ref + typedef First_ First; + + /// Second tensor ref + typedef Second_ Second; + + /// Third tensor ref + typedef Third_ Third; + + // + // Data members + // + + /// First TensorRef + First first; + + /// Second TensorRef + Second second; + + /// Third TensorRef + Third third; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Zip3TensorRef() {} + + CUTLASS_HOST_DEVICE + Zip3TensorRef(First const& _first, Second const& _second, Third const& _third) : + first(_first), second(_second), third(_third) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs a ZipTensorRef +template +CUTLASS_HOST_DEVICE +Zip3TensorRef make_Zip3TensorRef(First const &first, + Second const &second, + Third const &third) { + return Zip3TensorRef(first, second, third); +} + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/cutlass/zip_tile_iterator.h b/cutlass/zip_tile_iterator.h index f95acc1aaf..747a22bd0e 100644 --- a/cutlass/zip_tile_iterator.h +++ b/cutlass/zip_tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -47,6 +47,9 @@ class ZipTileIterator { /// Second iterator type typedef Second_ Second; + + /// + typedef typename First::Scalar Scalar; /// Params object struct Params { diff --git a/examples/00_basic_gemm/CMakeLists.txt b/examples/00_basic_gemm/CMakeLists.txt index 144263fff2..b9501940c5 100644 --- a/examples/00_basic_gemm/CMakeLists.txt +++ b/examples/00_basic_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/00_basic_gemm/basic_gemm.cu b/examples/00_basic_gemm/basic_gemm.cu index d6911c1f6a..853dd3f104 100644 --- a/examples/00_basic_gemm/basic_gemm.cu +++ b/examples/00_basic_gemm/basic_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -57,6 +57,8 @@ // Defines cutlass::gemm::SgemmTraits, the structural components for single-precision GEMM #include "cutlass/gemm/sgemm_traits.h" +#pragma warning( disable : 4503) + /////////////////////////////////////////////////////////////////////////////////////////////////// // // This function defines a CUTLASS GEMM kernel instantiation, constructs its parameters object, diff --git a/examples/01_tensor_view/CMakeLists.txt b/examples/01_tensor_view/CMakeLists.txt index 24ab8018ab..76e7cad46f 100644 --- a/examples/01_tensor_view/CMakeLists.txt +++ b/examples/01_tensor_view/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/01_tensor_view/tensor_view.cu b/examples/01_tensor_view/tensor_view.cu index e885e6eeeb..2f2965ae47 100644 --- a/examples/01_tensor_view/tensor_view.cu +++ b/examples/01_tensor_view/tensor_view.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/examples/02_cutlass_utilities/CMakeLists.txt b/examples/02_cutlass_utilities/CMakeLists.txt index f59281e057..c9e05bb0d8 100644 --- a/examples/02_cutlass_utilities/CMakeLists.txt +++ b/examples/02_cutlass_utilities/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/02_cutlass_utilities/cutlass_utilities.cu b/examples/02_cutlass_utilities/cutlass_utilities.cu index 7f04cc5747..2cc1a2dc91 100644 --- a/examples/02_cutlass_utilities/cutlass_utilities.cu +++ b/examples/02_cutlass_utilities/cutlass_utilities.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/examples/03_strided_batched_gemm/CMakeLists.txt b/examples/03_strided_batched_gemm/CMakeLists.txt index 564bc6310d..dc1cd1688b 100644 --- a/examples/03_strided_batched_gemm/CMakeLists.txt +++ b/examples/03_strided_batched_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/03_strided_batched_gemm/strided_batched_gemm.cu b/examples/03_strided_batched_gemm/strided_batched_gemm.cu index e7d387b6cb..35b3843933 100644 --- a/examples/03_strided_batched_gemm/strided_batched_gemm.cu +++ b/examples/03_strided_batched_gemm/strided_batched_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -29,6 +29,8 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/sgemm_traits.h" +#pragma warning( disable : 4503) + /* This example demonstrates how to use cutlass to compute a batched strided gemm. In this example, both A and B matrix are non-transpose and column major matrix diff --git a/examples/04_tile_iterator/CMakeLists.txt b/examples/04_tile_iterator/CMakeLists.txt index 0e74d12db6..ab87241af8 100644 --- a/examples/04_tile_iterator/CMakeLists.txt +++ b/examples/04_tile_iterator/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/04_tile_iterator/tile_iterator.cu b/examples/04_tile_iterator/tile_iterator.cu index 40d5e55198..656866bced 100644 --- a/examples/04_tile_iterator/tile_iterator.cu +++ b/examples/04_tile_iterator/tile_iterator.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/examples/05_wmma_gemm/CMakeLists.txt b/examples/05_wmma_gemm/CMakeLists.txt index ab048532c0..b7a6cd573d 100644 --- a/examples/05_wmma_gemm/CMakeLists.txt +++ b/examples/05_wmma_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/05_wmma_gemm/wmma_gemm.cu b/examples/05_wmma_gemm/wmma_gemm.cu index 2b1e3567f0..c9a5692d4a 100644 --- a/examples/05_wmma_gemm/wmma_gemm.cu +++ b/examples/05_wmma_gemm/wmma_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -47,8 +47,10 @@ // CUTLASS includes needed for WMMA GEMM kernel #include "cutlass/wmma_matrix.h" +#pragma warning( disable : 4503) + // This example works only when this MACRO is defined in "cutlass/wmma_matrix.h" -#ifdef CUTLASS_USE_SUBBYTE_WMMA +#ifdef CUTLASS_USE_INT_WMMA // Defines cutlass::gemm::Gemm, the generic Gemm computation template class. #include "cutlass/gemm/gemm.h" @@ -273,7 +275,7 @@ cudaError_t TestCutlassGemm(int M, int N, int K, int alpha, int beta) { // Passed error check return cudaSuccess; } -#endif // defined CUTLASS_USE_SUBBYTE_WMMA +#endif // defined CUTLASS_USE_INT_WMMA /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -285,7 +287,7 @@ cudaError_t TestCutlassGemm(int M, int N, int K, int alpha, int beta) { // int main(int argc, const char *arg[]) { -#ifdef CUTLASS_USE_SUBBYTE_WMMA +#ifdef CUTLASS_USE_INT_WMMA // Properties of CUDA device cudaDeviceProp device_properties; @@ -299,8 +301,8 @@ int main(int argc, const char *arg[]) { return -1; } - if ((device_properties.major * 10 + device_properties.minor) < 75) { - std::cerr << "This example needs to run on a Turing device." << std::endl; + if ((device_properties.major * 10 + device_properties.minor) < 72) { + std::cerr << "This example needs to run on a device which has at least 7.2 compute capability." << std::endl; return -1; } @@ -344,9 +346,9 @@ int main(int argc, const char *arg[]) { return result == cudaSuccess ? 0 : -1; #else - std::cerr << "CUTLASS WMMA GEMM targeting Turing Tensor Cores features requires CUDA 10." << std::endl; + std::cerr << "CUTLASS WMMA GEMM targeting Turing Tensor Cores features requires compute capability 7.2." << std::endl; return -1; -#endif // defined CUTLASS_USE_SUBBYTE_WMMA +#endif // defined CUTLASS_USE_INT_WMMA } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/06_splitK_gemm/CMakeLists.txt b/examples/06_splitK_gemm/CMakeLists.txt index 695a91b148..da1c99bed7 100644 --- a/examples/06_splitK_gemm/CMakeLists.txt +++ b/examples/06_splitK_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/examples/06_splitK_gemm/splitK_gemm.cu b/examples/06_splitK_gemm/splitK_gemm.cu index 20ea490ba1..820bc4fb49 100644 --- a/examples/06_splitK_gemm/splitK_gemm.cu +++ b/examples/06_splitK_gemm/splitK_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -84,11 +84,7 @@ cudaError_t cutlass_splitK_sgemm_nn(float const *A, typename deviceGemm::Params deviceGemmParams(m, n, k); // query if workspace is needed. the workspace size is sizeof(accumulateType) * M * N * splits_count - int workspace_size = deviceGemmParams.required_workspace_memory_in_byte(); - if (workspace_size <= 0) { - std::cerr << "splitK workspace_size is smaller than 0" << std::endl; - return cudaErrorInvalidValue; - } + size_t workspace_size = deviceGemmParams.required_workspace_memory_in_byte(); // allocate workspace memory float *workspace_ptr; diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index abc1e6ff2a..008c2a7723 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -27,3 +27,4 @@ add_subdirectory(03_strided_batched_gemm) add_subdirectory(04_tile_iterator) add_subdirectory(05_wmma_gemm) add_subdirectory(06_splitK_gemm) + diff --git a/media/images/cutlass-performance-plot.png b/media/images/cutlass-performance-plot.png index 0af79c5db1..041d28b3b9 100644 Binary files a/media/images/cutlass-performance-plot.png and b/media/images/cutlass-performance-plot.png differ diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index f14d9d42b9..a25a613007 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/tools/external/googletest b/tools/external/googletest index 9077ec7efe..2fe3bd994b 160000 --- a/tools/external/googletest +++ b/tools/external/googletest @@ -1 +1 @@ -Subproject commit 9077ec7efe5b652468ab051e93c67589d5cb8f85 +Subproject commit 2fe3bd994b3189899d93f1d5a881e725e046fdc2 diff --git a/tools/nvrtc/CMakeLists.txt b/tools/nvrtc/CMakeLists.txt index 12d8bf2fec..2eeb90d0c2 100644 --- a/tools/nvrtc/CMakeLists.txt +++ b/tools/nvrtc/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/tools/nvrtc/cutlass/nvrtc/environment.h b/tools/nvrtc/cutlass/nvrtc/environment.h index 310ac3b715..96bde4e8ba 100644 --- a/tools/nvrtc/cutlass/nvrtc/environment.h +++ b/tools/nvrtc/cutlass/nvrtc/environment.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/nvrtc/stdlib/stdint.h b/tools/nvrtc/stdlib/stdint.h index 8c0143987d..d066380e7e 100644 --- a/tools/nvrtc/stdlib/stdint.h +++ b/tools/nvrtc/stdlib/stdint.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/CMakeLists.txt b/tools/test/CMakeLists.txt index f782b6e998..ba9e0dfa0e 100644 --- a/tools/test/CMakeLists.txt +++ b/tools/test/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: diff --git a/tools/test/perf/CMakeLists.txt b/tools/test/perf/CMakeLists.txt index b5b54b5cff..b2afd38854 100644 --- a/tools/test/perf/CMakeLists.txt +++ b/tools/test/perf/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -32,6 +32,8 @@ set(CUTLASS_PERF_TEST_HEADERS gemm/cutlass_dispatch_splitK_PI.h gemm/gemm_perf_testbed.h gemm/gemm_profiler.h + gemm/cutlass_volta884_dispatch.h + gemm/cutlass_volta884_dispatch_splitK_PI.h ) set(CUTLASS_PERF_TEST_SOURCES @@ -43,8 +45,14 @@ set(CUTLASS_PERF_TEST_SOURCES gemm/igemm.cu gemm/igemm_splitK.cu gemm/wmma_gemm.cu - gemm/wmma_binary_gemm.cu - gemm/wmma_integer_gemm.cu + gemm/volta884_gemm.cu + gemm/volta884_gemm_splitK.cu + gemm/volta884_gemm_cta_rasterization_tn.cu + gemm/volta884_gemm_cta_rasterization_tt.cu + gemm/volta884_gemm_cta_rasterization_nn.cu + gemm/volta884_gemm_cta_rasterization_nt.cu + gemm/wmma_binary_gemm.cu + gemm/wmma_integer_gemm.cu ) source_group("Source\ Files" FILES ${CUTLASS_PERF_TEST_SOURCES}) @@ -62,5 +70,7 @@ cutlass_add_executable( ${CUTLASS_PERF_TEST_HEADERS} ) -target_link_libraries(cutlass_perf_test ${CUBLAS_LIBRARY}) +if(CUTLASS_ENABLE_CUBLAS) + target_link_libraries(cutlass_perf_test ${CUBLAS_LIBRARY}) +endif() diff --git a/tools/test/perf/cutlass_perf_test.cu b/tools/test/perf/cutlass_perf_test.cu index dee4c5afcf..f47c0d6d9a 100644 --- a/tools/test/perf/cutlass_perf_test.cu +++ b/tools/test/perf/cutlass_perf_test.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/perf/cutlass_perf_test.h b/tools/test/perf/cutlass_perf_test.h index 70320740e2..ababf33e9c 100644 --- a/tools/test/perf/cutlass_perf_test.h +++ b/tools/test/perf/cutlass_perf_test.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -32,6 +32,14 @@ #include "tools/test/perf/testbench_output.h" #include "tools/test/perf/gemm/gemm_profiler.h" +#if !defined(CUTLASS_ENABLE_CUBLAS) +#define CUTLASS_ENABLE_CUBLAS 0 +#endif + +#if !defined(CUTLASS_ENABLE_CUDNN) +#define CUTLASS_ENABLE_CUDNN 0 +#endif + namespace perf { typedef int (GemmProfileFunc)( diff --git a/tools/test/perf/gemm/cublas_dispatch.h b/tools/test/perf/gemm/cublas_dispatch.h index 8bad045254..56baed2649 100644 --- a/tools/test/perf/gemm/cublas_dispatch.h +++ b/tools/test/perf/gemm/cublas_dispatch.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -67,6 +67,7 @@ struct CublasGemmDispatch { CDeviceType *C, int ldc, cublasGemmAlgo_t algorithm) { + #if CUTLASS_ENABLE_CUBLAS return cublasGemmEx(handle, convert(layout_a), convert(layout_b), @@ -86,6 +87,9 @@ struct CublasGemmDispatch { ldc, cutlass::TypeTraits::cublas_type, algorithm); + #else + return CUBLAS_STATUS_NOT_SUPPORTED; + #endif } }; @@ -131,7 +135,7 @@ struct CublasBatchedStridedGemmDispatch { long long int batch_stride_C, int batch_count, cublasGemmAlgo_t algorithm) { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 9010 + #if CUTLASS_ENABLE_CUBLAS && defined(CUDA_VERSION) && CUDA_VERSION >= 9010 return cublasGemmStridedBatchedEx(handle, convert(layout_a), convert(layout_b), @@ -155,9 +159,9 @@ struct CublasBatchedStridedGemmDispatch { batch_count, cutlass::TypeTraits::cublas_type, algorithm); -#else + #else return CUBLAS_STATUS_NOT_SUPPORTED; -#endif + #endif } }; diff --git a/tools/test/perf/gemm/cutlass_dispatch.h b/tools/test/perf/gemm/cutlass_dispatch.h index 464dab4a6f..1f1ca87816 100644 --- a/tools/test/perf/gemm/cutlass_dispatch.h +++ b/tools/test/perf/gemm/cutlass_dispatch.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -33,7 +33,12 @@ template + #if CUTLASS_ENABLE_CUBLAS + bool RunCuBLAS_ = true + #else + bool RunCuBLAS_ = false + #endif +> struct CutlassDispatch { typedef typename Gemm_::Params Params; typedef Gemm_ Gemm; @@ -131,8 +136,6 @@ struct CutlassDispatchBasic { typedef typename Traits::ScalarC ScalarC; /// The scalar for D. typedef typename Traits::ScalarD ScalarD; - - // TODO - support alternative accumulator and scalar types typedef ScalarD Compute; typedef Compute ScalarEpilogue; diff --git a/tools/test/perf/gemm/cutlass_dispatch_splitK_PI.h b/tools/test/perf/gemm/cutlass_dispatch_splitK_PI.h index 262d39eb38..8ef5296402 100644 --- a/tools/test/perf/gemm/cutlass_dispatch_splitK_PI.h +++ b/tools/test/perf/gemm/cutlass_dispatch_splitK_PI.h @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -40,8 +40,13 @@ template - struct CutlassDispatchSplitKPIGemm { + #if CUTLASS_ENABLE_CUBLAS + bool RunCuBLAS_ = true + #else + bool RunCuBLAS_ = false + #endif +> +struct CutlassDispatchSplitKPIGemm { typedef typename KernelClass_::Params Params; typedef KernelClass_ KernelClass; typedef Index_ Index; @@ -87,8 +92,21 @@ template available_device_memory_in_byte) { + std::cout << "reqested workspace memory size("<< workspace_size_in_byte << + ") is larger than available memory size("<< available_device_memory_in_byte << "). Abort." << std::endl; + throw std::runtime_error("reqested workspace memory size is larger than available memory size. Abort."); + } + cudaError_t workspace_err = cudaMalloc(&workspace_ptr, workspace_size_in_byte); if (workspace_err != cudaSuccess) { std::cout << "\nCUDA workspace malloc error: " << cudaGetErrorString(workspace_err) @@ -153,8 +171,6 @@ struct CutlassDispatchSplitKPIGemmBasic { typedef typename Traits::ScalarC ScalarC; /// The scalar for D. typedef typename Traits::ScalarD ScalarD; - - // TODO - support alternative accumulator and scalar types typedef ScalarD Compute; typedef Compute ScalarEpilogue; diff --git a/tools/test/perf/gemm/cutlass_volta884_dispatch.h b/tools/test/perf/gemm/cutlass_volta884_dispatch.h new file mode 100644 index 0000000000..910a47370f --- /dev/null +++ b/tools/test/perf/gemm/cutlass_volta884_dispatch.h @@ -0,0 +1,113 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#pragma once + + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/volta884_gemm_traits.h" + +#include "tools/test/perf/cutlass_perf_test.h" +#include "tools/test/perf/gemm/gemm_profiler.h" +#include "tools/test/perf/gemm/cutlass_dispatch.h" +#include "tools/test/perf/gemm/gemm_perf_testbed.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Volta884GemmDispatch { + + typedef cutlass::gemm::Gemm Gemm; + + typedef typename Gemm::Params Params; + + typedef typename Traits::ScalarC ScalarC; + typedef typename Traits::ScalarD ScalarD; + typedef typename Traits::ScalarD ScalarEpilogue; + + /// Indicate warp-level GEMM + static bool const kThreadMultiplyAdd = false; + + #if CUTLASS_ENABLE_CUBLAS + static bool const kRunCuBLAS = true; + #else + static bool const kRunCuBLAS = false; + #endif + + static cutlass::MatrixLayout::Kind const kLayoutA = Traits::kLayoutA; + static cutlass::MatrixLayout::Kind const kLayoutB = Traits::kLayoutB; + + // + // Data members + // + + /// Params argument + Params params; + + // + // Methods + // + + Volta884GemmDispatch() {} + + /// Initializes params object + Volta884GemmDispatch(int m, int n, int k, ScalarEpilogue alpha, half const* d_a, int lda, + half const* d_b, int ldb, ScalarEpilogue beta, ScalarC const* d_c, int ldc, + ScalarD* d_d, int ldd) { + + params.initialize(m, n, k, alpha, d_a, lda, d_b, ldb, beta, d_c, ldc, d_d, ldd); + } + + Volta884GemmDispatch(int m, + int n, + int k, + ScalarEpilogue alpha, + half const* d_a, + int lda, + long long int batch_stride_A, + half const* d_b, + int ldb, + long long int batch_stride_B, + ScalarEpilogue beta, + ScalarC const* d_c, + int ldc, + long long int batch_stride_C, + ScalarD* d_d, + int ldd, + long long int batch_stride_D, + int batch_count) { + assert(0);//not yet supported + } + + /// Initializes params object + Volta884GemmDispatch(Params const& _params) : params(_params) {} + + /// Launches kernel + cudaError_t operator()() { return Gemm::launch(params); } +}; diff --git a/tools/test/perf/gemm/cutlass_volta884_dispatch_splitK_PI.h b/tools/test/perf/gemm/cutlass_volta884_dispatch_splitK_PI.h new file mode 100644 index 0000000000..a7f1de8a8e --- /dev/null +++ b/tools/test/perf/gemm/cutlass_volta884_dispatch_splitK_PI.h @@ -0,0 +1,139 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#pragma once + + +#include "cutlass/gemm/device_gemm.h" +#include "cutlass/gemm/volta884_gemm_traits.h" + +#include "tools/test/perf/cutlass_perf_test.h" +#include "tools/test/perf/gemm/gemm_profiler.h" +#include "tools/test/perf/gemm/cutlass_dispatch.h" +#include "tools/test/perf/gemm/gemm_perf_testbed.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Volta884GemmDispatchSplitKPI { + + typedef cutlass::gemm::DeviceGemm Gemm; + + typedef typename Gemm::Params Params; + + typedef typename Traits::ScalarC ScalarC; + typedef typename Traits::ScalarD ScalarD; + typedef typename Traits::Scalar ScalarEpilogue; + + /// Indicate warp-level GEMM + static bool const kThreadMultiplyAdd = false; + + #if CUTLASS_ENABLE_CUBLAS + static bool const kRunCuBLAS = true; + #else + static bool const kRunCuBLAS = false; + #endif + + static cutlass::MatrixLayout::Kind const kLayoutA = Traits::kLayoutA; + static cutlass::MatrixLayout::Kind const kLayoutB = Traits::kLayoutB; + + // + // Data members + // + + /// Params argument + Params params; + + /// splitK PI require workspace + typename cutlass::TypeTraits::device_type *workspace_ptr; + + // + // Methods + // + + Volta884GemmDispatchSplitKPI() {} + + /// Initializes params object + Volta884GemmDispatchSplitKPI(int m, int n, int k, ScalarEpilogue alpha, half const* d_a, int lda, + half const* d_b, int ldb, ScalarEpilogue beta, ScalarC const* d_c, int ldc, + ScalarD* d_d, int ldd) { + params.init_problem(m, n, k); + size_t workspace_size_in_byte = params.required_workspace_memory_in_byte(); + size_t available_device_memory_in_byte = 0; + size_t device_memory_in_byte = 0; + cudaError_t cudaMemGetInfo_err = cudaMemGetInfo(&available_device_memory_in_byte, &device_memory_in_byte); + if (cudaMemGetInfo_err != cudaSuccess) { + std::cout << "\ncudaMemGetInfo error: " << cudaGetErrorString(cudaMemGetInfo_err) + << "\n"; + } + + if (workspace_size_in_byte > available_device_memory_in_byte) { + std::cout << "reqested workspace memory size(" << workspace_size_in_byte << + ") is larger than available memory size(" << available_device_memory_in_byte << "). Abort." << std::endl; + throw std::runtime_error("reqested workspace memory size is larger than available memory size. Abort."); + } + + cudaError_t workspace_err = cudaMalloc(&workspace_ptr, workspace_size_in_byte); + if (workspace_err != cudaSuccess) { + std::cout << "\nCUDA workspace malloc error: " << cudaGetErrorString(workspace_err) + << "\n"; + } + + params.initialize(alpha, d_a, lda, d_b, ldb, beta, d_c, ldc, d_d, ldd, workspace_ptr, 8 /*volta884 requires leading dim to be mulitiple of 8*/); + } + + Volta884GemmDispatchSplitKPI(int m, + int n, + int k, + ScalarEpilogue alpha, + half const* d_a, + int lda, + long long int batch_stride_A, + half const* d_b, + int ldb, + long long int batch_stride_B, + ScalarEpilogue beta, + ScalarC const* d_c, + int ldc, + long long int batch_stride_C, + ScalarD* d_d, + int ldd, + long long int batch_stride_D, + int batch_count) { + assert(0);//not yet supported + } + + /// Initializes params object + Volta884GemmDispatchSplitKPI(Params const& _params) : params(_params) {} + + /// Launches kernel + cudaError_t operator()() { + return Gemm::launch(params); + } +}; diff --git a/tools/test/perf/gemm/dgemm.cu b/tools/test/perf/gemm/dgemm.cu index 3f4b63b851..d435397a05 100644 --- a/tools/test/perf/gemm/dgemm.cu +++ b/tools/test/perf/gemm/dgemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/perf/gemm/gemm_perf_testbed.h b/tools/test/perf/gemm/gemm_perf_testbed.h index 81ba51e1c9..0852b7053f 100644 --- a/tools/test/perf/gemm/gemm_perf_testbed.h +++ b/tools/test/perf/gemm/gemm_perf_testbed.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -284,10 +284,14 @@ class GemmTestbed { /// Constructs a basic workspace GemmTestbed(InitialDistribution const &_dist = InitialDistribution()) : initial_distribution(_dist) { + #if CUTLASS_ENABLE_CUBLAS status = cublasCreate(&handle); if (status != CUBLAS_STATUS_SUCCESS) { throw cutlass::cuda_exception("Failed to create CUBLAS handle"); } + #else + status = CUBLAS_STATUS_NOT_INITIALIZED; + #endif } /// Constructs a workspace for verifying GEMM, assumes @@ -296,15 +300,26 @@ class GemmTestbed { cublasGemmAlgo_t algorithm_ = CUBLAS_GEMM_DEFAULT, InitialDistribution const &_dist = InitialDistribution()) : problem(_problem), initial_distribution(_dist) { + #if CUTLASS_ENABLE_CUBLAS status = cublasCreate(&handle); if (status != CUBLAS_STATUS_SUCCESS) { throw cutlass::cuda_exception("Failed to create CUBLAS handle"); } + #else + status = CUBLAS_STATUS_NOT_INITIALIZED; + #endif resize(problem); } - ~GemmTestbed() { status = cublasDestroy(handle); } + /// Destructs the GEMM testbed + ~GemmTestbed() { + #if CUTLASS_ENABLE_CUBLAS + if (status != CUBLAS_STATUS_NOT_INITIALIZED) { + status = cublasDestroy(handle); + } + #endif + } /// Returns true if the last CUBLAS call returned successfully bool good() const { return status == CUBLAS_STATUS_SUCCESS; } @@ -388,6 +403,7 @@ class GemmTestbed { /// Launches the cuBLAS GEMM - does not initialize output matrix cublasStatus_t launch_cublas(cublasGemmAlgo_t algo) { + #if CUTLASS_ENABLE_CUBLAS if (problem.batch_count == 1) { CublasDispatch dispatch; @@ -441,6 +457,9 @@ class GemmTestbed { return status; } + #else + return CUBLAS_STATUS_NOT_SUPPORTED; + #endif } /// Verifies the 'test' tensor with 'ref' diff --git a/tools/test/perf/gemm/gemm_profiler.h b/tools/test/perf/gemm/gemm_profiler.h index 82d4151439..e45024e37b 100644 --- a/tools/test/perf/gemm/gemm_profiler.h +++ b/tools/test/perf/gemm/gemm_profiler.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -146,22 +146,23 @@ class GemmProfiler { , kernel_name , problem ); - + + result.disposition = Disposition::NotVerified; + if (options.dry_run) { result.disposition = Disposition::NotRun; return result; } if (CutlassDispatch::kRunCuBLAS) { +#if CUTLASS_ENABLE_CUBLAS testbed.compute_reference(algorithm); if (cudaDeviceSynchronize() != cudaSuccess) { result.disposition = Disposition::NotVerified; return result; } - } - else { - result.disposition = Disposition::Passed; +#endif } CutlassDispatch *dispatch_ptr; @@ -214,11 +215,13 @@ class GemmProfiler { } if (CutlassDispatch::kRunCuBLAS) { +#if CUTLASS_ENABLE_CUBLAS if (testbed.verify_with_reference()) { result.disposition = Disposition::Passed; } else { result.disposition = Disposition::Incorrect; } +#endif } if (options.save_workspace(result.disposition == Disposition::Passed)) { @@ -270,11 +273,34 @@ class GemmProfiler { result.runtime = double(average_ms) / double(options.iterations); result.gflops = testbed.GFLOPs_per_sec(result.runtime); - if (result.disposition != Disposition::Passed) { - std::cout << "[\033[1;31mFAILED\033[0m]: " << kernel_name - << " failed with disposition: " << result.disposition << "\n"; + if (result.disposition == Disposition::Unknown) { + std::cout << "[\033[1;30mUnknown\033[0m]: " << kernel_name + << " with disposition: " << result.disposition << "\n"; + } + if (result.disposition == Disposition::NotRun) { + std::cout << "[\033[1;33mNotRun\033[0m]: " << kernel_name + << " with disposition: " << result.disposition << "\n"; + } + if (result.disposition == Disposition::Passed) { + std::cout << "[\033[1;32mPassed\033[0m]: " << kernel_name + << " with disposition: " << result.disposition << "\n"; + } + if (result.disposition == Disposition::Incorrect) { + std::cout << "[\033[1;31mIncorrect\033[0m]: " << kernel_name + << " with disposition: " << result.disposition << "\n"; + } + if (result.disposition == Disposition::Failed) { + std::cout << "[\033[1;31mFailed\033[0m]: " << kernel_name + << " with disposition: " << result.disposition << "\n"; + } + if (result.disposition == Disposition::NotVerified) { + std::cout << "[\033[1;34mNotVerified\033[0m]: " << kernel_name + << " with disposition: " << result.disposition << "\n"; + } + if (result.disposition == Disposition::Invalid) { + std::cout << "[\033[1;36mInvalid\033[0m]: " << kernel_name + << " with disposition: " << result.disposition << "\n"; } - delete dispatch_ptr; return result; } @@ -299,7 +325,7 @@ class GemmProfiler { std::vector > results; - results.push_back(execute_cutlass(problem, algorithm)); + results.push_back(execute_cutlass(problem, algorithm)); // cool-down period if (!options.dry_run) { pause(options.sleep_time); @@ -402,10 +428,10 @@ int profile_gemm(TestbenchOutput &output, GemmProfiler perf(output, kernel, cutlass_algo, options, config); if (options.peak_performance) { perf.template peak( - config.problem_range.M, config.problem_range.N, config.problem_range.K); + config.gemm_problem_range.M, config.gemm_problem_range.N, config.gemm_problem_range.K); } else { perf.template schmoo( - config.problem_range.M, config.problem_range.N, config.problem_range.K, config.problem_range.batch_count); + config.gemm_problem_range.M, config.gemm_problem_range.N, config.gemm_problem_range.K, config.gemm_problem_range.batch_count); } } diff --git a/tools/test/perf/gemm/hgemm.cu b/tools/test/perf/gemm/hgemm.cu index 5b47e66dd4..da69c5762c 100644 --- a/tools/test/perf/gemm/hgemm.cu +++ b/tools/test/perf/gemm/hgemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/perf/gemm/igemm.cu b/tools/test/perf/gemm/igemm.cu index 24d721a91a..a86d6778fd 100644 --- a/tools/test/perf/gemm/igemm.cu +++ b/tools/test/perf/gemm/igemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -23,6 +23,8 @@ * **************************************************************************************************/ +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) + #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/igemm_traits.h" #include "tools/test/perf/cutlass_perf_test.h" @@ -36,6 +38,7 @@ namespace perf { //////////////////////////////////////////////////////////////////////////////////////////////////// +template int profile_igemm(TestbenchOutput &output, TestbenchOptions const &options, Config const &config) { typedef perf::GemmProfiler GemmProfiler; @@ -91,6 +94,21 @@ int profile_igemm(TestbenchOutput &output, TestbenchOptions const & results |= profile_gemm(output, "igemm_tt", options, config); } + return results; +} + +template +int profile_igemm_32x32x128(TestbenchOutput &output, TestbenchOptions const &options, Config const &config) { + + typedef perf::GemmProfiler GemmProfiler; + + // compute capability check + if (!options.compute_capability(6, 1)) { + return 0; + } + + int results = 0; + { typedef cutlass::gemm::IgemmTraits, int, @@ -138,8 +156,18 @@ int profile_igemm(TestbenchOutput &output, TestbenchOptions const & return results; } + + struct IgemmRegistrar { - IgemmRegistrar() { RegisterGemmProfileFunc(profile_igemm); } + IgemmRegistrar() + { + RegisterGemmProfileFunc(profile_igemm); + +#ifdef EXHAUSTIVE_PROF + RegisterGemmProfileFunc(profile_igemm_32x32x128); +#endif // defined EXHAUSTIVE_PROF + + } }; volatile IgemmRegistrar _IgemmRegistrar; @@ -147,3 +175,5 @@ volatile IgemmRegistrar _IgemmRegistrar; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace perf + +#endif // if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) diff --git a/tools/test/perf/gemm/igemm_splitK.cu b/tools/test/perf/gemm/igemm_splitK.cu index abec11525b..507d8188e4 100644 --- a/tools/test/perf/gemm/igemm_splitK.cu +++ b/tools/test/perf/gemm/igemm_splitK.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -23,6 +23,8 @@ * **************************************************************************************************/ +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) + #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/igemm_traits.h" #include "cutlass/reduction/batched_reduction_traits.h" @@ -154,7 +156,6 @@ int profile_igemm_splitkpi_kernel( results |= profile_gemm(output, name + "_tt", options, config, algo + "_splitk_pi"); } - return results; } @@ -200,3 +201,5 @@ volatile IgemmSplitKPIRegistrar _IgemmSplitKPIRegistrar; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace perf + +#endif // if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) diff --git a/tools/test/perf/gemm/sgemm.cu b/tools/test/perf/gemm/sgemm.cu index c83e874841..78a19748e2 100644 --- a/tools/test/perf/gemm/sgemm.cu +++ b/tools/test/perf/gemm/sgemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -38,8 +38,10 @@ namespace perf { //////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Profile simple gemm kernels template -int profile_sgemm_kernel( +int profile_simple_sgemm_kernel( TestbenchOutput &output, TestbenchOptions const &options, Config const &config, @@ -98,6 +100,24 @@ int profile_sgemm_kernel( results |= profile_gemm(output, name + "_tt", options, config, algo); } + return results; +} + + + +/// Profile swizzle-raster gemm kernels +template +int profile_swizzle_sgemm_kernel( + TestbenchOutput &output, + TestbenchOptions const &options, + Config const &config, + std::string const &name, + std::string const &algo) { + + typedef perf::GemmProfiler SGemmProfiler; + + int results = 0; + { typedef int index; typedef cutlass::gemm::SgemmConfig &output, TestbenchOptions const &options, Config const &config) { int results = 0; - results |= profile_sgemm_kernel >(output, options, config, "sgemm", "128x128"); + results |= profile_simple_sgemm_kernel >(output, options, config, "sgemm", "128x128"); + +#ifdef EXHAUSTIVE_PROF + results |= profile_swizzle_sgemm_kernel >(output, options, config, "sgemm", "128x128"); +#endif // defined EXHAUSTIVE_PROF return results; } diff --git a/tools/test/perf/gemm/sgemm_splitK.cu b/tools/test/perf/gemm/sgemm_splitK.cu index 936d519fad..f238e71897 100644 --- a/tools/test/perf/gemm/sgemm_splitK.cu +++ b/tools/test/perf/gemm/sgemm_splitK.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -38,7 +38,7 @@ namespace perf { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template int profile_sgemm_splitkpi_kernel( TestbenchOutput &output, TestbenchOptions const &options, @@ -53,7 +53,8 @@ int profile_sgemm_splitkpi_kernel( { /*batched sgemm traits*/ typedef cutlass::gemm::SgemmTraits + cutlass::MatrixLayout::kColumnMajor, OutputTile, + cutlass::gemm::LinearScaling, threadGemmShape> SgemmTraits; /*batched reduction traits*/ typedef cutlass::reduction::BatchedReductionTraits, cutlass::Shape<1, 1, 64>, - cutlass::Shape<1, 1, 2> > + threadReductionShape > BatchedReductionTraits; // create a device gemm @@ -77,7 +78,8 @@ int profile_sgemm_splitkpi_kernel( { /*batched sgemm traits*/ typedef cutlass::gemm::SgemmTraits + cutlass::MatrixLayout::kRowMajor, OutputTile, + cutlass::gemm::LinearScaling, threadGemmShape> SgemmTraits; /*batched reduction traits*/ typedef cutlass::reduction::BatchedReductionTraits, cutlass::Shape<1, 1, 64>, - cutlass::Shape<1, 1, 2> > + threadReductionShape > BatchedReductionTraits; // create a device gemm @@ -101,7 +103,8 @@ int profile_sgemm_splitkpi_kernel( { /*batched sgemm traits*/ typedef cutlass::gemm::SgemmTraits + cutlass::MatrixLayout::kColumnMajor, OutputTile, + cutlass::gemm::LinearScaling, threadGemmShape> SgemmTraits; /*batched reduction traits*/ typedef cutlass::reduction::BatchedReductionTraits, cutlass::Shape<1, 1, 64>, - cutlass::Shape<1, 1, 2> > + threadReductionShape > BatchedReductionTraits; // create a device gemm @@ -125,7 +128,8 @@ int profile_sgemm_splitkpi_kernel( { /*batched sgemm traits*/ typedef cutlass::gemm::SgemmTraits + cutlass::MatrixLayout::kRowMajor, OutputTile, + cutlass::gemm::LinearScaling, threadGemmShape> SgemmTraits; /*batched reduction traits*/ typedef cutlass::reduction::BatchedReductionTraits, cutlass::Shape<1, 1, 64>, - cutlass::Shape<1, 1, 2> > + threadReductionShape > BatchedReductionTraits; // create a device gemm @@ -153,25 +157,143 @@ int profile_sgemm_splitkpi_kernel( /// Profiles all SGEMM tile sizes int profile_sgemm_splitkpi(TestbenchOutput &output, TestbenchOptions const &options, Config const &config) { int results = 0; + /*128x128x8*/ + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 5 >(output, options, config, "sgemm_128x128x8_splitk_pi_split5", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 8 >(output, options, config, "sgemm_128x128x8_splitk_pi_split8", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 10 >(output, options, config, "sgemm_128x128x8_splitk_pi_split10", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 16 >(output, options, config, "sgemm_128x128x8_splitk_pi_split16", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 20 >(output, options, config, "sgemm_128x128x8_splitk_pi_split20", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 24 >(output, options, config, "sgemm_128x128x8_splitk_pi_split24", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 28 >(output, options, config, "sgemm_128x128x8_splitk_pi_split28", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 32 >(output, options, config, "sgemm_128x128x8_splitk_pi_split32", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 40 >(output, options, config, "sgemm_128x128x8_splitk_pi_split40", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 48 >(output, options, config, "sgemm_128x128x8_splitk_pi_split48", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 56 >(output, options, config, "sgemm_128x128x8_splitk_pi_split56", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 64 >(output, options, config, "sgemm_128x128x8_splitk_pi_split64", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 72 >(output, options, config, "sgemm_128x128x8_splitk_pi_split72", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 80 >(output, options, config, "sgemm_128x128x8_splitk_pi_split80", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 88 >(output, options, config, "sgemm_128x128x8_splitk_pi_split88", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 96 >(output, options, config, "sgemm_128x128x8_splitk_pi_split96", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 104 >(output, options, config, "sgemm_128x128x8_splitk_pi_split104", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 112 >(output, options, config, "sgemm_128x128x8_splitk_pi_split112", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 120 >(output, options, config, "sgemm_128x128x8_splitk_pi_split120", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 128 >(output, options, config, "sgemm_128x128x8_splitk_pi_split128", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 136 >(output, options, config, "sgemm_128x128x8_splitk_pi_split136", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 144 >(output, options, config, "sgemm_128x128x8_splitk_pi_split144", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 152 >(output, options, config, "sgemm_128x128x8_splitk_pi_split152", "128x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 160 >(output, options, config, "sgemm_128x128x8_splitk_pi_split160", "128x128"); - results |= profile_sgemm_splitkpi_kernel, 32 >(output, options, config, "sgemm_128x128x8_splitk_pi_split32", "128x128"); - +#ifdef EXHAUSTIVE_PROF /*128x64x8*/ - results |= profile_sgemm_splitkpi_kernel, 8 >(output, options, config, "sgemm_128x64x8_splitk_pi_split8", "128x64"); - results |= profile_sgemm_splitkpi_kernel, 16 >(output, options, config, "sgemm_128x64x8_splitk_pi_split16", "128x64"); - results |= profile_sgemm_splitkpi_kernel, 20 >(output, options, config, "sgemm_128x64x8_splitk_pi_split20", "128x64"); - results |= profile_sgemm_splitkpi_kernel, 24 >(output, options, config, "sgemm_128x64x8_splitk_pi_split24", "128x64"); - results |= profile_sgemm_splitkpi_kernel, 28 >(output, options, config, "sgemm_128x64x8_splitk_pi_split28", "128x64"); - results |= profile_sgemm_splitkpi_kernel, 32 >(output, options, config, "sgemm_128x64x8_splitk_pi_split32", "128x64"); - results |= profile_sgemm_splitkpi_kernel, 64 >(output, options, config, "sgemm_128x64x8_splitk_pi_split64", "128x64"); - /*128x32x8*/ - results |= profile_sgemm_splitkpi_kernel, 8 >(output, options, config, "sgemm_128x32x8_splitk_pi_split8", "128x32"); - results |= profile_sgemm_splitkpi_kernel, 16 >(output, options, config, "sgemm_128x32x8_splitk_pi_split16", "128x32"); - results |= profile_sgemm_splitkpi_kernel, 20 >(output, options, config, "sgemm_128x32x8_splitk_pi_split20", "128x32"); - results |= profile_sgemm_splitkpi_kernel, 24 >(output, options, config, "sgemm_128x32x8_splitk_pi_split24", "128x32"); - results |= profile_sgemm_splitkpi_kernel, 28 >(output, options, config, "sgemm_128x32x8_splitk_pi_split28", "128x32"); - results |= profile_sgemm_splitkpi_kernel, 32 >(output, options, config, "sgemm_128x32x8_splitk_pi_split32", "128x32"); - results |= profile_sgemm_splitkpi_kernel, 64 >(output, options, config, "sgemm_128x32x8_splitk_pi_split64", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 5 >(output, options, config, "sgemm_128x64x8_splitk_pi_split5", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 8 >(output, options, config, "sgemm_128x64x8_splitk_pi_split8", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 10 >(output, options, config, "sgemm_128x64x8_splitk_pi_split10", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 16 >(output, options, config, "sgemm_128x64x8_splitk_pi_split16", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 20 >(output, options, config, "sgemm_128x64x8_splitk_pi_split20", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 24 >(output, options, config, "sgemm_128x64x8_splitk_pi_split24", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 28 >(output, options, config, "sgemm_128x64x8_splitk_pi_split28", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 32 >(output, options, config, "sgemm_128x64x8_splitk_pi_split32", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 40 >(output, options, config, "sgemm_128x64x8_splitk_pi_split40", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 48 >(output, options, config, "sgemm_128x64x8_splitk_pi_split48", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 56 >(output, options, config, "sgemm_128x64x8_splitk_pi_split56", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 64 >(output, options, config, "sgemm_128x64x8_splitk_pi_split64", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 72 >(output, options, config, "sgemm_128x64x8_splitk_pi_split72", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 80 >(output, options, config, "sgemm_128x64x8_splitk_pi_split80", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 88 >(output, options, config, "sgemm_128x64x8_splitk_pi_split88", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 96 >(output, options, config, "sgemm_128x64x8_splitk_pi_split96", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 104 >(output, options, config, "sgemm_128x64x8_splitk_pi_split104", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 112 >(output, options, config, "sgemm_128x64x8_splitk_pi_split112", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 120 >(output, options, config, "sgemm_128x64x8_splitk_pi_split120", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 128 >(output, options, config, "sgemm_128x64x8_splitk_pi_split128", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 136 >(output, options, config, "sgemm_128x64x8_splitk_pi_split136", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 144 >(output, options, config, "sgemm_128x64x8_splitk_pi_split144", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 152 >(output, options, config, "sgemm_128x64x8_splitk_pi_split152", "128x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 160 >(output, options, config, "sgemm_128x64x8_splitk_pi_split160", "128x64"); + + /*128x32x8*/ + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 5 >(output, options, config, "sgemm_128x32x8_splitk_pi_split5", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 8 >(output, options, config, "sgemm_128x32x8_splitk_pi_split8", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 10 >(output, options, config, "sgemm_128x32x8_splitk_pi_split10", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 16 >(output, options, config, "sgemm_128x32x8_splitk_pi_split16", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 20 >(output, options, config, "sgemm_128x32x8_splitk_pi_split20", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 24 >(output, options, config, "sgemm_128x32x8_splitk_pi_split24", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 28 >(output, options, config, "sgemm_128x32x8_splitk_pi_split28", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 32 >(output, options, config, "sgemm_128x32x8_splitk_pi_split32", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 40 >(output, options, config, "sgemm_128x32x8_splitk_pi_split40", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 48 >(output, options, config, "sgemm_128x32x8_splitk_pi_split48", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 56 >(output, options, config, "sgemm_128x32x8_splitk_pi_split56", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 64 >(output, options, config, "sgemm_128x32x8_splitk_pi_split64", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 72 >(output, options, config, "sgemm_128x32x8_splitk_pi_split72", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 80 >(output, options, config, "sgemm_128x32x8_splitk_pi_split80", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 88 >(output, options, config, "sgemm_128x32x8_splitk_pi_split88", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 96 >(output, options, config, "sgemm_128x32x8_splitk_pi_split96", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 104 >(output, options, config, "sgemm_128x32x8_splitk_pi_split104", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 112 >(output, options, config, "sgemm_128x32x8_splitk_pi_split112", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 120 >(output, options, config, "sgemm_128x32x8_splitk_pi_split120", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 128 >(output, options, config, "sgemm_128x32x8_splitk_pi_split128", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 136 >(output, options, config, "sgemm_128x32x8_splitk_pi_split136", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 144 >(output, options, config, "sgemm_128x32x8_splitk_pi_split144", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 152 >(output, options, config, "sgemm_128x32x8_splitk_pi_split152", "128x32"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 160 >(output, options, config, "sgemm_128x32x8_splitk_pi_split160", "128x32"); + + /*64x128*/ + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 5 >(output, options, config, "sgemm_64x128x8_splitk_pi_split5", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 8 >(output, options, config, "sgemm_64x128x8_splitk_pi_split8", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 10 >(output, options, config, "sgemm_64x128x8_splitk_pi_split10", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 16 >(output, options, config, "sgemm_64x128x8_splitk_pi_split16", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 20 >(output, options, config, "sgemm_64x128x8_splitk_pi_split20", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 24 >(output, options, config, "sgemm_64x128x8_splitk_pi_split24", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 28 >(output, options, config, "sgemm_64x128x8_splitk_pi_split28", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 32 >(output, options, config, "sgemm_64x128x8_splitk_pi_split32", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 40 >(output, options, config, "sgemm_64x128x8_splitk_pi_split40", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 48 >(output, options, config, "sgemm_64x128x8_splitk_pi_split48", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 56 >(output, options, config, "sgemm_64x128x8_splitk_pi_split56", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 64 >(output, options, config, "sgemm_64x128x8_splitk_pi_split64", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 72 >(output, options, config, "sgemm_64x128x8_splitk_pi_split72", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 80 >(output, options, config, "sgemm_64x128x8_splitk_pi_split80", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 88 >(output, options, config, "sgemm_64x128x8_splitk_pi_split88", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 96 >(output, options, config, "sgemm_64x128x8_splitk_pi_split96", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 104 >(output, options, config, "sgemm_64x128x8_splitk_pi_split104", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 112 >(output, options, config, "sgemm_64x128x8_splitk_pi_split112", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 120 >(output, options, config, "sgemm_64x128x8_splitk_pi_split120", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 128 >(output, options, config, "sgemm_64x128x8_splitk_pi_split128", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 136 >(output, options, config, "sgemm_64x128x8_splitk_pi_split136", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 144 >(output, options, config, "sgemm_64x128x8_splitk_pi_split144", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 152 >(output, options, config, "sgemm_64x128x8_splitk_pi_split152", "64x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 1>, 160 >(output, options, config, "sgemm_64x128x8_splitk_pi_split160", "64x128"); + + /*32x128*/ + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 2>, 5 >(output, options, config, "sgemm_32x128x8_splitk_pi_split5", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 2>, 8 >(output, options, config, "sgemm_32x128x8_splitk_pi_split8", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 2>, 10 >(output, options, config, "sgemm_32x128x8_splitk_pi_split10", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 2>, 16 >(output, options, config, "sgemm_32x128x8_splitk_pi_split16", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 2>, 20 >(output, options, config, "sgemm_32x128x8_splitk_pi_split20", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 2>, 24 >(output, options, config, "sgemm_32x128x8_splitk_pi_split24", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 2>, 28 >(output, options, config, "sgemm_32x128x8_splitk_pi_split28", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 2>, 32 >(output, options, config, "sgemm_32x128x8_splitk_pi_split32", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 2>, 40 >(output, options, config, "sgemm_32x128x8_splitk_pi_split40", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 2>, 48 >(output, options, config, "sgemm_32x128x8_splitk_pi_split48", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 2>, 56 >(output, options, config, "sgemm_32x128x8_splitk_pi_split56", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 2>, 64 >(output, options, config, "sgemm_32x128x8_splitk_pi_split64", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 2>, 72 >(output, options, config, "sgemm_32x128x8_splitk_pi_split72", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 2>, 80 >(output, options, config, "sgemm_32x128x8_splitk_pi_split80", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 1>, 88 >(output, options, config, "sgemm_32x128x8_splitk_pi_split88", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 1>, 96 >(output, options, config, "sgemm_32x128x8_splitk_pi_split96", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 1>, 104 >(output, options, config, "sgemm_32x128x8_splitk_pi_split104", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 1>, 112 >(output, options, config, "sgemm_32x128x8_splitk_pi_split112", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 1>, 120 >(output, options, config, "sgemm_32x128x8_splitk_pi_split120", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 1>, 128 >(output, options, config, "sgemm_32x128x8_splitk_pi_split128", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 1>, 136 >(output, options, config, "sgemm_32x128x8_splitk_pi_split136", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 1>, 144 >(output, options, config, "sgemm_32x128x8_splitk_pi_split144", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 1>, 152 >(output, options, config, "sgemm_32x128x8_splitk_pi_split152", "32x128"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 4>, cutlass::Shape<1, 1, 1>, 160 >(output, options, config, "sgemm_32x128x8_splitk_pi_split160", "32x128"); + + /*64x64*/ + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 5 >(output, options, config, "sgemm_64x64x8_splitk_pi_split5", "64x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 8 >(output, options, config, "sgemm_64x64x8_splitk_pi_split8", "64x64"); + results |= profile_sgemm_splitkpi_kernel, cutlass::Shape<8, 8, 8>, cutlass::Shape<1, 1, 2>, 10 >(output, options, config, "sgemm_64x64x8_splitk_pi_split10", "64x64"); + +#endif //#ifdef EXHAUSTIVE_PROF return results; } diff --git a/tools/test/perf/gemm/volta884_gemm.cu b/tools/test/perf/gemm/volta884_gemm.cu new file mode 100644 index 0000000000..b00f5d2dfd --- /dev/null +++ b/tools/test/perf/gemm/volta884_gemm.cu @@ -0,0 +1,183 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "tools/test/perf/gemm/cutlass_volta884_dispatch.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace perf { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +int profile_volta884_gemm_kernel( + TestbenchOutput &output, + TestbenchOptions const &options, + Config const &config, + std::string const &name, + std::string const &algo) { + + int results = 0; + + // compute capability check + if (!options.compute_capability(7, 0)) { + return 0; + } + + typedef typename cutlass::TypeTraits::device_type AccumDevType; + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + typedef perf::GemmProfiler< + cutlass::half_t, + cutlass::half_t, + AccumHostType, + AccumHostType, + AccumHostType> GemmProfiler; + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2 + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_nn", options, config, algo); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2 + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_nt", options, config, algo); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2 + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_tn", options, config, algo); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + cutlass::gemm::IdentityBlockSwizzle, + true + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_tt", options, config, algo); + } + + #endif // if defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) + + return results; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int profile_volta884_gemm(TestbenchOutput &output, TestbenchOptions const &options, Config const &config) { + int results = 0; + + + + + results |= profile_volta884_gemm_kernel, float >(output, options, config, "s884gemm", "128x128"); + + results |= profile_volta884_gemm_kernel, cutlass::half_t >(output, options, config, "h884gemm", "128x128"); + +#ifdef EXHAUSTIVE_PROF + results |= profile_volta884_gemm_kernel, float >(output, options, config, "s884gemm_256x128", "256x128"); + + results |= profile_volta884_gemm_kernel, float >(output, options, config, "s884gemm_128x64", "128x64"); + + results |= profile_volta884_gemm_kernel, float >(output, options, config, "s884gemm_64x64", "64x64"); + + results |= profile_volta884_gemm_kernel, cutlass::half_t >(output, options, config, "h884gemm_256x128", "256x128"); + + results |= profile_volta884_gemm_kernel, cutlass::half_t >(output, options, config, "h884gemm_128x64", "128x64"); + + results |= profile_volta884_gemm_kernel, cutlass::half_t >(output, options, config, "h884gemm_64x64", "64x64"); +#endif // defined EXHAUSTIVE_PROF + + return results; +} + +struct Volta884GemmRegistrar { + Volta884GemmRegistrar() { RegisterGemmProfileFunc(profile_volta884_gemm); } +}; + +volatile Volta884GemmRegistrar _Volta884GemmRegistrar; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace perf + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/perf/gemm/volta884_gemm_cta_rasterization_nn.cu b/tools/test/perf/gemm/volta884_gemm_cta_rasterization_nn.cu new file mode 100644 index 0000000000..f3fb6b40b6 --- /dev/null +++ b/tools/test/perf/gemm/volta884_gemm_cta_rasterization_nn.cu @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "tools/test/perf/gemm/cutlass_volta884_dispatch.h" + +#ifdef EXHAUSTIVE_PROF + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace perf { +template +int profile_volta884_gemm_cta_rasterization_nn_kernel( + TestbenchOutput &output, + TestbenchOptions const &options, + Config const &config, + std::string const &name, + std::string const &algo) { + + int results = 0; + + // compute capability check + if (!options.compute_capability(7, 0)) { + return 0; + } + + typedef typename cutlass::TypeTraits::device_type AccumDevType; + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + typedef perf::GemmProfiler< + cutlass::half_t, + cutlass::half_t, + AccumHostType, + AccumHostType, + AccumHostType> GemmProfiler; + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_1_one_nn", options, config, algo + "_row_1_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_1_B_nn", options, config, algo + "_row_1_B"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_2_one_nn", options, config, algo + "_row_2_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_2_B_nn", options, config, algo + "_row_2_B"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_1_one_nn", options, config, algo + "_col_1_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_1_B_nn", options, config, algo + "_col_1_B"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_2_one_nn", options, config, algo + "_col_2_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_2_B_nn", options, config, algo + "_col_2_B"); + } +#endif // if defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) + + return results; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int profile_volta884_gemm_cta_rasterization_nn(TestbenchOutput &output, TestbenchOptions const &options, Config const &config) { + int results = 0; + + results |= profile_volta884_gemm_cta_rasterization_nn_kernel, float >(output, options, config, "s884gemm", "128x128"); + + results |= profile_volta884_gemm_cta_rasterization_nn_kernel, float >(output, options, config, "s884gemm_256x128", "256x128"); + + return results; +} + +struct Volta884GemmCTARasterizationNNRegistrar { + Volta884GemmCTARasterizationNNRegistrar() { RegisterGemmProfileFunc(profile_volta884_gemm_cta_rasterization_nn); } +}; + +volatile Volta884GemmCTARasterizationNNRegistrar _Volta884CTARasterizationNNGemmRegistrar; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace perf +#endif // if defined(EXHAUSTIVE_PROF) + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/perf/gemm/volta884_gemm_cta_rasterization_nt.cu b/tools/test/perf/gemm/volta884_gemm_cta_rasterization_nt.cu new file mode 100644 index 0000000000..502c9fcf64 --- /dev/null +++ b/tools/test/perf/gemm/volta884_gemm_cta_rasterization_nt.cu @@ -0,0 +1,247 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "tools/test/perf/gemm/cutlass_volta884_dispatch.h" + +#ifdef EXHAUSTIVE_PROF + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace perf { +template +int profile_volta884_gemm_cta_rasterization_nt_kernel( + TestbenchOutput &output, + TestbenchOptions const &options, + Config const &config, + std::string const &name, + std::string const &algo) { + + int results = 0; + + // compute capability check + if (!options.compute_capability(7, 0)) { + return 0; + } + + typedef typename cutlass::TypeTraits::device_type AccumDevType; + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + typedef perf::GemmProfiler< + cutlass::half_t, + cutlass::half_t, + AccumHostType, + AccumHostType, + AccumHostType> GemmProfiler; + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_1_one_nt", options, config, algo + "_row_1_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_1_B_nt", options, config, algo + "_row_1_B"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_2_one_nt", options, config, algo + "_row_2_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_2_B_nt", options, config, algo + "_row_2_B"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_1_one_nt", options, config, algo + "_col_1_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_1_B_nt", options, config, algo + "_col_1_B"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_2_one_nt", options, config, algo + "_col_2_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_2_B_nt", options, config, algo + "_col_2_B"); + } +#endif // if defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) + + return results; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int profile_volta884_gemm_cta_rasterization_nt(TestbenchOutput &output, TestbenchOptions const &options, Config const &config) { + int results = 0; + + results |= profile_volta884_gemm_cta_rasterization_nt_kernel, float >(output, options, config, "s884gemm", "128x128"); + + results |= profile_volta884_gemm_cta_rasterization_nt_kernel, float >(output, options, config, "s884gemm_256x128", "256x128"); + + results |= profile_volta884_gemm_cta_rasterization_nt_kernel, float >(output, options, config, "s884gemm_128x64", "128x64"); + + results |= profile_volta884_gemm_cta_rasterization_nt_kernel, float >(output, options, config, "s884gemm_64x64", "64x64"); + + return results; +} + +struct Volta884GemmCTARasterizationNTRegistrar { + Volta884GemmCTARasterizationNTRegistrar() { RegisterGemmProfileFunc(profile_volta884_gemm_cta_rasterization_nt); } +}; + +volatile Volta884GemmCTARasterizationNTRegistrar _Volta884CTARasterizationNTGemmRegistrar; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace perf + +#endif // if defined(EXHAUSTIVE_PROF) + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/perf/gemm/volta884_gemm_cta_rasterization_tn.cu b/tools/test/perf/gemm/volta884_gemm_cta_rasterization_tn.cu new file mode 100644 index 0000000000..7fe24c501a --- /dev/null +++ b/tools/test/perf/gemm/volta884_gemm_cta_rasterization_tn.cu @@ -0,0 +1,247 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "tools/test/perf/gemm/cutlass_volta884_dispatch.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace perf { +template +int profile_volta884_gemm_cta_rasterization_tn_kernel( + TestbenchOutput &output, + TestbenchOptions const &options, + Config const &config, + std::string const &name, + std::string const &algo) { + + int results = 0; + + // compute capability check + if (!options.compute_capability(7, 0)) { + return 0; + } + + typedef typename cutlass::TypeTraits::device_type AccumDevType; + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + typedef perf::GemmProfiler< + cutlass::half_t, + cutlass::half_t, + AccumHostType, + AccumHostType, + AccumHostType> GemmProfiler; + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_1_one_tn", options, config, algo + "_row_1_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_1_B_tn", options, config, algo + "_row_1_B"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_2_one_tn", options, config, algo + "_row_2_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_2_B_tn", options, config, algo + "_row_2_B"); + } + +#ifdef EXHAUSTIVE_PROF + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_1_one_tn", options, config, algo + "_col_1_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_1_B_tn", options, config, algo + "_col_1_B"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_2_one_tn", options, config, algo + "_col_2_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_2_B_tn", options, config, algo + "_col_2_B"); + } +#endif // if defined(EXHAUSTIVE_PROF) +#endif // if defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) + + return results; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int profile_volta884_gemm_cta_rasterization_tn(TestbenchOutput &output, TestbenchOptions const &options, Config const &config) { + int results = 0; + + results |= profile_volta884_gemm_cta_rasterization_tn_kernel, float >(output, options, config, "s884gemm", "128x128"); + +#ifdef EXHAUSTIVE_PROF + results |= profile_volta884_gemm_cta_rasterization_tn_kernel, float >(output, options, config, "s884gemm_256x128", "256x128"); + + results |= profile_volta884_gemm_cta_rasterization_tn_kernel, float >(output, options, config, "s884gemm_128x64", "128x64"); + + results |= profile_volta884_gemm_cta_rasterization_tn_kernel, float >(output, options, config, "s884gemm_64x64", "64x64"); +#endif // if defined( EXHAUSTIVE_PROF) + + return results; +} + +struct Volta884GemmCTARasterizationTNRegistrar { + Volta884GemmCTARasterizationTNRegistrar() { RegisterGemmProfileFunc(profile_volta884_gemm_cta_rasterization_tn); } +}; + +volatile Volta884GemmCTARasterizationTNRegistrar _Volta884CTARasterizationTNGemmRegistrar; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace perf + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/perf/gemm/volta884_gemm_cta_rasterization_tt.cu b/tools/test/perf/gemm/volta884_gemm_cta_rasterization_tt.cu new file mode 100644 index 0000000000..cf853e558d --- /dev/null +++ b/tools/test/perf/gemm/volta884_gemm_cta_rasterization_tt.cu @@ -0,0 +1,255 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "tools/test/perf/gemm/cutlass_volta884_dispatch.h" + +#ifdef EXHAUSTIVE_PROF + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace perf { +template +int profile_volta884_gemm_cta_rasterization_tt_kernel( + TestbenchOutput &output, + TestbenchOptions const &options, + Config const &config, + std::string const &name, + std::string const &algo) { + + int results = 0; + + // compute capability check + if (!options.compute_capability(7, 0)) { + return 0; + } + + typedef typename cutlass::TypeTraits::device_type AccumDevType; + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + typedef perf::GemmProfiler< + cutlass::half_t, + cutlass::half_t, + AccumHostType, + AccumHostType, + AccumHostType> GemmProfiler; + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection>, + true + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_1_one_tt", options, config, algo + "_row_1_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon>, + true + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_1_B_tt", options, config, algo + "_row_1_B"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection>, + true + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_2_one_tt", options, config, algo + "_row_2_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon>, + true + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_row_2_B_tt", options, config, algo + "_row_2_B"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection>, + true + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_1_one_tt", options, config, algo + "_col_1_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon>, + true + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_1_B_tt", options, config, algo + "_col_1_B"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection>, + true + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_2_one_tt", options, config, algo + "_col_2_one"); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon>, + true + > GemmTraits; + + typedef Volta884GemmDispatch Dispatch; + + results |= profile_gemm(output, name + "_col_2_B_tt", options, config, algo + "_col_2_B"); + } +#endif // if defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) + + return results; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int profile_volta884_gemm_cta_rasterization_tt(TestbenchOutput &output, TestbenchOptions const &options, Config const &config) { + int results = 0; + + results |= profile_volta884_gemm_cta_rasterization_tt_kernel, float >(output, options, config, "s884gemm", "128x128"); + + results |= profile_volta884_gemm_cta_rasterization_tt_kernel, float >(output, options, config, "s884gemm_256x128", "256x128"); + + results |= profile_volta884_gemm_cta_rasterization_tt_kernel, float >(output, options, config, "s884gemm_128x64", "128x64"); + + results |= profile_volta884_gemm_cta_rasterization_tt_kernel, float >(output, options, config, "s884gemm_64x64", "64x64"); + + return results; +} + +struct Volta884GemmCTARasterizationTTRegistrar { + Volta884GemmCTARasterizationTTRegistrar() { RegisterGemmProfileFunc(profile_volta884_gemm_cta_rasterization_tt); } +}; + +volatile Volta884GemmCTARasterizationTTRegistrar _Volta884CTARasterizationTTGemmRegistrar; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace perf + +#endif // if defined(EXHAUSTIVE_PROF) + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/perf/gemm/volta884_gemm_splitK.cu b/tools/test/perf/gemm/volta884_gemm_splitK.cu new file mode 100644 index 0000000000..59d6f9d763 --- /dev/null +++ b/tools/test/perf/gemm/volta884_gemm_splitK.cu @@ -0,0 +1,324 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "tools/test/perf/gemm/cutlass_volta884_dispatch_splitK_PI.h" +#include "cutlass/reduction/batched_reduction_traits.h" +#include "cutlass/gemm/device_gemm_traits.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace perf { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +int profile_volta884_gemm_splitkpi_kernel( + TestbenchOutput &output, + TestbenchOptions const &options, + Config const &config, + std::string const &name, + std::string const &algo) { + + int results = 0; + + // compute capability check + if (!options.compute_capability(7, 0)) { + return 0; + } + + typedef typename cutlass::TypeTraits::device_type AccumDevType; + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + typedef perf::GemmProfiler< + cutlass::half_t, + cutlass::half_t, + cutlass::half_t, + AccumHostType, + AccumHostType> GemmProfiler; + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2 + > GemmTraits; + + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + threadReductionShape > + BatchedReductionTraits; + + // create a device gemm + typedef typename cutlass::gemm::SplitkPIGemmTraits deviceGemmTraits; + typedef Volta884GemmDispatchSplitKPI Dispatch; + + results |= profile_gemm(output, name + "_nn", options, config, algo); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2 + > GemmTraits; + + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + threadReductionShape > + BatchedReductionTraits; + + // create a device gemm + typedef typename cutlass::gemm::SplitkPIGemmTraits deviceGemmTraits; + typedef Volta884GemmDispatchSplitKPI Dispatch; + + results |= profile_gemm(output, name + "_nt", options, config, algo); + } + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2 + > GemmTraits; + + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + threadReductionShape > + BatchedReductionTraits; + + // create a device gemm + typedef typename cutlass::gemm::SplitkPIGemmTraits deviceGemmTraits; + typedef Volta884GemmDispatchSplitKPI Dispatch; + + results |= profile_gemm(output, name + "_tn", options, config, algo); +} + + { + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + OutputTile, + cutlass::Shape<32, 64, 64>, + AccumDevType, + AccumDevType, + AccumDevType, + 2 + > GemmTraits; + + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + threadReductionShape > + BatchedReductionTraits; + + // create a device gemm + typedef typename cutlass::gemm::SplitkPIGemmTraits deviceGemmTraits; + typedef Volta884GemmDispatchSplitKPI Dispatch; + + results |= profile_gemm(output, name + "_tt", options, config, algo); + } + + #endif // if defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) + + return results; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int profile_volta884_gemm_splitkpi(TestbenchOutput &output, TestbenchOptions const &options, Config const &config) { + int results = 0; + + //results |= profile_volta884_gemm_kernel, float >(output, options, config, "s884gemm", "128x128"); + + // half accum + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 2>, 5 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits5", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 2>, 8 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits8", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 2>, 10 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits10", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 2>, 16 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits16", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 2>, 20 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits20", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 2>, 24 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits24", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 2>, 28 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits28", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 2>, 32 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits32", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 2>, 40 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits40", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 2>, 48 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits48", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 2>, 56 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits56", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 2>, 64 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits64", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 2>, 72 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits72", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 2>, 80 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits80", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 1>, 88 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits88", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 1>, 96 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits96", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 1>, 104 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits104", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 1>, 112 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits112", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 1>, 120 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits120", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 1>, 128 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits128", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 1>, 136 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits136", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 1>, 144 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits144", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 1>, 152 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits152", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, cutlass::half_t, cutlass::Shape<1, 1, 1>, 160 >(output, options, config, "h884gemm_128x128x32_splitk_pi_splits160", "128x128"); + + // float accum + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 5 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits5", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 8 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits8", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 10 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits10", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 16 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits16", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 20 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits20", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 24 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits24", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 28 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits28", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 32 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits32", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 40 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits40", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 48 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits48", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 56 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits56", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 64 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits64", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 72 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits72", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 80 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits80", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 88 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits88", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 96 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits96", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 104 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits104", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 112 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits112", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 120 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits120", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 128 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits128", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 136 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits136", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 144 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits144", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 152 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits152", "128x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 160 >(output, options, config, "s884gemm_128x128x32_splitk_pi_splits160", "128x128"); + +#ifdef EXHAUSTIVE_PROF + // float accum 128x64 + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 5 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits5", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 8 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits8", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 10 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits10", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 16 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits16", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 20 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits20", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 24 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits24", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 28 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits28", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 32 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits32", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 40 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits40", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 48 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits48", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 56 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits56", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 64 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits64", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 72 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits72", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 80 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits80", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 88 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits88", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 96 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits96", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 104 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits104", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 112 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits112", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 120 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits120", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 128 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits128", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 136 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits136", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 144 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits144", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 152 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits152", "128x64"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 160 >(output, options, config, "s884gemm_128x64x32_splitk_pi_splits160", "128x64"); + + // float accum 64x128 + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 5 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits5", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 8 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits8", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 10 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits10", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 16 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits16", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 20 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits20", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 24 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits24", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 28 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits28", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 32 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits32", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 40 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits40", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 48 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits48", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 56 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits56", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 64 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits64", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 72 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits72", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 2>, 80 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits80", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 88 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits88", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 96 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits96", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 104 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits104", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 112 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits112", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 120 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits120", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 128 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits128", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 136 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits136", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 144 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits144", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 152 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits152", "64x128"); + results |= profile_volta884_gemm_splitkpi_kernel, float, cutlass::Shape<1, 1, 1>, 160 >(output, options, config, "s884gemm_64x128x32_splitk_pi_splits160", "64x128"); +#endif //#ifdef EXHAUSTIVE_PROF + return results; +} + +struct Volta884GemmSplitKPIRegistrar { + Volta884GemmSplitKPIRegistrar() { RegisterGemmProfileFunc(profile_volta884_gemm_splitkpi); } +}; + +volatile Volta884GemmSplitKPIRegistrar _Volta884GemmSplitKPIRegistrar; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace perf + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/perf/gemm/wmma_binary_gemm.cu b/tools/test/perf/gemm/wmma_binary_gemm.cu index 81ee4fab63..6083c4c5a0 100644 --- a/tools/test/perf/gemm/wmma_binary_gemm.cu +++ b/tools/test/perf/gemm/wmma_binary_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -30,7 +30,7 @@ #include "cutlass/wmma_matrix.h" #ifdef CUTLASS_USE_WMMA_API #ifdef CUTLASS_USE_SUBBYTE_WMMA - +#pragma warning( disable : 4503) //////////////////////////////////////////////////////////////////////////////////////////////////// #include "cutlass/gemm/gemm.h" diff --git a/tools/test/perf/gemm/wmma_gemm.cu b/tools/test/perf/gemm/wmma_gemm.cu index 15c0e6eb7d..60a2c98ab1 100644 --- a/tools/test/perf/gemm/wmma_gemm.cu +++ b/tools/test/perf/gemm/wmma_gemm.cu @@ -1,27 +1,27 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. -* -* Redistribution and use in source and binary forms, with or without modification, are permitted -* provided that the following conditions are met: -* * Redistributions of source code must retain the above copyright notice, this list of -* conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, this list of -* conditions and the following disclaimer in the documentation and/or other materials -* provided with the distribution. -* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used -* to endorse or promote products derived from this software without specific prior written -* permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE -* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -* -**************************************************************************************************/ + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ #include "cutlass/wmma_matrix.h" #ifdef CUTLASS_USE_WMMA_API @@ -49,7 +49,11 @@ struct WmmaGemmDispatch { /// Indicate warp-level GEMM static bool const kThreadMultiplyAdd = false; + #if CUTLASS_ENABLE_CUBLAS static bool const kRunCuBLAS = true; + #else + static bool const kRunCuBLAS = false; + #endif static cutlass::MatrixLayout::Kind const kLayoutA = Traits::kLayoutA; static cutlass::MatrixLayout::Kind const kLayoutB = Traits::kLayoutB; @@ -138,8 +142,8 @@ int profile_wmma_gemm_f32(TestbenchOutput &output, TestbenchOptions { typedef cutlass::gemm::WmmaGemmTraits - WmmaGemmTraits; + cutlass::MatrixLayout::kRowMajor> + WmmaGemmTraits; typedef WmmaGemmDispatch Dispatch; @@ -148,8 +152,8 @@ int profile_wmma_gemm_f32(TestbenchOutput &output, TestbenchOptions { typedef cutlass::gemm::WmmaGemmTraits - WmmaGemmTraits; + cutlass::MatrixLayout::kColumnMajor> + WmmaGemmTraits; typedef WmmaGemmDispatch Dispatch; @@ -158,7 +162,7 @@ int profile_wmma_gemm_f32(TestbenchOutput &output, TestbenchOptions { typedef cutlass::gemm::WmmaGemmTraits + cutlass::MatrixLayout::kColumnMajor> WmmaGemmTraits; typedef WmmaGemmDispatch Dispatch; @@ -168,7 +172,7 @@ int profile_wmma_gemm_f32(TestbenchOutput &output, TestbenchOptions { typedef cutlass::gemm::WmmaGemmTraits + cutlass::MatrixLayout::kRowMajor> WmmaGemmTraits; typedef WmmaGemmDispatch Dispatch; @@ -183,9 +187,9 @@ int profile_wmma_gemm_f32(TestbenchOutput &output, TestbenchOptions template int profile_wmma_gemm_f16( - TestbenchOutput &output, - TestbenchOptions const &options, - Config const &config) { + TestbenchOutput &output, + TestbenchOptions const &options, + Config const &config) { typedef perf::GemmProfiler< cutlass::half_t, @@ -278,9 +282,9 @@ int profile_wmma_gemm_f16( template int profile_wmma_4_gemm_f16( - TestbenchOutput &output, - TestbenchOptions const &options, - Config const &config) { + TestbenchOutput &output, + TestbenchOptions const &options, + Config const &config) { typedef perf::GemmProfiler< cutlass::half_t, @@ -547,11 +551,11 @@ struct WmmaGemmRegistrar { RegisterGemmProfileFunc(profile_wmma_gemm_f32); RegisterGemmProfileFunc(profile_wmma_gemm_f16); - //#ifdef EXHAUSTIVE_PROF +//#ifdef EXHAUSTIVE_PROF RegisterGemmProfileFunc(profile_wmma_4_gemm_f16); //fp32 accum with fp16 input and output RegisterGemmProfileFunc(profile_wmma_4_fp16_sgemm_fp16); - //#endif // defined EXHAUSTIVE_PROF +//#endif // defined EXHAUSTIVE_PROF } }; @@ -564,4 +568,3 @@ volatile WmmaGemmRegistrar _WmmaGemmRegistrar; //////////////////////////////////////////////////////////////////////////////////////////////////// #endif // defined CUTLASS_USE_WMMA_API - diff --git a/tools/test/perf/gemm/wmma_integer_gemm.cu b/tools/test/perf/gemm/wmma_integer_gemm.cu index 848b28eaed..854d170592 100644 --- a/tools/test/perf/gemm/wmma_integer_gemm.cu +++ b/tools/test/perf/gemm/wmma_integer_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -29,8 +29,8 @@ #include "cutlass/wmma_matrix.h" #ifdef CUTLASS_USE_WMMA_API -#ifdef CUTLASS_USE_SUBBYTE_WMMA - +#ifdef CUTLASS_USE_INT_WMMA +#pragma warning( disable : 4503) #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/wmma_gemm_traits.h" #include "tools/test/perf/gemm/cutlass_dispatch.h" @@ -92,6 +92,7 @@ struct WmmaIntegerGemmDispatch { //////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef CUTLASS_USE_SUBBYTE_WMMA template struct WmmaIntegerGemmDispatch, @@ -209,6 +210,7 @@ struct WmmaIntegerGemmDispatch &output, TestbenchOpt int results = 0; // compute capability check - if (!options.compute_capability(7, 5)) { + if (!options.compute_capability(7, 2)) { return 0; } @@ -398,6 +400,7 @@ int profile_wmma_integer_gemm(TestbenchOutput &output, TestbenchOpt return 0; } +#ifdef CUTLASS_USE_SUBBYTE_WMMA { typedef cutlass::gemm::WmmaGemmTraits &output, TestbenchOpt results |= profile_gemm(output, "wmma_integer_gemm_u4_tn", options, config); } +#endif //ifdef CUTLASS_USE_SUBBYTE_WMMA return results; } @@ -461,7 +465,7 @@ int profile_wmma_integer_gemm(TestbenchOutput &output, TestbenchOpt //////////////////////////////////////////////////////////////////////////////////////////////////// -#else // ! CUTLASS_USE_SUBBYTE_WMMA +#else // ! CUTLASS_USE_INT_WMMA namespace perf { diff --git a/tools/test/perf/performance_result.h b/tools/test/perf/performance_result.h index 4906788ba5..99dc40f8c9 100644 --- a/tools/test/perf/performance_result.h +++ b/tools/test/perf/performance_result.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -28,7 +28,6 @@ #include "cutlass/matrix_traits.h" #include "tools/util/command_line.h" #include "tools/test/perf/provider.h" - //////////////////////////////////////////////////////////////////////////////////////////////////// namespace perf { @@ -175,7 +174,6 @@ inline std::ostream &operator<<(std::ostream &out, GemmProblem const &problem) { return out; } - //////////////////////////////////////////////////////////////////////////////////////////////////// /// Result object diff --git a/tools/test/perf/provider.h b/tools/test/perf/provider.h index 544ee3fbb1..06569b44b9 100644 --- a/tools/test/perf/provider.h +++ b/tools/test/perf/provider.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/perf/testbench_configs.h b/tools/test/perf/testbench_configs.h index a7036aba86..1147ba19ba 100644 --- a/tools/test/perf/testbench_configs.h +++ b/tools/test/perf/testbench_configs.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -45,9 +45,9 @@ struct Config { // kernel to run std::vector kernels; - /// Range of problem sizes - GemmProblemRange problem_range; - + /// Range of problem sizes for GEMM + GemmProblemRange gemm_problem_range; + // Reference GFLOPs double gflops_ref; @@ -121,8 +121,12 @@ struct TestbenchConfigs { if (item.compare("Kernel") == 0) kernel_idx = idx; if (item.compare("Beta") == 0) beta_idx = idx; if (item.compare("M") == 0) m_idx = idx; - if (item.compare("N") == 0) n_idx = idx; - if (item.compare("K") == 0) k_idx = idx; + if (item.compare("N") == 0) { + n_idx = idx; + } + if (item.compare("K") == 0) { + k_idx = idx; + } if (item.compare("GFLOPs") == 0) gflops_idx = idx; if (item.compare("Runtime") == 0) runtime_idx = idx; if (item.compare("SOL") == 0) peak_throughput_idx = idx; @@ -150,9 +154,9 @@ struct TestbenchConfigs { config.alpha = options.alpha; config.beta = strtod(tokens[beta_idx].c_str(), NULL); config.kernels.push_back(tokens[kernel_idx]); - config.problem_range.M = Range((int)strtol(tokens[m_idx].c_str(), NULL, 10)); - config.problem_range.N = Range((int)strtol(tokens[n_idx].c_str(), NULL, 10)); - config.problem_range.K = Range((int)strtol(tokens[k_idx].c_str(), NULL, 10)); + config.gemm_problem_range.M = Range(tokens[m_idx]); + config.gemm_problem_range.N = Range(tokens[n_idx]); + config.gemm_problem_range.K = Range(tokens[k_idx]); config.gflops_ref = strtod(tokens[gflops_idx].c_str(), NULL); config.runtime_ref = strtod(tokens[runtime_idx].c_str(), NULL); config.peak_throughput_ref = strtod(tokens[peak_throughput_idx].c_str(), NULL); @@ -172,7 +176,7 @@ struct TestbenchConfigs { for (int i = 0; i < options.kernels.size(); i++) { config.kernels.push_back(options.kernels[i]); } - config.problem_range = options.problem_range; + config.gemm_problem_range = options.gemm_problem_range; configs.push_back(config); } diff --git a/tools/test/perf/testbench_options.h b/tools/test/perf/testbench_options.h index 4b1fd899f4..d7ec89b9a4 100644 --- a/tools/test/perf/testbench_options.h +++ b/tools/test/perf/testbench_options.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -66,6 +66,11 @@ struct Range { Range(int _start, int _end, int _increment = 1, Operator _op = Add) : start(_start), end(_end), increment(_increment), increment_op(_op) {} + Range(std::string _start) : increment(1), increment_op(Add) { + start = end = (int)strtol(_start.c_str(), NULL, 10); + } + + /// Returns the next item in series int next(int val) const { switch (increment_op) { @@ -162,8 +167,6 @@ struct GemmProblemRange { } }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - /// Defines a vector of string pairs typedef std::vector > KeyValueVector; @@ -391,8 +394,8 @@ struct TestbenchOptions { /// Scalar value for GEMM double beta; - /// Range of problem sizes - GemmProblemRange problem_range; + /// Range of GEMM problem sizes + GemmProblemRange gemm_problem_range; /// If true, kernels are not executed, and no sleep waits are inserted bool dry_run; @@ -418,7 +421,7 @@ struct TestbenchOptions { : initial_distribution(args), execution_mode(ExecutionMode::Profile), save_workspace_mode(WorkspaceSaveMode::Never), - problem_range(args), + gemm_problem_range(args), dry_run(false), sleep_time(1) { @@ -473,6 +476,8 @@ struct TestbenchOptions { "igemm", "wmma_gemm", "wmma_gemm_f16", + "s884gemm", + "h884gemm", "wmma_binary_gemm", "wmma_integer_gemm", 0 @@ -480,7 +485,8 @@ struct TestbenchOptions { char const *layouts[] = {"nn", "nt", "tn", "tt", 0}; for (int i = 0; gemms[i]; ++i) { for (int j = 0; layouts[j]; ++j) { - if ((std::string(gemms[i]).compare("wmma_binary_gemm") == 0 || + if (( + std::string(gemms[i]).compare("wmma_binary_gemm") == 0 || std::string(gemms[i]).compare("wmma_integer_gemm") == 0) && std::string(layouts[j]).compare("tn") != 0) { continue; @@ -488,7 +494,7 @@ struct TestbenchOptions { kernels.push_back(std::string(gemms[i]) + "_" + layouts[j]); } } - + } } @@ -596,15 +602,14 @@ struct TestbenchOptions { << " Height of GEMM problem (number of rows of C). May specify a range with optional " "step size.\n" - << " --n=[:max width[:step]] " + << " --n=[:max width[:step]] (GEMM-specific)" << " Width of GEMM problem (number of columns of C). May specify a range with optional " "step size.\n" - << " --k=[:max depth[:step]] " + << " --k=[:max depth[:step]] (GEMM-specific)" << " Size of inner dimension of A and B. May specify a range with optional step size.\n" - << " --batch= " - << " Number of batches for a bached gemm. " + << " Number of batches for a batched gemm.\n" << " --kernels=<{s|d|h|i|wmma_|wmma_binary_|wmma_integer_}gemm_{nn,nt,tn,tt}>\n" << " " @@ -641,13 +646,14 @@ struct TestbenchOptions { out << "\n\n" << "Example usage:\n\n" - << "# Runs one problem size for all kernels\n" + << "# Runs one problem size for all GEMM kernels\n" << "./tools/test/perf/cutlass_perf_test --m=10240 --n=1024 --k=1024\n\n" << "# Varies GEMM K dimension for SGEMM and IGEMM with column-major multiplicands\n" << "./tools/test/perf/cutlass_perf_test --m=10240 --n=4096 --k=1024:8192:128 " "--kernels=sgemm_nn,igemm_nn\n\n" - + << " # Executes GEMM kernel on Volta Tensor Cores\n" + << " $ ./tools/test/perf/cutlass_perf_test --kernels=s884gemm_nt\n\n" << std::flush; } }; diff --git a/tools/test/perf/testbench_output.h b/tools/test/perf/testbench_output.h index 297f02f896..c3619293a3 100644 --- a/tools/test/perf/testbench_output.h +++ b/tools/test/perf/testbench_output.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/CMakeLists.txt b/tools/test/unit/CMakeLists.txt index 795770c2db..1e8738c2bc 100644 --- a/tools/test/unit/CMakeLists.txt +++ b/tools/test/unit/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are permitted # provided that the following conditions are met: @@ -43,11 +43,6 @@ set(CUTLASS_UNIT_TEST_HEADERS reduction/test_batched_reduction.h ) -set(CUTLASS_UNIT_TEST_SOURCES_BACKUP - cutlass_unit_test.cpp - gemm/batched_strided_sgemm_128x128x8.cu -) - set(CUTLASS_UNIT_TEST_SOURCES cutlass_unit_test.cpp tile_iterator_test.cu @@ -85,6 +80,7 @@ set(CUTLASS_UNIT_TEST_SOURCES gemm/sgemm_64x64x16.cu gemm/sgemm_64x32x8.cu gemm/sgemm_64x32x16.cu + gemm/sgemm_32x128x8.cu gemm/fp16_sgemm_fp32_128x128x16.cu gemm/fp16_sgemm_fp16_128x128x16.cu gemm/wmma_gemm.cu @@ -102,7 +98,13 @@ set(CUTLASS_UNIT_TEST_SOURCES gemm/batched_strided_hgemm_128x128x8.cu gemm/batched_strided_wmma_gemm.cu gemm/batched_strided_fp16_wmma_gemm_fp16.cu + gemm/batched_strided_volta884_hgemm.cu gemm/epilogue_functor.cu + gemm/volta884_gemm_epilogue.cu + gemm/volta884_h884gemm.cu + gemm/volta884_gemm.cu + gemm/volta884_gemm_threadblock_swizzle.cu + gemm/volta884_h884gemm_threadblock_swizzle.cu reduction/batched_reduction.cu reduction/mixed_batched_reduction.cu gemm/splitK_sgemm.cu @@ -111,7 +113,9 @@ set(CUTLASS_UNIT_TEST_SOURCES gemm/splitK_dgemm.cu gemm/splitK_hgemm.cu gemm/splitK_wmma_gemm.cu + gemm/splitK_volta884_hgemm.cu gemm/partitionedK_sgemm_128x128x8.cu + gemm/partitionedK_volta884_hgemm.cu ) if (CUTLASS_NVRTC_ENABLE) @@ -144,6 +148,7 @@ if (CUTLASS_NVRTC_ENABLE) endif() endif() -target_link_libraries(cutlass_unit_test ${CUBLAS_LIBRARY}) - +if(CUTLASS_ENABLE_CUBLAS) + target_link_libraries(cutlass_unit_test ${CUBLAS_LIBRARY}) +endif() diff --git a/tools/test/unit/core/layout_verification.cu b/tools/test/unit/core/layout_verification.cu index c043ced090..76c1d7c67d 100644 --- a/tools/test/unit/core/layout_verification.cu +++ b/tools/test/unit/core/layout_verification.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/core/layout_verification.h b/tools/test/unit/core/layout_verification.h index a0716131de..86222acacd 100644 --- a/tools/test/unit/core/layout_verification.h +++ b/tools/test/unit/core/layout_verification.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/core/predicate_vector.cu b/tools/test/unit/core/predicate_vector.cu index ea3a359d31..f8c9294026 100644 --- a/tools/test/unit/core/predicate_vector.cu +++ b/tools/test/unit/core/predicate_vector.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/core/tensor_ref.cu b/tools/test/unit/core/tensor_ref.cu index ee16f92f1f..fef042a802 100644 --- a/tools/test/unit/core/tensor_ref.cu +++ b/tools/test/unit/core/tensor_ref.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/core/tensor_view.cu b/tools/test/unit/core/tensor_view.cu index 8090f468d9..288228a629 100644 --- a/tools/test/unit/core/tensor_view.cu +++ b/tools/test/unit/core/tensor_view.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/core/tile_iterator.cu b/tools/test/unit/core/tile_iterator.cu index c7f9598121..eabc234903 100644 --- a/tools/test/unit/core/tile_iterator.cu +++ b/tools/test/unit/core/tile_iterator.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/core/zip_tile_iterator.cu b/tools/test/unit/core/zip_tile_iterator.cu index 2117e012d5..c57be86ce2 100644 --- a/tools/test/unit/core/zip_tile_iterator.cu +++ b/tools/test/unit/core/zip_tile_iterator.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/cutlass_unit_test.cpp b/tools/test/unit/cutlass_unit_test.cpp index 355235aa35..a6cc8ccb1c 100644 --- a/tools/test/unit/cutlass_unit_test.cpp +++ b/tools/test/unit/cutlass_unit_test.cpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -84,9 +84,9 @@ void set_gtest_flag() { { "*wmma*", 70, false }, { "WmmaInt8*", 72, false }, { "*wmmaInt8*", 72, false }, - { "WmmaInt4*", 75, true }, + { "WmmaInt4*", 75, true }, { "*wmmaInt4*", 75, true }, - { "WmmaBinary*", 75, true }, + { "WmmaBinary*", 75, true }, { "*wmmaBinary*", 75, true }, { 0, 0, false } }; diff --git a/tools/test/unit/cutlass_unit_test.h b/tools/test/unit/cutlass_unit_test.h index 2ffced5828..b52f7f6041 100644 --- a/tools/test/unit/cutlass_unit_test.h +++ b/tools/test/unit/cutlass_unit_test.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -29,3 +29,7 @@ #include #pragma diag_warning boolean_controlling_expr_is_constant #pragma warning( disable : 4503) + +#if !defined(CUTLASS_ENABLE_CUBLAS) +#define CUTLASS_ENABLE_CUBLAS 0 +#endif diff --git a/tools/test/unit/gemm/batched_strided_dgemm_128x128x8.cu b/tools/test/unit/gemm/batched_strided_dgemm_128x128x8.cu index 8b0bc16358..5e0538025c 100644 --- a/tools/test/unit/gemm/batched_strided_dgemm_128x128x8.cu +++ b/tools/test/unit/gemm/batched_strided_dgemm_128x128x8.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/batched_strided_fp16_wmma_gemm_fp16.cu b/tools/test/unit/gemm/batched_strided_fp16_wmma_gemm_fp16.cu index fef9e70c96..b408c2c693 100644 --- a/tools/test/unit/gemm/batched_strided_fp16_wmma_gemm_fp16.cu +++ b/tools/test/unit/gemm/batched_strided_fp16_wmma_gemm_fp16.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/batched_strided_hgemm_128x128x8.cu b/tools/test/unit/gemm/batched_strided_hgemm_128x128x8.cu index 4738d29f92..85406eb7eb 100644 --- a/tools/test/unit/gemm/batched_strided_hgemm_128x128x8.cu +++ b/tools/test/unit/gemm/batched_strided_hgemm_128x128x8.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/batched_strided_sgemm_128x128x8.cu b/tools/test/unit/gemm/batched_strided_sgemm_128x128x8.cu index ffeba34f40..2f17f9f44b 100644 --- a/tools/test/unit/gemm/batched_strided_sgemm_128x128x8.cu +++ b/tools/test/unit/gemm/batched_strided_sgemm_128x128x8.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/batched_strided_volta884_hgemm.cu b/tools/test/unit/gemm/batched_strided_volta884_hgemm.cu new file mode 100644 index 0000000000..3cce0025da --- /dev/null +++ b/tools/test/unit/gemm/batched_strided_volta884_hgemm.cu @@ -0,0 +1,114 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include "cutlass_unit_test.h" + +#include "tools/util/half.h" +#include "tools/util/host_tensor.h" +#include "tools/util/tensor_view_io.h" + +#include "cutlass/gemm/volta884_gemm_traits.h" +#include "cutlass/gemm/gemm.h" + +#include "tools/test/unit/gemm/gemm_testbed.h" +#include "tools/test/unit/gemm/run_gemm.h" + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_batched_strided_64x64x32_nt, 64x128x64x3) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_batched_strided_gemm(64, 128, 64, 3); +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_batched_strided_64x64x32_nn, 64x128x64x3) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_batched_strided_gemm(64, 128, 64, 3); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_batched_strided_64x64x32_tn, 64x128x64x3) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_batched_strided_gemm(64, 128, 64, 3); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_batched_strided_64x64x32_tt, 64x128x64x3) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_batched_strided_gemm(64, 128, 64, 3); +} + +#endif // if defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) diff --git a/tools/test/unit/gemm/batched_strided_wmma_gemm.cu b/tools/test/unit/gemm/batched_strided_wmma_gemm.cu index 0e47d98e02..fcaa5fbbf9 100644 --- a/tools/test/unit/gemm/batched_strided_wmma_gemm.cu +++ b/tools/test/unit/gemm/batched_strided_wmma_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/dgemm.cu b/tools/test/unit/gemm/dgemm.cu index ebfeba9205..6681915d2f 100644 --- a/tools/test/unit/gemm/dgemm.cu +++ b/tools/test/unit/gemm/dgemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/epilogue_functor.cu b/tools/test/unit/gemm/epilogue_functor.cu index cc03735164..6aa68257ba 100644 --- a/tools/test/unit/gemm/epilogue_functor.cu +++ b/tools/test/unit/gemm/epilogue_functor.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -83,7 +83,7 @@ TEST(Sgemm_epilogue_functor, device_ptr_mode_sgemm_1024x512x128_nt) { // // Construct a CUTLASS GEMM and initialize parameters // - typedef typename SgemmTraits::KernelClass Gemm; + typedef cutlass::gemm::Gemm Gemm; typename Gemm::Params params; params.initialize(testbed.M(), diff --git a/tools/test/unit/gemm/fp16_sgemm_fp16_128x128x16.cu b/tools/test/unit/gemm/fp16_sgemm_fp16_128x128x16.cu index a3db605e99..acf88473a7 100644 --- a/tools/test/unit/gemm/fp16_sgemm_fp16_128x128x16.cu +++ b/tools/test/unit/gemm/fp16_sgemm_fp16_128x128x16.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/fp16_sgemm_fp32_128x128x16.cu b/tools/test/unit/gemm/fp16_sgemm_fp32_128x128x16.cu index 21b6c40451..f604b387a3 100644 --- a/tools/test/unit/gemm/fp16_sgemm_fp32_128x128x16.cu +++ b/tools/test/unit/gemm/fp16_sgemm_fp32_128x128x16.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/fp16_wmma_gemm_fp16.cu b/tools/test/unit/gemm/fp16_wmma_gemm_fp16.cu index 2d3617a6cd..0be620d7d3 100644 --- a/tools/test/unit/gemm/fp16_wmma_gemm_fp16.cu +++ b/tools/test/unit/gemm/fp16_wmma_gemm_fp16.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/fp16_wmma_gemm_fp16_non_multiple16.cu b/tools/test/unit/gemm/fp16_wmma_gemm_fp16_non_multiple16.cu index d819351cb2..0f6596286b 100644 --- a/tools/test/unit/gemm/fp16_wmma_gemm_fp16_non_multiple16.cu +++ b/tools/test/unit/gemm/fp16_wmma_gemm_fp16_non_multiple16.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/gemm_nvrtc.h b/tools/test/unit/gemm/gemm_nvrtc.h index fae1e7d6ff..493eafb4aa 100644 --- a/tools/test/unit/gemm/gemm_nvrtc.h +++ b/tools/test/unit/gemm/gemm_nvrtc.h @@ -88,8 +88,6 @@ static __host__ void run_gemm_nvrtc( std::string type_name; #if 0 - // TODO Ideally we'd use nvrtcGetTypeName to determine the type, but it cannot resolve enum symbol names - // As altername solution we might want to implement to_string() to get the traits string. nvrtcGetTypeName(&type_name); #else type_name = gemm_traits; diff --git a/tools/test/unit/gemm/gemm_testbed.h b/tools/test/unit/gemm/gemm_testbed.h index 40399144cf..7cf6388819 100644 --- a/tools/test/unit/gemm/gemm_testbed.h +++ b/tools/test/unit/gemm/gemm_testbed.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -204,7 +204,7 @@ struct GemmTestbed { struct RandomBitGenerator { RandomBitGenerator(int seed = -1) { srand(seed); } - T operator()() { + T operator()() { uint32_t val = 0; for (int i = 0; i < 32; i++) { val |= rand() % 2; @@ -268,7 +268,10 @@ struct GemmTestbed { /// partitionK count int partitionK_count; - + + /// each partition should be mulitples of partitionK_multiple + int partitionK_multiple; + /// distance between A[i] and A[i+1] for strided batched gemm long long int batch_stride_A; @@ -316,13 +319,19 @@ struct GemmTestbed { algorithm(algorithm_), batch_count(1), partitionK_count(1), + partitionK_multiple(1), batch_stride_A(static_cast(0)), batch_stride_B(static_cast(0)), batch_stride_C(static_cast(0)) { + + #if CUTLASS_ENABLE_CUBLAS status = cublasCreate(&handle); if (status != CUBLAS_STATUS_SUCCESS) { throw cutlass::cuda_exception("Failed to create CUBLAS handle"); } + #else + status = CUBLAS_STATUS_NOT_INITIALIZED; + #endif resize(A, M_, K_, layout_a); resize(B, K_, N_, layout_b); @@ -355,6 +364,7 @@ struct GemmTestbed { algorithm(algorithm_), batch_count(1), partitionK_count(1), + partitionK_multiple(1), batch_stride_A(static_cast(0)), batch_stride_B(static_cast(0)), batch_stride_C(static_cast(0)) { @@ -389,13 +399,19 @@ struct GemmTestbed { algorithm(algorithm_), batch_count(1), partitionK_count(1), + partitionK_multiple(1), batch_stride_A(static_cast(0)), batch_stride_B(static_cast(0)), batch_stride_C(static_cast(0)) { + + #if CUTLASS_ENABLE_CUBLAS status = cublasCreate(&handle); if (status != CUBLAS_STATUS_SUCCESS) { throw cutlass::cuda_exception("Failed to create CUBLAS handle"); } + #else + status = CUBLAS_STATUS_NOT_INITIALIZED; + #endif resize(A, M_, K_, layout_a, lda); resize(B, K_, N_, layout_b, ldb); @@ -428,6 +444,7 @@ struct GemmTestbed { algorithm(algorithm_), batch_count(1), partitionK_count(1), + partitionK_multiple(1), batch_stride_A(static_cast(0)), batch_stride_B(static_cast(0)), batch_stride_C(static_cast(0)) { @@ -462,12 +479,17 @@ struct GemmTestbed { beta(beta_), algorithm(algorithm_), batch_count(batch_count_), - partitionK_count(1) { + partitionK_count(1), + partitionK_multiple(1) { + #if CUTLASS_ENABLE_CUBLAS status = cublasCreate(&handle); if (status != CUBLAS_STATUS_SUCCESS) { throw cutlass::cuda_exception("Failed to create CUBLAS handle"); } + #else + status = CUBLAS_STATUS_NOT_INITIALIZED; + #endif resize(A, M_, K_ * batch_count, layout_a); resize(B, K_ * batch_count, N_, layout_b); @@ -491,6 +513,7 @@ struct GemmTestbed { GemmTestbed(int M_, int N_, std::pair K_pair_, /*(k, partitionK_count)*/ + int partitionK_multiple_, /*each partition should be mulitiple of partitionK_multiple*/ cublasOperation_t layout_a, cublasOperation_t layout_b, Scalar alpha_ = Scalar(1), @@ -504,12 +527,18 @@ struct GemmTestbed { beta(beta_), algorithm(algorithm_), batch_count(1), - partitionK_count(K_pair_.second) { + partitionK_count(K_pair_.second), + partitionK_multiple(partitionK_multiple_) { + #if CUTLASS_ENABLE_CUBLAS status = cublasCreate(&handle); if (status != CUBLAS_STATUS_SUCCESS) { throw cutlass::cuda_exception("Failed to create CUBLAS handle"); } + #else + status = CUBLAS_STATUS_NOT_INITIALIZED; + #endif + resize(A, M_, K_pair_.first, layout_a); resize(B, K_pair_.first, N_, layout_b); resize(C_initial, M_, N_ * partitionK_count, layout_c); @@ -521,6 +550,7 @@ struct GemmTestbed { // we can use a combination of batched stried gemm and regular gemm // to simulation partitionedK, which is what we will do for reference code int partitionK_size = K() / partitionK_count; + partitionK_size = partitionK_size - (partitionK_size % partitionK_multiple); batch_stride_A = (layout_a == CUBLAS_OP_N) ? M_ * partitionK_size : partitionK_size; batch_stride_B = (layout_b == CUBLAS_OP_N) ? partitionK_size : partitionK_size * N_; batch_stride_C = M_ * N_; @@ -528,9 +558,11 @@ struct GemmTestbed { /// Destructs the GEMM testbed ~GemmTestbed() { + #if CUTLASS_ENABLE_CUBLAS if (status != CUBLAS_STATUS_NOT_INITIALIZED) { status = cublasDestroy(handle); } + #endif } /// Returns true if the last CUBLAS call returned successfully @@ -623,15 +655,16 @@ struct GemmTestbed { // Initialize the source matrix with a uniform distribution cutlass::Distribution dist; dist.set_uniform(-8, 8); - + cutlass::reference::host::TensorInitialize(A.host_view(), seed, dist); cutlass::reference::host::TensorInitialize(B.host_view(), seed + 11, dist); cutlass::reference::host::TensorInitialize(C_initial.host_view(), seed + 13, dist); - + A.sync_device(); B.sync_device(); C_initial.sync_device(); + computed.fill(0); } /// Initializes binary data @@ -673,6 +706,7 @@ struct GemmTestbed { /// Excutes an equivalent GEMM using cuBLAS bool execute_cublas() { + #if CUTLASS_ENABLE_CUBLAS if (partitionK_count == 1) { if (batch_count == 1) { status = cublasGemmEx(handle, @@ -727,6 +761,7 @@ struct GemmTestbed { //first call strided batched gemm int partitionK_size = K() / partitionK_count; + partitionK_size = partitionK_size - (partitionK_size % partitionK_multiple); //int lastK_size = (K() % partitionK_size) + partitionK_size; int lastK_size = K() - partitionK_size * (partitionK_count - 1); status = cublasGemmStridedBatchedTemplate(handle, @@ -770,6 +805,9 @@ struct GemmTestbed { return status == CUBLAS_STATUS_SUCCESS; } + #else + return false; + #endif } /// Helper function to use cublasGemmStridedBatched @@ -892,49 +930,65 @@ struct GemmTestbed { /// Verifies the contents of computed equal cuBLAS bool verify_with_cublas(bool save_on_error = true, bool always_print = false) { + + bool passed = false; + + #if CUTLASS_ENABLE_CUBLAS compute_cublas(); ref_cublas.sync_host(); computed.sync_host(); - - bool passed = computed.bit_equals(ref_cublas); + passed = computed.bit_equals(ref_cublas); if ((!passed && save_on_error) || always_print) { save_workspace(computed, ref_cublas); } + + #endif return passed; } /// Verifies the host computation with cuBLAS bool verify_host_with_cublas(bool save_on_error = true, bool always_print = false) { + + bool passed = false; + + #if CUTLASS_ENABLE_CUBLAS + compute_host(); compute_cublas(); ref_cublas.sync_host(); - bool passed = ref_host.bit_equals(ref_cublas); + passed = ref_host.bit_equals(ref_cublas); if ((!passed && save_on_error) || always_print) { save_workspace(ref_host, ref_cublas); } + #endif + return passed; } /// Verifies the reference implementation with cuBLAS bool verify_reference_with_cublas(bool save_on_error = true, bool always_print = false) { + bool passed = false; + + #if CUTLASS_ENABLE_CUBLAS compute_device_reference(); ref_device.sync_host(); compute_cublas(); ref_cublas.sync_host(); - bool passed = ref_device.bit_equals(ref_cublas); + passed = ref_device.bit_equals(ref_cublas); if ((!passed && save_on_error) || always_print) { save_workspace(ref_device, ref_cublas); } + #endif return passed; } @@ -948,15 +1002,26 @@ struct GemmTestbed { // verify on host passed = (passed && verify_with_host()); + #if CUTLASS_ENABLE_CUBLAS // verify with cublas passed = (passed && verify_with_cublas()); + #endif return passed; } - bool has_cublas_support() const { return cutlass::platform::is_same::value; } + bool has_cublas_support() const { + #if CUTLASS_ENABLE_CUBLAS + return cutlass::platform::is_same::value; + #else + return false; + #endif + } }; +////////////////////////////////////////////////////////////////////////////////////////// + +////////////////////////////////////////////////////////////////////////////////////////// // //specialization for cublasGemmStridedBatchedTemplate template<> inline cublasStatus_t GemmTestbed::cublasGemmStridedBatchedTemplate(cublasHandle_t handle, @@ -977,6 +1042,7 @@ template<> inline cublasStatus_t GemmTestbed: int ldc, long long int stride_C, int batchCount) { + #if CUTLASS_ENABLE_CUBLAS return cublasSgemmStridedBatched(handle, transa, transb, @@ -993,6 +1059,9 @@ template<> inline cublasStatus_t GemmTestbed: ldc, stride_C, batchCount); + #else + return CUBLAS_STATUS_NOT_SUPPORTED; + #endif } template<> inline cublasStatus_t GemmTestbed::cublasGemmStridedBatchedTemplate(cublasHandle_t handle, @@ -1013,6 +1082,7 @@ template<> inline cublasStatus_t GemmTestbed inline cublasStatus_t GemmTestbed inline cublasStatus_t GemmTestbed::cublasGemmStridedBatchedTemplate(cublasHandle_t handle, @@ -1049,6 +1122,7 @@ template<> inline cublasStatus_t GemmTestbedoperator half(); half temp_beta = beta->operator half(); return cublasHgemmStridedBatched(handle, @@ -1067,6 +1141,9 @@ template<> inline cublasStatus_t GemmTestbed inline cublasStatus_t GemmTestbed::cublasGemmStridedBatchedTemplate(cublasHandle_t handle, @@ -1087,6 +1164,7 @@ template<> inline cublasStatus_t GemmTestbed inline cublasStatus_t GemmTestbed::cublas_type, CUBLAS_GEMM_DEFAULT); + #else + return CUBLAS_STATUS_NOT_SUPPORTED; + #endif } } // namespace test diff --git a/tools/test/unit/gemm/hgemm_128x128x16.cu b/tools/test/unit/gemm/hgemm_128x128x16.cu index 1dd1c92e79..b3743b78c1 100644 --- a/tools/test/unit/gemm/hgemm_128x128x16.cu +++ b/tools/test/unit/gemm/hgemm_128x128x16.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/hgemm_128x128x8.cu b/tools/test/unit/gemm/hgemm_128x128x8.cu index f8184f2723..fa8f2cf6f9 100644 --- a/tools/test/unit/gemm/hgemm_128x128x8.cu +++ b/tools/test/unit/gemm/hgemm_128x128x8.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/hgemm_128x32x8.cu b/tools/test/unit/gemm/hgemm_128x32x8.cu index 34e2ba1ea1..181a0dd036 100644 --- a/tools/test/unit/gemm/hgemm_128x32x8.cu +++ b/tools/test/unit/gemm/hgemm_128x32x8.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/hgemm_128x64x8.cu b/tools/test/unit/gemm/hgemm_128x64x8.cu index 001b222422..8b96570d2a 100644 --- a/tools/test/unit/gemm/hgemm_128x64x8.cu +++ b/tools/test/unit/gemm/hgemm_128x64x8.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/igemm_128x128x32.cu b/tools/test/unit/gemm/igemm_128x128x32.cu index 6c891a45c4..91b0aa8cf5 100644 --- a/tools/test/unit/gemm/igemm_128x128x32.cu +++ b/tools/test/unit/gemm/igemm_128x128x32.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -22,6 +22,9 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ + +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) + #include "cutlass_unit_test.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/igemm_traits.h" @@ -355,3 +358,5 @@ TEST(Igemm_128x128x32, igemm_256x256x64_tt) { } //////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) diff --git a/tools/test/unit/gemm/igemm_128x128x32_float.cu b/tools/test/unit/gemm/igemm_128x128x32_float.cu index 08b7dbff23..743be2bc1e 100644 --- a/tools/test/unit/gemm/igemm_128x128x32_float.cu +++ b/tools/test/unit/gemm/igemm_128x128x32_float.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -22,6 +22,9 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ + +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) + #include "cutlass_unit_test.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/igemm_traits.h" @@ -356,3 +359,5 @@ TEST(Igemm_128x128x32_float, igemm_256x256x64_tt) { } //////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) diff --git a/tools/test/unit/gemm/igemm_128x128x32_int8.cu b/tools/test/unit/gemm/igemm_128x128x32_int8.cu index fbf5ca406d..7e81489eed 100644 --- a/tools/test/unit/gemm/igemm_128x128x32_int8.cu +++ b/tools/test/unit/gemm/igemm_128x128x32_int8.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -22,6 +22,9 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ + +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) + #include "cutlass_unit_test.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/igemm_traits.h" @@ -357,3 +360,5 @@ TEST(Igemm_128x128x32_int8, igemm_256x256x64_tt) { } //////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) diff --git a/tools/test/unit/gemm/igemm_128x32x32.cu b/tools/test/unit/gemm/igemm_128x32x32.cu index dabeb07dfc..298efaf921 100644 --- a/tools/test/unit/gemm/igemm_128x32x32.cu +++ b/tools/test/unit/gemm/igemm_128x32x32.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -22,6 +22,9 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ + +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) + #include "cutlass_unit_test.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/igemm_traits.h" @@ -358,3 +361,5 @@ TEST(Igemm_128x32x32, igemm_256x128x32_tt) { } //////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) diff --git a/tools/test/unit/gemm/igemm_128x64x32.cu b/tools/test/unit/gemm/igemm_128x64x32.cu index 279daafec4..af33708973 100644 --- a/tools/test/unit/gemm/igemm_128x64x32.cu +++ b/tools/test/unit/gemm/igemm_128x64x32.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -22,6 +22,9 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ + +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) + #include "cutlass_unit_test.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/igemm_traits.h" @@ -358,3 +361,5 @@ TEST(Igemm_128x64x32, igemm_256x128x64_tt) { } //////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) diff --git a/tools/test/unit/gemm/igemm_32x32x128.cu b/tools/test/unit/gemm/igemm_32x32x128.cu index 02434572f8..542392f019 100644 --- a/tools/test/unit/gemm/igemm_32x32x128.cu +++ b/tools/test/unit/gemm/igemm_32x32x128.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -22,6 +22,9 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ + +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) + #include "cutlass_unit_test.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/igemm_traits.h" @@ -236,3 +239,5 @@ TEST(Igemm_32x32x128, igemm_32x32x128_tt) { } //////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) diff --git a/tools/test/unit/gemm/partitionedK_sgemm_128x128x8.cu b/tools/test/unit/gemm/partitionedK_sgemm_128x128x8.cu index 0d4587811f..58631e8a3b 100644 --- a/tools/test/unit/gemm/partitionedK_sgemm_128x128x8.cu +++ b/tools/test/unit/gemm/partitionedK_sgemm_128x128x8.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/partitionedK_volta884_hgemm.cu b/tools/test/unit/gemm/partitionedK_volta884_hgemm.cu new file mode 100644 index 0000000000..211aafd8ab --- /dev/null +++ b/tools/test/unit/gemm/partitionedK_volta884_hgemm.cu @@ -0,0 +1,293 @@ +/*************************************************************************************************** +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +* +* Redistribution and use in source and binary forms, with or without modification, are permitted +* provided that the following conditions are met: +* * Redistributions of source code must retain the above copyright notice, this list of +* conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, this list of +* conditions and the following disclaimer in the documentation and/or other materials +* provided with the distribution. +* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +* to endorse or promote products derived from this software without specific prior written +* permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ +#include +#include +#include "cutlass_unit_test.h" + +#include "tools/util/half.h" +#include "tools/util/host_tensor.h" +#include "tools/util/tensor_view_io.h" + +#include "cutlass/gemm/volta884_gemm_traits.h" +#include "cutlass/gemm/gemm.h" + +#include "tools/test/unit/gemm/gemm_testbed.h" +#include "tools/test/unit/gemm/run_gemm.h" + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_partitionedK_64x64x32, volta884_h884gemm_128x256x88x10_nn) { + /* + for example + partitionedK gemm, m = 128, n = 256, overall_K = 88, partitionK_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 8 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 16 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + + int m = 128; + int n = 256; + int overall_k = 88; + int partitionK_count = 10; + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_partitioned_k_gemm(m, n, overall_k, partitionK_count); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_partitionedK_64x64x32, volta884_h884gemm_128x256x88x10_nt) { + /* + for example + partitionedK gemm, m = 128, n = 256, overall_K = 88, partitionK_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 8 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 16 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + + int m = 128; + int n = 256; + int overall_k = 88; + int partitionK_count = 10; + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_partitioned_k_gemm(m, n, overall_k, partitionK_count); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_partitionedK_64x64x32, volta884_h884gemm_128x256x88x10_tn) { + /* + for example + partitionedK gemm, m = 128, n = 256, overall_K = 88, partitionK_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 8 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 16 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + + int m = 128; + int n = 256; + int overall_k = 88; + int partitionK_count = 10; + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_partitioned_k_gemm(m, n, overall_k, partitionK_count); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_partitionedK_64x64x32, volta884_h884gemm_128x256x88x10_tt) { + /* + for example + partitionedK gemm, m = 128, n = 256, overall_K = 88, partitionK_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 8 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 16 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + + int m = 128; + int n = 256; + int overall_k = 88; + int partitionK_count = 10; + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_partitioned_k_gemm(m, n, overall_k, partitionK_count); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_partitionedK_64x64x32, volta884_h884gemm_128x256x128x10_nn) { + /* + for example + partitionedK gemm, m = 128, n = 256, overall_K = 128, partitionK_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 12. + But if we require the partition mulitple to be 8, the first 9 partition + k = k - (k % partition_mulitiple) = 8 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 56 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + + int m = 128; + int n = 256; + int overall_k = 128; + int partitionK_count = 10; + int partitionK_multiple = 8; + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_partitioned_k_gemm(m, n, overall_k, partitionK_count, partitionK_multiple); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_partitionedK_64x64x32, volta884_h884gemm_128x256x128x10_nt) { + /* + for example + partitionedK gemm, m = 128, n = 256, overall_K = 128, partitionK_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 12. + But if we require the partition mulitple to be 8, the first 9 partition + k = k - (k % partition_mulitiple) = 8 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 56 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + + int m = 128; + int n = 256; + int overall_k = 128; + int partitionK_count = 10; + int partitionK_multiple = 8; + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_partitioned_k_gemm(m, n, overall_k, partitionK_count, partitionK_multiple); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_partitionedK_64x64x32, volta884_h884gemm_128x256x128x10_tn) { + /* + for example + partitionedK gemm, m = 128, n = 256, overall_K = 128, partitionK_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 12. + But if we require the partition mulitple to be 8, the first 9 partition + k = k - (k % partition_mulitiple) = 8 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 56 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + + int m = 128; + int n = 256; + int overall_k = 128; + int partitionK_count = 10; + int partitionK_multiple = 8; + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_partitioned_k_gemm(m, n, overall_k, partitionK_count, partitionK_multiple); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_partitionedK_64x64x32, volta884_h884gemm_128x256x128x10_tt) { + /* + for example + partitionedK gemm, m = 128, n = 256, overall_K = 128, partitionK_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 12. + But if we require the partition mulitple to be 8, the first 9 partition + k = k - (k % partition_mulitiple) = 8 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 56 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + + int m = 128; + int n = 256; + int overall_k = 128; + int partitionK_count = 10; + int partitionK_multiple = 8; + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_partitioned_k_gemm(m, n, overall_k, partitionK_count, partitionK_multiple); +} + +#endif // if defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) diff --git a/tools/test/unit/gemm/run_gemm.h b/tools/test/unit/gemm/run_gemm.h index aabc55718f..3616d59142 100644 --- a/tools/test/unit/gemm/run_gemm.h +++ b/tools/test/unit/gemm/run_gemm.h @@ -44,7 +44,9 @@ static void run_gemm( typename test::GemmTestbedTraits::host_type beta = typename test::GemmTestbedTraits::host_type(0.0f)) { - typedef typename GemmTraits_::KernelClass Gemm; + //typedef typename GemmTraits_::KernelClass Gemm; + typedef cutlass::gemm::Gemm Gemm; + typename Gemm::Params params; test::GemmTestbed< @@ -106,6 +108,8 @@ static void run_gemm( //////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + template static void run_gemm( int m, @@ -115,8 +119,8 @@ static void run_gemm( typename test::GemmTestbedTraits::host_type(1.0f), typename test::GemmTestbedTraits::host_type beta = typename test::GemmTestbedTraits::host_type(0.0f)) { - //typedef cutlass::gemm::Gemm Gemm; - typedef typename GemmTraits_::KernelClass Gemm; + typedef cutlass::gemm::Gemm Gemm; + //typedef typename GemmTraits_::KernelClass Gemm; typename Gemm::Params params; typedef test::GemmTestbed< @@ -185,8 +189,8 @@ static void run_batched_strided_gemm( typename test::GemmTestbedTraits::host_type(1.0f), typename test::GemmTestbedTraits::host_type beta = typename test::GemmTestbedTraits::host_type(0.0f)) { - //typedef cutlass::gemm::Gemm Gemm; - typedef typename GemmTraits_::KernelClass Gemm; + typedef cutlass::gemm::Gemm Gemm; + //typedef typename GemmTraits_::KernelClass Gemm; typename Gemm::Params params; test::GemmTestbed< typename test::GemmTestbedTraits< @@ -254,6 +258,7 @@ template static void run_splitK_gemm(int m, int n, int k, + int partitionK_multiple = 1, /*requires each partition to be mulitple of partitionK_multiple*/ typename test::GemmTestbedTraits::host_type alpha = typename test::GemmTestbedTraits::host_type(1.0f), typename test::GemmTestbedTraits::host_type beta = @@ -283,11 +288,12 @@ static void run_splitK_gemm(int m, // create a device gemm typedef cutlass::gemm::SplitkPIGemmTraits deviceGemmTraits; - typedef typename deviceGemmTraits::KernelClass deviceGemm; + //typedef typename deviceGemmTraits::KernelClass deviceGemm; + typedef typename cutlass::gemm::DeviceGemm deviceGemm; typename deviceGemm::Params deviceGemmParams(testbed.M(), testbed.N(), testbed.K()); // query if workspace is needed - int workspace_size = deviceGemmParams.required_workspace_memory_in_byte(); + size_t workspace_size = deviceGemmParams.required_workspace_memory_in_byte(); typename test::GemmTestbedTraits::device_type *workspace_ptr = 0; if (workspace_size != 0) { @@ -306,7 +312,8 @@ static void run_splitK_gemm(int m, testbed.ldc(), testbed.ptr_computed(), testbed.ldc(), - workspace_ptr); + workspace_ptr, + partitionK_multiple); deviceGemm::launch(deviceGemmParams); @@ -337,12 +344,13 @@ static void run_partitioned_k_gemm( int n, int k, int partitionK_count, + int partitionK_multiple = 1, //requires each partition to be multiples of partitionK_multiple typename test::GemmTestbedTraits::host_type alpha = typename test::GemmTestbedTraits::host_type(1.0f), typename test::GemmTestbedTraits::host_type beta = typename test::GemmTestbedTraits::host_type(0.0f)) { - //typedef cutlass::gemm::Gemm Gemm; - typedef typename GemmTraits_::KernelClass Gemm; + typedef cutlass::gemm::Gemm Gemm; + //typedef typename GemmTraits_::KernelClass Gemm; typename Gemm::Params params; test::GemmTestbed< typename test::GemmTestbedTraits< @@ -358,6 +366,7 @@ static void run_partitioned_k_gemm( testbed(m, n, std::make_pair(k, partitionK_count), + partitionK_multiple, test::convert(GemmTraits_::kLayoutA), test::convert(GemmTraits_::kLayoutB), alpha, @@ -383,7 +392,8 @@ static void run_partitioned_k_gemm( testbed.ldc(), testbed.ptr_computed(), testbed.ldc(), - partitionK_count); + partitionK_count, + partitionK_multiple); Gemm::launch(params); diff --git a/tools/test/unit/gemm/sgemm_128x128x16.cu b/tools/test/unit/gemm/sgemm_128x128x16.cu index 40e49980d3..e54cba3dc3 100644 --- a/tools/test/unit/gemm/sgemm_128x128x16.cu +++ b/tools/test/unit/gemm/sgemm_128x128x16.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_128x128x8.cu b/tools/test/unit/gemm/sgemm_128x128x8.cu index a9931f3404..8dcfee2943 100644 --- a/tools/test/unit/gemm/sgemm_128x128x8.cu +++ b/tools/test/unit/gemm/sgemm_128x128x8.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_128x32x16.cu b/tools/test/unit/gemm/sgemm_128x32x16.cu index 2886eef5c1..c1b693cb90 100644 --- a/tools/test/unit/gemm/sgemm_128x32x16.cu +++ b/tools/test/unit/gemm/sgemm_128x32x16.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_128x32x8.cu b/tools/test/unit/gemm/sgemm_128x32x8.cu index 5e7a9f75b5..d65bae9d11 100644 --- a/tools/test/unit/gemm/sgemm_128x32x8.cu +++ b/tools/test/unit/gemm/sgemm_128x32x8.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_128x64x16.cu b/tools/test/unit/gemm/sgemm_128x64x16.cu index 5852a6b178..4b78330f5c 100644 --- a/tools/test/unit/gemm/sgemm_128x64x16.cu +++ b/tools/test/unit/gemm/sgemm_128x64x16.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_128x64x8.cu b/tools/test/unit/gemm/sgemm_128x64x8.cu index e07c38db34..df74a1b576 100644 --- a/tools/test/unit/gemm/sgemm_128x64x8.cu +++ b/tools/test/unit/gemm/sgemm_128x64x8.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_32x128x8.cu b/tools/test/unit/gemm/sgemm_32x128x8.cu new file mode 100644 index 0000000000..52cc362ac4 --- /dev/null +++ b/tools/test/unit/gemm/sgemm_32x128x8.cu @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include "cutlass_unit_test.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/sgemm_traits.h" +#include "tools/test/unit/gemm/gemm_testbed.h" +#include "tools/test/unit/gemm/run_gemm.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x128x1_nt) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 128, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x128x8_nt) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 128, 8); +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x128x16_nt) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x256x16_nt) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_64x256x16_nt) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(64, 256, 16); +} + +//NN +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x128x1_nn) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 128, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x128x8_nn) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 128, 8); +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x128x16_nn) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x256x16_nn) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_64x256x16_nn) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(64, 256, 16); +} + +//TN +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x128x1_tn) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 128, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x128x8_tn) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 128, 8); +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x128x16_tn) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x256x16_tn) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_64x256x16_tn) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(64, 256, 16); +} + +//TT +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x128x1_tt) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 128, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x128x8_tt) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 128, 8); +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x128x16_tt) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_32x256x16_tt) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(32, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_32x128x8, sgemm_64x256x16_tt) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(64, 256, 16); +} diff --git a/tools/test/unit/gemm/sgemm_64x128x16.cu b/tools/test/unit/gemm/sgemm_64x128x16.cu index c4afa3414c..0501585684 100644 --- a/tools/test/unit/gemm/sgemm_64x128x16.cu +++ b/tools/test/unit/gemm/sgemm_64x128x16.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_64x128x8.cu b/tools/test/unit/gemm/sgemm_64x128x8.cu index e87abd2fba..03869abfb4 100644 --- a/tools/test/unit/gemm/sgemm_64x128x8.cu +++ b/tools/test/unit/gemm/sgemm_64x128x8.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_64x32x16.cu b/tools/test/unit/gemm/sgemm_64x32x16.cu index 0cb0f2b760..0ffd0b6200 100644 --- a/tools/test/unit/gemm/sgemm_64x32x16.cu +++ b/tools/test/unit/gemm/sgemm_64x32x16.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_64x32x8.cu b/tools/test/unit/gemm/sgemm_64x32x8.cu index 3e8c60aaf8..7d853666d0 100644 --- a/tools/test/unit/gemm/sgemm_64x32x8.cu +++ b/tools/test/unit/gemm/sgemm_64x32x8.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_64x64x16.cu b/tools/test/unit/gemm/sgemm_64x64x16.cu index 45619cef81..2bbbbe8160 100644 --- a/tools/test/unit/gemm/sgemm_64x64x16.cu +++ b/tools/test/unit/gemm/sgemm_64x64x16.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_64x64x8.cu b/tools/test/unit/gemm/sgemm_64x64x8.cu index 7b02c46db5..8a92124bbb 100644 --- a/tools/test/unit/gemm/sgemm_64x64x8.cu +++ b/tools/test/unit/gemm/sgemm_64x64x8.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_threadblock_swizzle_nn.cu b/tools/test/unit/gemm/sgemm_threadblock_swizzle_nn.cu index fab5906608..f7e3bd4246 100644 --- a/tools/test/unit/gemm/sgemm_threadblock_swizzle_nn.cu +++ b/tools/test/unit/gemm/sgemm_threadblock_swizzle_nn.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_threadblock_swizzle_nt.cu b/tools/test/unit/gemm/sgemm_threadblock_swizzle_nt.cu index c436cdf539..5fc670afcc 100644 --- a/tools/test/unit/gemm/sgemm_threadblock_swizzle_nt.cu +++ b/tools/test/unit/gemm/sgemm_threadblock_swizzle_nt.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_threadblock_swizzle_tn.cu b/tools/test/unit/gemm/sgemm_threadblock_swizzle_tn.cu index b8b9f7fdc8..847ae57a72 100644 --- a/tools/test/unit/gemm/sgemm_threadblock_swizzle_tn.cu +++ b/tools/test/unit/gemm/sgemm_threadblock_swizzle_tn.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/sgemm_threadblock_swizzle_tt.cu b/tools/test/unit/gemm/sgemm_threadblock_swizzle_tt.cu index e1ceae68f7..b7d98a6fca 100644 --- a/tools/test/unit/gemm/sgemm_threadblock_swizzle_tt.cu +++ b/tools/test/unit/gemm/sgemm_threadblock_swizzle_tt.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/splitK_dgemm.cu b/tools/test/unit/gemm/splitK_dgemm.cu index 19e1e38efa..7493228649 100644 --- a/tools/test/unit/gemm/splitK_dgemm.cu +++ b/tools/test/unit/gemm/splitK_dgemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -54,7 +54,7 @@ TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x512_nn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -81,7 +81,7 @@ TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x512_nt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -108,7 +108,7 @@ TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x512_tn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -135,7 +135,7 @@ TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x512_tt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -162,7 +162,7 @@ TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x500_nn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -189,7 +189,7 @@ TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x500_nt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -216,7 +216,7 @@ TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x500_tn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -243,5 +243,5 @@ TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x500_tt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } diff --git a/tools/test/unit/gemm/splitK_fp16_sgemm_fp16.cu b/tools/test/unit/gemm/splitK_fp16_sgemm_fp16.cu index 7ed1138286..7a2968ff70 100644 --- a/tools/test/unit/gemm/splitK_fp16_sgemm_fp16.cu +++ b/tools/test/unit/gemm/splitK_fp16_sgemm_fp16.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -69,7 +69,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x512 cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -103,7 +103,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x512 cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -136,7 +136,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x512 cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -170,7 +170,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x512 cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -204,7 +204,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x500 cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -238,7 +238,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x500 cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -271,7 +271,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x500 cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -305,7 +305,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x500 cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -339,7 +339,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x512_ cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -373,7 +373,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x512_ cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -406,7 +406,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x512_ cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -440,7 +440,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x512_ cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -474,7 +474,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x500_ cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -508,7 +508,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x500_ cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -541,7 +541,7 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x500_ cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -575,5 +575,5 @@ TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x500_ cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } diff --git a/tools/test/unit/gemm/splitK_hgemm.cu b/tools/test/unit/gemm/splitK_hgemm.cu index 5af20936be..d3392940eb 100644 --- a/tools/test/unit/gemm/splitK_hgemm.cu +++ b/tools/test/unit/gemm/splitK_hgemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -54,7 +54,7 @@ TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x64_nn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 1.0f, 0.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 1.0f, 0.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -81,7 +81,7 @@ TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x64_nt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -108,7 +108,7 @@ TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x64_tn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -135,7 +135,7 @@ TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x64_tt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -162,7 +162,7 @@ TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x66_nn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 1.0f, 0.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 1.0f, 0.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -189,7 +189,7 @@ TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x66_nt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -216,7 +216,7 @@ TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x66_tn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -243,6 +243,6 @@ TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x66_tt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } diff --git a/tools/test/unit/gemm/splitK_igemm.cu b/tools/test/unit/gemm/splitK_igemm.cu index f788dc9739..79567403ec 100644 --- a/tools/test/unit/gemm/splitK_igemm.cu +++ b/tools/test/unit/gemm/splitK_igemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -22,6 +22,9 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ + +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) + #include "cutlass_unit_test.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/igemm_traits.h" @@ -83,7 +86,7 @@ TEST(SplitK_igemm_128x128x32_splits16, igemm_128x256x512_nt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2, 1, true /*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2, 1, true /*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -111,7 +114,7 @@ TEST(SplitK_igemm_128x128x32_splits16, igemm_128x256x512_tn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2, 1, true /*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2, 1, true /*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -139,7 +142,7 @@ TEST(SplitK_igemm_128x128x32_splits16, igemm_128x256x512_tt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2, 1, true /*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2, 1, true /*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -167,7 +170,7 @@ TEST(SplitK_igemm_128x128x32_splits16, igemm_1024x64x4096_nn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 1, 0, false /*not use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 1, 0, false /*not use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -195,7 +198,7 @@ TEST(SplitK_igemm_128x128x32_splits16, igemm_1024x64x4096_nt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 1, 0, false /*not use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 1, 0, false /*not use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -223,7 +226,7 @@ TEST(SplitK_igemm_128x128x32_splits16, igemm_1024x64x4096_tn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 1, 0, false /*not use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 1, 0, false /*not use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -251,7 +254,7 @@ TEST(SplitK_igemm_128x128x32_splits16, igemm_1024x64x4096_tt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 1, 0, false /*not use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 1, 0, false /*not use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -279,7 +282,7 @@ TEST(SplitK_igemm_128x32x32_splits16, igemm_1024x64x4096_nn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 1, 0, false /*not use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 1, 0, false /*not use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -307,7 +310,7 @@ TEST(SplitK_igemm_128x32x32_splits16, igemm_1024x64x4096_nt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 1, 0, false /*not use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 1, 0, false /*not use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -335,7 +338,7 @@ TEST(SplitK_igemm_128x32x32_splits16, igemm_1024x64x4096_tn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 1, 0, false /*not use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 1, 0, false /*not use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -363,5 +366,7 @@ TEST(SplitK_igemm_128x32x32_splits16, igemm_1024x64x4096_tt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 1, 0, false /*not use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 1, 0, false /*not use host reference*/); } + +#endif // if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610)) diff --git a/tools/test/unit/gemm/splitK_sgemm.cu b/tools/test/unit/gemm/splitK_sgemm.cu index 5e1885f81b..19b4a9f445 100644 --- a/tools/test/unit/gemm/splitK_sgemm.cu +++ b/tools/test/unit/gemm/splitK_sgemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -54,7 +54,7 @@ TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x512_nn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -81,7 +81,7 @@ TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x512_nt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -108,7 +108,7 @@ TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x512_tn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -135,7 +135,7 @@ TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x512_tt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -162,7 +162,7 @@ TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x500_nn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -189,7 +189,7 @@ TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x500_nt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -216,7 +216,7 @@ TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x500_tn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -243,7 +243,7 @@ TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x500_tt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -270,7 +270,7 @@ TEST(SplitK_sgemm_128x128x8_splits16, sgemm_1024x64x4096_nn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -297,7 +297,7 @@ TEST(SplitK_sgemm_128x128x8_splits16, sgemm_1024x64x4096_nt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -324,7 +324,7 @@ TEST(SplitK_sgemm_128x128x8_splits16, sgemm_1024x64x4096_tn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -351,5 +351,5 @@ TEST(SplitK_sgemm_128x128x8_splits16, sgemm_1024x64x4096_tt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f); } diff --git a/tools/test/unit/gemm/splitK_volta884_hgemm.cu b/tools/test/unit/gemm/splitK_volta884_hgemm.cu new file mode 100644 index 0000000000..190b753dd9 --- /dev/null +++ b/tools/test/unit/gemm/splitK_volta884_hgemm.cu @@ -0,0 +1,507 @@ +/*************************************************************************************************** +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +* +* Redistribution and use in source and binary forms, with or without modification, are permitted +* provided that the following conditions are met: +* * Redistributions of source code must retain the above copyright notice, this list of +* conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, this list of +* conditions and the following disclaimer in the documentation and/or other materials +* provided with the distribution. +* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +* to endorse or promote products derived from this software without specific prior written +* permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ +#include +#include +#include "cutlass_unit_test.h" + +#include "tools/util/half.h" +#include "tools/util/host_tensor.h" +#include "tools/util/tensor_view_io.h" + +#include "cutlass/gemm/volta884_gemm_traits.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/reduction/batched_reduction_traits.h" + +#include "tools/test/unit/gemm/gemm_testbed.h" +#include "tools/test/unit/gemm/run_gemm.h" + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_splitK_h884gemm_64x64x32_splits16, volta884_h884gemm_128x256x512_nn) { + const int splits_count = 16; + const int m = 128; + const int n = 256; + const int k = 512; + + /*gemm traits*/ + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + cutlass::Shape<1, 1, 2> > + BatchedReductionTraits; + + run_splitK_gemm(m, n, k, 8/*partitionK_multiple*/, 1.0f, 0.0f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_splitK_h884gemm_64x64x32_splits16, volta884_h884gemm_128x256x512_nt) { + const int splits_count = 16; + const int m = 128; + const int n = 256; + const int k = 512; + + /*gemm traits*/ + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + cutlass::Shape<1, 1, 2> > + BatchedReductionTraits; + + run_splitK_gemm(m, n, k, 8/*partitionK_multiple*/, 1.0f, 0.0f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_splitK_h884gemm_64x64x32_splits16, volta884_h884gemm_128x256x512_tn) { + const int splits_count = 16; + const int m = 128; + const int n = 256; + const int k = 512; + + /*gemm traits*/ + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + cutlass::Shape<1, 1, 2> > + BatchedReductionTraits; + + run_splitK_gemm(m, n, k, 8/*partitionK_multiple*/, 1.0f, 0.0f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_splitK_h884gemm_64x64x32_splits16, volta884_h884gemm_128x256x512_tt) { + const int splits_count = 16; + const int m = 128; + const int n = 256; + const int k = 512; + + /*gemm traits*/ + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + cutlass::Shape<1, 1, 2> > + BatchedReductionTraits; + + run_splitK_gemm(m, n, k, 8/*partitionK_multiple*/, 1.0f, 0.0f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_splitK_h884gemm_64x64x32_splits10, volta884_h884gemm_128x256x88_nn) { + /* + m = 128, n = 256, overall_K = 88, splits_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 8 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 16 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + + const int splits_count = 10; + const int m = 128; + const int n = 256; + const int k = 88; + + /*gemm traits*/ + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + cutlass::Shape<1, 1, 2> > + BatchedReductionTraits; + + run_splitK_gemm(m, n, k, 8/*partitionK_multiple*/, 1.0f, 0.0f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_splitK_h884gemm_64x64x32_splits10, volta884_h884gemm_128x256x88_nt) { + /* + m = 128, n = 256, overall_K = 88, splits_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 8 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 16 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + const int splits_count = 10; + const int m = 128; + const int n = 256; + const int k = 88; + + /*gemm traits*/ + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + cutlass::Shape<1, 1, 2> > + BatchedReductionTraits; + + run_splitK_gemm(m, n, k, 8/*partitionK_multiple*/, 1.0f, 0.0f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_splitK_h884gemm_64x64x32_splits10, volta884_h884gemm_128x256x88_tn) { + /* + m = 128, n = 256, overall_K = 88, splits_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 8 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 16 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + const int splits_count = 10; + const int m = 128; + const int n = 256; + const int k = 88; + + /*gemm traits*/ + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + cutlass::Shape<1, 1, 2> > + BatchedReductionTraits; + + run_splitK_gemm(m, n, k, 8/*partitionK_multiple*/, 1.0f, 0.0f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_splitK_h884gemm_64x64x32_splits10, volta884_h884gemm_128x256x88_tt) { + /* + m = 128, n = 256, overall_K = 88, splits_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 8 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 16 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + const int splits_count = 10; + const int m = 128; + const int n = 256; + const int k = 88; + + /*gemm traits*/ + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + cutlass::Shape<1, 1, 2> > + BatchedReductionTraits; + + run_splitK_gemm(m, n, k, 8/*partitionK_multiple*/, 1.0f, 0.0f); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_splitK_h884gemm_64x64x32_splits10, volta884_h884gemm_128x256x256_nn) { + /* + m = 128, n = 256, overall_K = 256, splits_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 25 + But if we require the partition mulitple to be 8, the first 9 partition + k = k - (k % partition_mulitiple) = 24 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 40 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + const int splits_count = 10; + const int m = 128; + const int n = 256; + const int k = 256; + + /*gemm traits*/ + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + cutlass::Shape<1, 1, 2> > + BatchedReductionTraits; + + run_splitK_gemm(m, n, k, 8/*partitionK_multiple*/, 1.0f, 0.0f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_splitK_h884gemm_64x64x32_splits10, volta884_h884gemm_128x256x256_nt) { + /* + m = 128, n = 256, overall_K = 256, splits_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 25 + But if we require the partition mulitple to be 8, the first 9 partition + k = k - (k % partition_mulitiple) = 24 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 40 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + const int splits_count = 10; + const int m = 128; + const int n = 256; + const int k = 256; + + /*gemm traits*/ + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + cutlass::Shape<1, 1, 2> > + BatchedReductionTraits; + + run_splitK_gemm(m, n, k, 8/*partitionK_multiple*/, 1.0f, 0.0f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_splitK_h884gemm_64x64x32_splits10, volta884_h884gemm_128x256x256_tn) { + /* + m = 128, n = 256, overall_K = 256, splits_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 25 + But if we require the partition mulitple to be 8, the first 9 partition + k = k - (k % partition_mulitiple) = 24 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 40 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + const int splits_count = 10; + const int m = 128; + const int n = 256; + const int k = 256; + + /*gemm traits*/ + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + cutlass::Shape<1, 1, 2> > + BatchedReductionTraits; + + run_splitK_gemm(m, n, k, 8/*partitionK_multiple*/, 1.0f, 0.0f); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_splitK_h884gemm_64x64x32_splits10, volta884_h884gemm_128x256x256_tt) { + /* + m = 128, n = 256, overall_K = 256, splits_count = 10 + for the first 9 partition k = overall_k / partitionK_count = 25 + But if we require the partition mulitple to be 8, the first 9 partition + k = k - (k % partition_mulitiple) = 24 + for the last partition last_k = overall_k - (partitionK_count - 1) * k = 40 + for volta884 it is safe to make sure leading dim are multiple of 8 + */ + const int splits_count = 10; + const int m = 128; + const int n = 256; + const int k = 256; + + /*gemm traits*/ + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + /*batched reduction traits*/ + typedef cutlass::reduction::BatchedReductionTraits, + cutlass::Shape<1, 1, 64>, + cutlass::Shape<1, 1, 2> > + BatchedReductionTraits; + + run_splitK_gemm(m, n, k, 8/*partitionK_multiple*/, 1.0f, 0.0f); +} + +#endif diff --git a/tools/test/unit/gemm/splitK_wmma_gemm.cu b/tools/test/unit/gemm/splitK_wmma_gemm.cu index 7b035b4512..af9887fc4f 100644 --- a/tools/test/unit/gemm/splitK_wmma_gemm.cu +++ b/tools/test/unit/gemm/splitK_wmma_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -64,7 +64,7 @@ TEST(SplitK_wmma_gemm_16x16x32_splits16, wmma_gemm_128x256x512_nn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 2.0f, 1.0f, true/*use host reference*/); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 2.0f, 1.0f, true/*use host reference*/); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -99,7 +99,7 @@ TEST(SplitK_wmma_gemm_16x16x32_splits16, wmma_gemm_128x256x512_nt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 1.0f, 0.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 1.0f, 0.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -134,7 +134,7 @@ TEST(SplitK_wmma_gemm_16x16x32_splits16, wmma_gemm_128x256x512_tn) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 1.0f, 0.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 1.0f, 0.0f); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -169,7 +169,7 @@ TEST(SplitK_wmma_gemm_16x16x32_splits16, wmma_gemm_128x256x512_tt) { cutlass::Shape<1, 1, 2> > BatchedReductionTraits; - run_splitK_gemm(m, n, k, 1.0f, 0.0f); + run_splitK_gemm(m, n, k, 1/*partitionK_multiple*/, 1.0f, 0.0f); } #endif diff --git a/tools/test/unit/gemm/volta884_gemm.cu b/tools/test/unit/gemm/volta884_gemm.cu new file mode 100644 index 0000000000..707c6d3743 --- /dev/null +++ b/tools/test/unit/gemm/volta884_gemm.cu @@ -0,0 +1,1287 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include "cutlass_unit_test.h" + +#include "tools/util/half.h" +#include "tools/util/host_tensor.h" +#include "tools/util/tensor_view_io.h" + +#include "cutlass/gemm/volta884_gemm_traits.h" +#include "cutlass/gemm/gemm.h" + +#include "tools/test/unit/gemm/gemm_testbed.h" +#include "tools/test/unit/gemm/run_gemm.h" + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Very small warp sizes +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_nn, short_480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_nt, short_480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_tn, short_480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_tt, short_480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Short compile time - s884gemm +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_nn, short_480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_128x128x32_nt, short_480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_128x128x32_tn, short_480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_128x128x32_tt, short_480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Contiguous - s884gemm +// +//////////////////////////////////////////////////////////////////////////////////////////////////// +#if 0 +TEST(Volta884_f16_s884gemm_64x64x32_nt, 64x64x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(64, 64, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_nt, 64x64x30_residue) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(64, 64, 30); +} + +#if 0 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_nt, 64x64x64) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(64, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_nt, 128x64x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 64, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_nt, 128x128x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 128, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x64x32_nt, 128x64x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 64, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x64x32_nt, 128x64x128) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 64, 128); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_nt, 128x128x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 128, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_nt, 128x128x128) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 128, 128); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_nt, 384x256x64) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(384, 256, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_nt, 392x264x64) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(392, 264, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_nt, 392x264x192) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(392, 264, 192); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_nt, 480x280x223) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(480, 280, 223); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Crosswise +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_tn, 64x64x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(64, 64, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_tn, 64x64x64) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(64, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_tn, 64x64x24_residue) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(64, 64, 24); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_tn, 128x64x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 64, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_tn, 128x128x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 128, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x64x32_tn, 128x64x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 64, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x64x32_tn, 128x64x128) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 64, 128); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_tn, 128x128x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 128, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_tn, 128x128x128) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 128, 128); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_tn, 384x256x64) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(384, 256, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_tn, 392x264x64) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(392, 264, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_tn, 392x264x192) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(392, 264, 192); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_tn, 480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Congruous-Crosswise +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_nn, 64x64x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(64, 64, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_nn, 64x64x64) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(64, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_nn, 64x64x24_residue) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(64, 64, 24); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_nn, 128x64x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 64, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_nn, 128x128x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 128, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x64x32_nn, 128x64x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 64, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x64x32_nn, 128x64x128) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 64, 128); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_nn, 128x128x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 128, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_nn, 128x128x128) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 128, 128); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_nn, 384x256x64) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(384, 256, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_nn, 392x264x64) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(392, 264, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_nn, 392x264x192) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(392, 264, 192); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_nn, 480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Crosswise-Congruous +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_tt, 64x64x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(64, 64, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_tt, 64x64x64) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(64, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_tt, 64x64x24_residue) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(64, 64, 24); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_tt, 128x64x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 64, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_tt, 128x128x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 128, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x64x32_tt, 128x64x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 64, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x64x32_tt, 128x64x128) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 64, 128); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_tt, 128x128x32) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 128, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_tt, 128x128x128) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(128, 128, 128); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_tt, 384x256x64) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(384, 256, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_tt, 392x264x64) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(392, 264, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_tt, 392x264x192) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(392, 264, 192); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_128x128x32_tt, 480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + float, + float, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// FP32 accumulation, FP16 output +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_f16_128x128x32_nn, 480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_f16_128x128x32_nt, 480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_f16_128x128x32_tn, 480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_f16_128x128x32_tt, 480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_f16_64x64x32_nt, 480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_f16_64x128x32_nt, 480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 64>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_f16_128x64x32_tn, 480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_f16_256x128x32_tn, 480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 256>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_f16_128x256x32_tn, 480x280x224) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 256, 128>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2 + > GemmTraits; + + run_gemm(480, 280, 224); +} +#endif +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) diff --git a/tools/test/unit/gemm/volta884_gemm_epilogue.cu b/tools/test/unit/gemm/volta884_gemm_epilogue.cu new file mode 100644 index 0000000000..9f01777467 --- /dev/null +++ b/tools/test/unit/gemm/volta884_gemm_epilogue.cu @@ -0,0 +1,453 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory + with + the computed matrix product. +*/ +// clang-format off + +#include +#include +#include "cutlass_unit_test.h" + +#include "tools/util/half.h" +#include "tools/util/host_matrix.h" +#include "tools/util/tensor_view_io.h" + +#include "cutlass/tile_traits_standard.h" +#include "cutlass/gemm/linear_scaling.h" + +#include "cutlass/gemm/volta884_multiplicand.h" +#include "cutlass/gemm/volta884_multiply_add.h" +#include "cutlass/gemm/mma_global_stream.h" +#include "cutlass/gemm/volta884_gemm_epilogue_traits.h" +#include "cutlass/gemm/volta884_shared_tile.h" +#include "cutlass/gemm/mma_shared_stream.h" +#include "cutlass/gemm/mma_epilogue.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel that verifies the Volta884 epilogue against the naive epilogue implementation +template +__global__ void test_volta884_epilogue( + typename EpilogueTraits::Params params, + AccumulatorType *ptr_Product, + int ldm, + cutlass::Coord<3> problem_size) { + + // Shared memoryallocation + __shared__ typename EpilogueTraits::SharedStorage shared_storage; + + // Construct the epilogue + cutlass::gemm::MMAEpilogue epilogue(params, shared_storage, problem_size); + + // Initialize accumulators + typedef typename EpilogueTraits::Accumulators Accumulators; + + typedef typename cutlass::gemm::Volta884NaiveEpilogue< + AccumulatorType, + typename EpilogueTraits::WarpDelta, + cutlass::Shape<2,2,2,2> > NaiveEpilogue; + + Accumulators accumulators; + + // Artificially load accumulators with some random matrix product + NaiveEpilogue naive(ptr_Product, ldm); + naive.load(accumulators); + + // Store the accumulators + epilogue.epilogue(accumulators); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ScalarC, + /// Specifies the delta between warp accesses along the outer dimension + typename WarpDelta +> +struct Volta884EpilogueTestbed { + + // + // Type definitions + // + + /// Warp-level tile + typedef cutlass::Shape<4, 64, 64> WarpGemmTile; + + /// Thread-block scoped tile + typedef typename cutlass::ShapeMul< + WarpGemmTile, + WarpDelta + >::Shape OutputTile; + + /// Multiply-add operation + typedef cutlass::gemm::Volta884MultiplyAdd< + WarpGemmTile, + cutlass::MatrixLayout::kColumnMajor, + half, + cutlass::MatrixLayout::kRowMajor, + half, + ScalarC + > MultiplyAdd; + + // + // Parameters for the epilogue + // + + /// Epilogue functor + typedef cutlass::gemm::LinearScaling Functor; + + /// Traits for global tile access + typedef cutlass::gemm::Volta884EpilogueGlobalTileTraits< + WarpGemmTile, + WarpDelta, + 1, + ScalarC + > EpilogueGlobalTileTraits; + + + /// Defines traits for an epilogue of a Volta884 GEMM + typedef cutlass::gemm::Volta884EpilogueTraits< + OutputTile, + WarpGemmTile, + WarpDelta, + typename MultiplyAdd::Accumulators, + cutlass::gemm::Volta884SelectAccumulators< + WarpGemmTile, + WarpDelta, + ScalarC + >, + cutlass::PredicatedTileLoadStream< + cutlass::TileLoadIterator< + EpilogueGlobalTileTraits, + ScalarC, + cutlass::IteratorAdvance::kH, + cutlass::MemorySpace::kGlobal + >, + cutlass::gemm::Volta884EpiloguePredicateFunctor + >, + cutlass::PredicatedTileStoreStream< + cutlass::TileStoreIterator< + EpilogueGlobalTileTraits, + ScalarC, + cutlass::IteratorAdvance::kH, + cutlass::MemorySpace::kGlobal + >, + cutlass::gemm::Volta884EpiloguePredicateFunctor + >, + cutlass::TileStoreStream< + cutlass::gemm::Volta884EpilogueSharedStoreIterator< + WarpGemmTile, + WarpDelta, + ScalarC, + ScalarC + > + >, + cutlass::TileLoadStream< + cutlass::gemm::Volta884EpilogueSharedLoadIterator< + WarpGemmTile, + WarpDelta, + ScalarC, + 1, + ScalarC + > + >, + Functor + > EpilogueTraits; + + // + // + // + + /// Generates random elements + template + struct RandomGenerator { + RandomGenerator( + int seed = -1 + ) { srand(seed); } + + T operator()() { + int val = (rand() % 29) - 13; + return T(val); + } + }; + + typedef typename cutlass::TypeTraits::host_type ScalarCHost; + + // + // Data members + // + + /// Input accumulator matrix + cutlass::HostMatrix tensor_C; + + /// Matrix product + cutlass::HostMatrix tensor_Product; + + /// Reference output + cutlass::HostMatrix tensor_Ref; + + /// Computed output + cutlass::HostMatrix tensor_D; + + // + // Methods + // + + Volta884EpilogueTestbed() { + tensor_C.resize(OutputTile::kW, OutputTile::kH, cutlass::MatrixLayout::kColumnMajor); + tensor_Product.resize(OutputTile::kW, OutputTile::kH, cutlass::MatrixLayout::kColumnMajor); + tensor_Ref.resize(OutputTile::kW, OutputTile::kH, cutlass::MatrixLayout::kColumnMajor); + tensor_D.resize_matrix(OutputTile::kW, OutputTile::kH, cutlass::MatrixLayout::kColumnMajor); + } + + /// Runs a test case + bool run() { + + tensor_C.fill_sequential(); + tensor_Product.fill_random(RandomGenerator(17)); + + tensor_D.fill(ScalarCHost(0)); + tensor_Ref.fill(ScalarCHost(0)); + + tensor_C.sync_device(); + tensor_Product.sync_device(); + tensor_D.sync_device(); + + // run kernel + dim3 grid(1, 1); + dim3 block(32 * cutlass::ShapeCount::kCount, 1, 1); + + typename EpilogueTraits::Params params; + + params.load_stream_c.iterator.initialize( + tensor_C.device_data(), + tensor_C.leading_dim(), + tensor_C.leading_dim(), + 1); + + params.store_stream_d.iterator.initialize( + tensor_D.device_data(), + tensor_D.leading_dim(), + tensor_D.leading_dim(), + 1); + + ScalarCHost alpha = 2; + ScalarCHost beta = 1; + + params.functor.initialize(alpha, beta); + + cutlass::Coord<3> problem_size = cutlass::make_Coord( + 128, + 64 * EpilogueTraits::WarpDelta::kH - 7, + 64 * EpilogueTraits::WarpDelta::kW - 5); + + test_volta884_epilogue<<< grid, block >>>( + params, + tensor_Product.device_data(), + tensor_Product.leading_dim(), + problem_size + ); + + EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); + + // Copy to host + tensor_D.sync_host(); + + // Compute reference based on alpha, beta, and the problem dimensions + for (int j = 0; j < OutputTile::kH; ++j) { + for (int i = 0; i < OutputTile::kW; ++i) { + if (j < problem_size[1] && i < problem_size[2]) { + tensor_Ref.host_data()[i + j * tensor_Ref.leading_dim()] = + alpha * tensor_Product.host_data()[i + j * tensor_Product.leading_dim()] + + beta * tensor_C.host_data()[i + j * tensor_C.leading_dim()]; + } + } + } + + // Verify result + bool passed = tensor_D.bit_equals(tensor_Ref); + + if (!passed) { + std::cout << "Mismatch:\n" + << "Product = \n" << tensor_Product << "\n\n" + << "C =\n" << tensor_C << "\n\n" + << "Reference =\n" << tensor_Ref << "\n\n" + << "D =\n" << tensor_D << std::endl; + } + + return passed; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_epilogue_f32, 64x64x32) { + + Volta884EpilogueTestbed< + float, + cutlass::Shape<1, 1, 1, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + + +TEST(volta884_epilogue_f32, 64x128x32) { + + Volta884EpilogueTestbed< + float, + cutlass::Shape<1, 2, 1, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + + +TEST(volta884_epilogue_f32, 128x64x32) { + + Volta884EpilogueTestbed< + float, + cutlass::Shape<1, 1, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +TEST(volta884_epilogue_f32, 128x128x32) { + + Volta884EpilogueTestbed< + float, + cutlass::Shape<1, 2, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + + +TEST(volta884_epilogue_f32, 256x128x32) { + + Volta884EpilogueTestbed< + float, + cutlass::Shape<1, 2, 4, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +TEST(volta884_epilogue_f32, 128x256x32) { + + Volta884EpilogueTestbed< + float, + cutlass::Shape<1, 4, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_epilogue_f16, 64x64x32) { + + Volta884EpilogueTestbed< + half, + cutlass::Shape<1, 1, 1, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_epilogue_f16, 128x64x32) { + + Volta884EpilogueTestbed< + half, + cutlass::Shape<1, 1, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_epilogue_f16, 64x128x32) { + + Volta884EpilogueTestbed< + half, + cutlass::Shape<1, 2, 1, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_epilogue_f16, 128x128x32) { + + Volta884EpilogueTestbed< + half, + cutlass::Shape<1, 2, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_epilogue_f16, 256x128x32) { + + Volta884EpilogueTestbed< + half, + cutlass::Shape<1, 2, 4, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_epilogue_f16, 128x256x32) { + + Volta884EpilogueTestbed< + half, + cutlass::Shape<1, 4, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ENABLE_TENSOR_CORE_MMA + +// clang-format on diff --git a/tools/test/unit/gemm/volta884_gemm_threadblock_swizzle.cu b/tools/test/unit/gemm/volta884_gemm_threadblock_swizzle.cu new file mode 100644 index 0000000000..a0e45ad445 --- /dev/null +++ b/tools/test/unit/gemm/volta884_gemm_threadblock_swizzle.cu @@ -0,0 +1,496 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include "cutlass_unit_test.h" + +#include "tools/util/half.h" +#include "tools/util/host_tensor.h" +#include "tools/util/tensor_view_io.h" + +#include "cutlass/gemm/volta884_gemm_traits.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock_swizzle.h" +#include "cutlass/gemm/linear_scaling.h" + +#include "tools/test/unit/gemm/gemm_testbed.h" +#include "tools/test/unit/gemm/run_gemm.h" + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Very small warp sizes +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_nn_swizzle, short_480x280x224_rowMajorSwizzle) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_nn_swizzle, short_480x280x224_rowMajorSwizzle_groupCol2) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_nn_swizzle, short_480x280x224_rowMajorSwizzle_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_nn_swizzle, short_480x280x224_rowMajorSwizzle_groupCol2_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_nn_swizzle, short_480x280x224_columnMajorSwizzle) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_nn_swizzle, short_480x280x224_columnMajorSwizzle_groupCol2) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_nn_swizzle, short_480x280x224_columnMajorSwizzle_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_nn_swizzle, short_480x280x224_columnMajorSwizzle_groupCol2_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_tt_swizzle, short_480x280x224_rowMajorSwizzle) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_tt_swizzle, short_480x280x224_rowMajorSwizzle_groupCol2) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_tt_swizzle, short_480x280x224_rowMajorSwizzle_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_tt_swizzle, short_480x280x224_rowMajorSwizzle_groupCol2_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_tt_swizzle, short_480x280x224_columnMajorSwizzle) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_tt_swizzle, short_480x280x224_columnMajorSwizzle_groupCol2) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_tt_swizzle, short_480x280x224_columnMajorSwizzle_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_64x64x32_32x32x32_tt_swizzle, short_480x280x224_columnMajorSwizzle_groupCol2_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 32, 32>, + float, + float, + float, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// FP32 accumulation, FP16 output +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_f16_128x128x32_nn_swizzle, 480x280x224_rowMajorSwizzle) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_f16_128x128x32_nn_swizzle, 480x280x224_rowMajorSwizzle_groupCol2) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_f16_128x128x32_nn_swizzle, 480x280x224_rowMajorSwizzle_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_f16_128x128x32_nn_swizzle, 480x280x224_rowMajorSwizzle_groupCol2_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_f16_s884gemm_f16_128x128x32_nn_swizzle, 480x280x224_columnMajorSwizzle) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_f16_128x128x32_nn_swizzle, 480x280x224_columnMajorSwizzle_groupCol2) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_f16_128x128x32_nn_swizzle, 480x280x224_columnMajorSwizzle_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(480, 280, 224); +} + +TEST(Volta884_f16_s884gemm_f16_128x128x32_nn_swizzle, 480x280x224_columnMajorSwizzle_groupCol2_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + float, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(480, 280, 224); +} + + +#endif // if defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) diff --git a/tools/test/unit/gemm/volta884_h884gemm.cu b/tools/test/unit/gemm/volta884_h884gemm.cu new file mode 100644 index 0000000000..6a858aa543 --- /dev/null +++ b/tools/test/unit/gemm/volta884_h884gemm.cu @@ -0,0 +1,246 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include "cutlass_unit_test.h" + +#include "tools/util/half.h" +#include "tools/util/host_tensor.h" +#include "tools/util/tensor_view_io.h" + +#include "cutlass/gemm/volta884_gemm_traits.h" +#include "cutlass/gemm/gemm.h" + +#include "tools/test/unit/gemm/gemm_testbed.h" +#include "tools/test/unit/gemm/run_gemm.h" + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Contiguous - h884gemm +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_64x64x32_nt, 520x264x136) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_128x64x32_nt, 520x264x136) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x128x32_nt, 520x264x136) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_128x128x32_nt, 520x264x136) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_256x128x32_nt, 520x264x136) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 128, 256>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_128x256x32_nt, 520x264x136) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 256, 128>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x64x32_tn, 520x264x136) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_128x64x32_tn, 520x264x136) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x128x32_tn, 520x264x136) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_128x128x32_tn, 520x264x136) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_256x128x32_tn, 520x264x136) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 128, 256>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_128x256x32_tn, 520x264x136) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 256, 128>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2 + > GemmTraits; + + run_gemm(520, 264, 136); +} + +#endif // #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + +#endif // defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) + diff --git a/tools/test/unit/gemm/volta884_h884gemm_threadblock_swizzle.cu b/tools/test/unit/gemm/volta884_h884gemm_threadblock_swizzle.cu new file mode 100644 index 0000000000..29babc8050 --- /dev/null +++ b/tools/test/unit/gemm/volta884_h884gemm_threadblock_swizzle.cu @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include "cutlass_unit_test.h" + +#include "tools/util/half.h" +#include "tools/util/host_tensor.h" +#include "tools/util/tensor_view_io.h" + +#include "cutlass/gemm/volta884_gemm_traits.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock_swizzle.h" +#include "cutlass/gemm/linear_scaling.h" + +#include "tools/test/unit/gemm/gemm_testbed.h" +#include "tools/test/unit/gemm/run_gemm.h" + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Contiguous - h884gemm +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_64x64x32_nt_swizzle, 520x264x136_RowMajorSwizzle) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x64x32_nt_swizzle, 520x264x136_RowMajorSwizzle_groupCol2) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x64x32_nt_swizzle, 520x264x136_RowMajorSwizzle_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x64x32_nt_swizzle, 520x264x136_RowMajorSwizzle_groupCol2_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_64x64x32_nt_swizzle, 520x264x136_ColumnMajorSwizzle) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x64x32_nt_swizzle, 520x264x136_ColumnMajorSwizzle_groupCol2) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x64x32_nt_swizzle, 520x264x136_ColumnMajorSwizzle_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x64x32_nt_swizzle, 520x264x136_ColumnMajorSwizzle_groupCol2_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_64x64x32_tn_swizzle, 520x264x136_RowMajorSwizzle) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x64x32_tn_swizzle, 520x264x136_RowMajorSwizzle_groupCol2) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x64x32_tn_swizzle, 520x264x136_RowMajorSwizzle_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x64x32_tn_swizzle, 520x264x136_RowMajorSwizzle_groupCol2_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::RowMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Volta884_h884gemm_64x64x32_tn_swizzle, 520x264x136_ColumnMajorSwizzle) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x64x32_tn_swizzle, 520x264x136_ColumnMajorSwizzle_groupCol2) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::OneDirection> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x64x32_tn_swizzle, 520x264x136_ColumnMajorSwizzle_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<1, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +TEST(Volta884_h884gemm_64x64x32_tn_swizzle, 520x264x136_ColumnMajorSwizzle_groupCol2_Boustrophedon) { + + typedef cutlass::gemm::Volta884GemmTraits< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<32, 64, 64>, + half, + half, + half, + 2, + cutlass::gemm::LinearScaling, + typename cutlass::gemm::ColumnMajorBlockSwizzle<2, cutlass::gemm::swizzleDirection::Boustrophedon> + > GemmTraits; + + run_gemm(520, 264, 136); +} + +#endif // #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + +#endif // defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) + diff --git a/tools/test/unit/gemm/volta884_multiplicand.cu b/tools/test/unit/gemm/volta884_multiplicand.cu new file mode 100644 index 0000000000..94c09333a3 --- /dev/null +++ b/tools/test/unit/gemm/volta884_multiplicand.cu @@ -0,0 +1,692 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include +#include "cutlass_unit_test.h" + +#include "tools/util/half.h" +#include "tools/util/host_matrix.h" +#include "tools/util/tensor_view_io.h" + +#include "cutlass/gemm/volta884_multiplicand.h" +#include "cutlass/gemm/volta884_multiply_add.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if CUTLASS_ENABLE_TENSOR_CORE_MMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Simplified GEMM: computes one threadblock-scoped matrix product. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to verify a tile of data loaded from GMEM, stored to SMEM, and loaded into RF computes +/// the expected mma.sync product +template < + typename MultiplicandA, + typename MultiplicandB, + typename ScalarC +> +__global__ void test_volta884_matrix_product( + typename MultiplicandA::LoadIterator::Params load_A_params, + typename MultiplicandB::LoadIterator::Params load_B_params, + float *C, + int ldc, + int active_k_idx) { + + // Define thread-block scoped load iterators + typename MultiplicandA::LoadIterator load_A_iterator(load_A_params); + typename MultiplicandB::LoadIterator load_B_iterator(load_B_params); + + + // Define shared memory buffers + static int const kSmemAElements = + cutlass::ShapeCount::kCount; + + static int const kSmemBElements = + cutlass::ShapeCount::kCount; + + __shared__ uint16_t smem_A_buffer[kSmemAElements]; + __shared__ uint16_t smem_B_buffer[kSmemBElements]; + + + // Instantiate thread-block-scoped store iterators + typename MultiplicandA::StoreIterator::Params store_A_params(reinterpret_cast(&smem_A_buffer[0])); + typename MultiplicandB::StoreIterator::Params store_B_params(reinterpret_cast(&smem_B_buffer[0])); + + typename MultiplicandA::StoreIterator store_A_iterator(store_A_params); + typename MultiplicandB::StoreIterator store_B_iterator(store_B_params); + + + // Load thread-block scoped fragments + typename MultiplicandA::LoadIterator::Fragment threadblock_A_frag; + typename MultiplicandB::LoadIterator::Fragment threadblock_B_frag; + + __syncthreads(); + + // A operand + load_A_iterator.load(threadblock_A_frag); + store_A_iterator.store(threadblock_A_frag); + + // Barrier to enforce SMEM consistency + __syncthreads(); + + // B operand + load_B_iterator.load(threadblock_B_frag); + store_B_iterator.store(threadblock_B_frag); + + + // Barrier to enforce SMEM consistency + __syncthreads(); + + // Instantiate warp-scoped load iterators + typename MultiplicandA::WarpLoadIterator::Params warp_A_params(reinterpret_cast(&smem_A_buffer[0])); + typename MultiplicandB::WarpLoadIterator::Params warp_B_params(reinterpret_cast(&smem_B_buffer[0])); + + typename MultiplicandA::WarpLoadIterator warp_load_A(warp_A_params); + typename MultiplicandB::WarpLoadIterator warp_load_B(warp_B_params); + + // Instantiate a multiply-add object specialized for Volta mma.sync + typedef cutlass::gemm::Volta884MultiplyAdd< + typename MultiplicandA::WarpTile, + MultiplicandA::kLayout, + half, + MultiplicandB::kLayout, + half, + ScalarC + > MultiplyAdd; + + typedef cutlass::gemm::Volta884NaiveEpilogue< + ScalarC, + typename MultiplicandA::WarpDelta, + typename MultiplyAdd::Iterations + > NaiveEpilogue; + + MultiplyAdd multiply_add; + NaiveEpilogue epilogue(C, ldc); + + // Initialize accumulator fragment + typename MultiplyAdd::Accumulators accumulators; + + + for (int i = 0; i < MultiplyAdd::Accumulators::kElements; ++i) { + accumulators[i] = threadIdx.x; + } + + epilogue.clear(accumulators); + + // Iterate over the K dimension of the threadblock tile + #pragma unroll + for (int k_idx = 0; k_idx < MultiplicandA::Tile::kD / MultiplyAdd::WarpTile::kD; ++k_idx) { + + if (active_k_idx < 0 || active_k_idx == k_idx) { + typename MultiplicandA::WarpLoadIterator::Fragment warp_A_frag; + typename MultiplicandB::WarpLoadIterator::Fragment warp_B_frag; + + // Load warp-scoped fragments + warp_load_A.load(warp_A_frag, cutlass::make_Coord(k_idx, 0, 0, 0)); + warp_load_B.load(warp_B_frag, cutlass::make_Coord(k_idx, 0, 0, 0)); + + // Compute accumulated matrix product + multiply_add.multiply_add(warp_A_frag, warp_B_frag, accumulators, accumulators); + } + } + + // Store accumulator tile + epilogue.store(accumulators); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Identifies multiplicand of GEMM (A or B) + cutlass::MatrixLayout::Kind LayoutA, + /// Specifies layout of data in source memory + cutlass::MatrixLayout::Kind LayoutB, + /// Accumulator type + typename ScalarC, + /// Specifies threadblock tile shape + typename Tile, + /// Specifies the warp tile shape + typename WarpTile, + /// Specifies the number of participating warps + int WarpCount, + /// Specifies the delta between warp accesses along the outer dimension + typename WarpDelta +> +struct Volta884MatrixProductTestbed { + + // + // Type definitions + // + + typedef cutlass::gemm::Volta884Multiplicand< + cutlass::GemmOperand::kA, + LayoutA, + Tile, + WarpTile, + WarpCount, + WarpDelta> MultiplicandA; + + typedef cutlass::gemm::Volta884Multiplicand< + cutlass::GemmOperand::kB, + LayoutB, + Tile, + WarpTile, + WarpCount, + WarpDelta> MultiplicandB; + + /// Generates random elements + template + struct RandomGenerator { + RandomGenerator( + int seed = -1 + ) { srand(seed); } + + T operator()() { + int val = (rand() % 29) - 13; + return T(val); + } + }; + + /// Depth of an mma.sync instruction + static int const kWarpK = 4; + + // + // Data members + // + + cutlass::HostMatrix tensor_A; + cutlass::HostMatrix tensor_B; + cutlass::HostMatrix tensor_C; + cutlass::HostMatrix tensor_Ref; + + // + // Methods + // + + Volta884MatrixProductTestbed() { + + tensor_A.resize(cutlass::make_Coord(Tile::kW, Tile::kD), LayoutA); + tensor_B.resize(cutlass::make_Coord(Tile::kD, Tile::kH), LayoutB); + tensor_C.resize(cutlass::make_Coord(Tile::kW, Tile::kH), cutlass::MatrixLayout::kColumnMajor); + tensor_Ref.resize(cutlass::make_Coord(Tile::kW, Tile::kH), cutlass::MatrixLayout::kColumnMajor); + + } + + /// Runs a test case + bool run_once(int seed, int active_k_idx = -1) { + + #if 0 + // For debugging, it helps to see sequential elements + tensor_A.fill_sequential(); + tensor_B.fill_identity(); + #else + // Fill with random elements + tensor_A.fill_random(RandomGenerator(seed + 53)); + tensor_B.fill_random(RandomGenerator(seed + 97)); + #endif + + if (active_k_idx >= 0) { + // overwrite all but the active k index with zeros + int const m_stride = (LayoutA == cutlass::MatrixLayout::kRowMajor ? Tile::kD : 1); + int const a_k_stride = (LayoutA == cutlass::MatrixLayout::kRowMajor ? 1 : Tile::kW); + + int const n_stride = (LayoutB == cutlass::MatrixLayout::kRowMajor ? 1 : Tile::kD); + int const b_k_stride = (LayoutB == cutlass::MatrixLayout::kRowMajor ? Tile::kH : 1); + + for (int k_idx = 0; k_idx < Tile::kD / kWarpK; ++k_idx) { + if (active_k_idx != k_idx) { + + for (int k = 0; k < kWarpK; ++k) { + for (int m = 0; m < Tile::kW; ++m) { + tensor_A.host_data()[m_stride * m + a_k_stride * (k_idx * kWarpK + k)] = 0; + } + for (int n = 0; n < Tile::kH; ++n) { + tensor_B.host_data()[n_stride * n + b_k_stride * (k_idx * kWarpK + k)] = 0; + } + } + } + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + + tensor_C.fill(ScalarC(0)); + tensor_Ref.fill(ScalarC(0)); + tensor_C.sync_device(); + + // run kernel + dim3 grid(1, 1); + dim3 block(32 * WarpCount, 1, 1); + + typename MultiplicandA::LoadIterator::Params load_A_params( + tensor_A.device_data(), + tensor_A.leading_dim() * 8, + tensor_A.leading_dim(), + 8 + ); + + typename MultiplicandB::LoadIterator::Params load_B_params( + tensor_B.device_data(), + tensor_B.leading_dim() * 8, + tensor_B.leading_dim(), + 8 + ); + + test_volta884_matrix_product<<< grid, block >>>( + load_A_params, + load_B_params, + tensor_C.device_data(), + tensor_C.leading_dim(), + active_k_idx + ); + + EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); + + // Copy to host + tensor_C.sync_host(); + + // Compute reference + cutlass::reference::host::Gemm( + cutlass::gemm::GemmCoord( + tensor_A.size().column(), + tensor_Ref.size().column(), + tensor_Ref.size().row()), + ScalarC(1), + tensor_A, + tensor_B, + ScalarC(0), + tensor_Ref, + ScalarC(0)); + + // Assert bit-level equivalence + bool passed = tensor_Ref.bit_equals(tensor_C); + + EXPECT_TRUE(passed) + << "Incorrect matrix product\n" + << "A =\n" << tensor_A + << "\nB =\n" << tensor_B + << "\nRef =\n" << tensor_Ref + << "\nMMA=\n" << tensor_C; + + return passed; + } + + /// Executes a set of test cases containing unique, randomly chosen matrices and verifies + /// bit equivalence with the reference implementation. + bool run(int test_count = 16) { + + bool passed = true; + + #if 1 + // Run several tests with deterministic seeds + for (int i = 0; i < test_count && passed; ++i) { + passed = run_once(i * 41 + i * 17); + } + + #else + // For debugging, run the full matrix product with exactly one K-index non-zero + for (int k_idx = 0; passed && k_idx < Tile::kD / kWarpK; ++k_idx) { + passed = run_once(17, k_idx); + if (!passed) { + std::cout << "Failed on k_idx = " << k_idx + << " [" << k_idx * kWarpK << ".." << (k_idx + 1) * kWarpK - 1 << "]" << std::endl; + } + } + #endif + + return passed; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 64x64x32, 128x64x32, 64x128x32, 128x128x32, 256x128x32, 128x256x32, 64x64x128 +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Congruous loading +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_nt, 64x64x32_32x32x4) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + float, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<4, 32, 32>, + 4, + cutlass::Shape<1, 2, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_nt, 128x64x32_64x32x4) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + float, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<4, 32, 64>, + 4, + cutlass::Shape<1, 2, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_nt, 64x128x32_32x64x4) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + float, + cutlass::Shape<32, 128, 64>, + cutlass::Shape<4, 64, 32>, + 4, + cutlass::Shape<1, 2, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_nt, 64x64x32) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + float, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<4, 64, 64>, + 1, + cutlass::Shape<1, 1, 1, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_nt, 64x64x128) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + float, + cutlass::Shape<128, 64, 64>, + cutlass::Shape<4, 64, 64>, + 1, + cutlass::Shape<1, 1, 1, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_nt, 128x64x32) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + float, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<4, 64, 64>, + 2, + cutlass::Shape<1, 1, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_nt, 64x128x32) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + float, + cutlass::Shape<32, 128, 64>, + cutlass::Shape<4, 64, 64>, + 2, + cutlass::Shape<1, 2, 1, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_nt, 128x128x32) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + float, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<4, 64, 64>, + 4, + cutlass::Shape<1, 2, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_nt, 256x128x32) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + float, + cutlass::Shape<32, 128, 256>, + cutlass::Shape<4, 64, 64>, + 8, + cutlass::Shape<1, 2, 4, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_nt, 128x256x32) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kColumnMajor, + cutlass::MatrixLayout::kRowMajor, + float, + cutlass::Shape<32, 256, 128>, + cutlass::Shape<4, 64, 64>, + 8, + cutlass::Shape<1, 4, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Crosswise loading +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +TEST(volta884_matrix_product_tn, 64x64x32_32x32x4) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + float, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<4, 32, 32>, + 4, + cutlass::Shape<1, 2, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_tn, 128x64x32_64x32x4) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + float, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<4, 32, 64>, + 4, + cutlass::Shape<1, 2, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_tn, 64x128x32_32x64x4) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + float, + cutlass::Shape<32, 128, 64>, + cutlass::Shape<4, 64, 32>, + 4, + cutlass::Shape<1, 2, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_tn, 64x64x32) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + float, + cutlass::Shape<32, 64, 64>, + cutlass::Shape<4, 64, 64>, + 1, + cutlass::Shape<1, 1, 1, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_tn, 128x64x32) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + float, + cutlass::Shape<32, 64, 128>, + cutlass::Shape<4, 64, 64>, + 2, + cutlass::Shape<1, 1, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_tn, 128x128x32) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + float, + cutlass::Shape<32, 128, 128>, + cutlass::Shape<4, 64, 64>, + 4, + cutlass::Shape<1, 2, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_tn, 256x128x32) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + float, + cutlass::Shape<32, 128, 256>, + cutlass::Shape<4, 64, 64>, + 8, + cutlass::Shape<1, 2, 4, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(volta884_matrix_product_tn, 128x256x32) { + + Volta884MatrixProductTestbed< + cutlass::MatrixLayout::kRowMajor, + cutlass::MatrixLayout::kColumnMajor, + float, + cutlass::Shape<32, 256, 128>, + cutlass::Shape<4, 64, 64>, + 8, + cutlass::Shape<1, 4, 2, 1> + > testbed; + + EXPECT_TRUE(testbed.run()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) diff --git a/tools/test/unit/gemm/wmma_gemm.cu b/tools/test/unit/gemm/wmma_gemm.cu index bb9412515d..a6f567dfd3 100644 --- a/tools/test/unit/gemm/wmma_gemm.cu +++ b/tools/test/unit/gemm/wmma_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/wmma_gemm_non_multiple16.cu b/tools/test/unit/gemm/wmma_gemm_non_multiple16.cu index 0dfa4107ef..c5d8c33005 100644 --- a/tools/test/unit/gemm/wmma_gemm_non_multiple16.cu +++ b/tools/test/unit/gemm/wmma_gemm_non_multiple16.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/gemm/wmma_integer_gemm.cu b/tools/test/unit/gemm/wmma_integer_gemm.cu index 857408c866..0baa6fd7ec 100644 --- a/tools/test/unit/gemm/wmma_integer_gemm.cu +++ b/tools/test/unit/gemm/wmma_integer_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -23,7 +23,7 @@ * **************************************************************************************************/ #include "cutlass/wmma_matrix.h" -#ifdef CUTLASS_USE_SUBBYTE_WMMA +#ifdef CUTLASS_USE_WMMA_API #include "cutlass_unit_test.h" #include "cutlass/gemm/gemm.h" @@ -44,6 +44,7 @@ - Shapes should be specified as MxNxK (opposite to the Shape<> definition which is KxNxM) */ +#ifdef CUTLASS_USE_SUBBYTE_WMMA //////////////////////////////////////////////////////////////////////////////////////////////////// // // S4 Integer GEMM Unit Tests @@ -112,6 +113,8 @@ TEST(WmmaInt4Gemm_32x32x64_8x8x32_u4, wmma_integer_gemm_32x32x64) { WmmaGemmTraits; run_integer_gemm(32, 32, 64); } +#endif //ifdef CUTLASS_USE_SUBBYTE_WMMA +#ifdef CUTLASS_USE_INT_WMMA //////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -627,4 +630,5 @@ TEST(WmmaInt8Gemm_32x32x32_8x32x16_u8_nn, wmma_integer_gemm_32x32x32) { //////////////////////////////////////////////////////////////////////////////////////////////////// -#endif // ifdef CUTLASS_USE_SUBBYTE_WMMA +#endif // ifdef CUTLASS_USE_INT_WMMA +#endif // ifdef CUTLASS_USE_WMMA_API diff --git a/tools/test/unit/reduction/batched_reduction.cu b/tools/test/unit/reduction/batched_reduction.cu index 4bed73d4cd..90cad919c4 100644 --- a/tools/test/unit/reduction/batched_reduction.cu +++ b/tools/test/unit/reduction/batched_reduction.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/reduction/batched_reduction_testbed.h b/tools/test/unit/reduction/batched_reduction_testbed.h index c5db28eeef..5478c96172 100644 --- a/tools/test/unit/reduction/batched_reduction_testbed.h +++ b/tools/test/unit/reduction/batched_reduction_testbed.h @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/reduction/mixed_batched_reduction.cu b/tools/test/unit/reduction/mixed_batched_reduction.cu index 3ea66a58c3..3956acde3a 100644 --- a/tools/test/unit/reduction/mixed_batched_reduction.cu +++ b/tools/test/unit/reduction/mixed_batched_reduction.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/reduction/test_batched_reduction.h b/tools/test/unit/reduction/test_batched_reduction.h index ffba8e1334..a967ef9118 100644 --- a/tools/test/unit/reduction/test_batched_reduction.h +++ b/tools/test/unit/reduction/test_batched_reduction.h @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/tile_iterator_test.cu b/tools/test/unit/tile_iterator_test.cu index 6782c18029..8d395f50c8 100644 --- a/tools/test/unit/tile_iterator_test.cu +++ b/tools/test/unit/tile_iterator_test.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -39,7 +39,6 @@ using ::cutlass::TileTraits; using ::testing::Test; -// TODO: Move the following to standard test helper infrastructure // Returns randomly initialized array // // Caller is responsible for deallocation. @@ -90,10 +89,6 @@ TEST(TileIteratorTest, BasicCpuSideIterateTile) { TileThreadOffset, /*AccessSize=*/1>, float, IteratorAdvance::kH, MemorySpace::kGlobal> GlobalTileLoader; typedef GlobalTileLoader::Fragment BufferType; - // - // TODO: The following loop should probably be refactored out into standard test helper code for - // tile iteration. - // // Iterate: gridDim(1, 1, kDimX / kDimXPerWarp), blockDim(1, kDimXPerWarp, kDimYPerWarp) for (int blockIdx_x = 0; blockIdx_x < kDimX / kDimXPerWarp; blockIdx_x++) { for (int threadIdx_x = 0; threadIdx_x < kDimXPerWarp; threadIdx_x++) { diff --git a/tools/test/unit/util/complex.cu b/tools/test/unit/util/complex.cu index 12d840fdbe..e4867e19e3 100644 --- a/tools/test/unit/util/complex.cu +++ b/tools/test/unit/util/complex.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/util/host_tensor.cu b/tools/test/unit/util/host_tensor.cu index ce3b22489d..921d5d0700 100644 --- a/tools/test/unit/util/host_tensor.cu +++ b/tools/test/unit/util/host_tensor.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/util/tensor_elementwise.cu b/tools/test/unit/util/tensor_elementwise.cu index a983a4f4c1..1330e0793c 100644 --- a/tools/test/unit/util/tensor_elementwise.cu +++ b/tools/test/unit/util/tensor_elementwise.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/test/unit/util/tensor_foreach.cu b/tools/test/unit/util/tensor_foreach.cu index dcb9659872..33b6ef7caf 100644 --- a/tools/test/unit/util/tensor_foreach.cu +++ b/tools/test/unit/util/tensor_foreach.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/command_line.h b/tools/util/command_line.h index d4bb96fea6..73db6f087a 100644 --- a/tools/util/command_line.h +++ b/tools/util/command_line.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are not permitted. diff --git a/tools/util/device_memory.h b/tools/util/device_memory.h index 0aa0532cba..813a8a3210 100644 --- a/tools/util/device_memory.h +++ b/tools/util/device_memory.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are not permitted. diff --git a/tools/util/distribution.h b/tools/util/distribution.h index 1c2701fc3b..42a531768d 100644 --- a/tools/util/distribution.h +++ b/tools/util/distribution.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/exceptions.h b/tools/util/exceptions.h index 3683fbf4fd..2d26309b5e 100644 --- a/tools/util/exceptions.h +++ b/tools/util/exceptions.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are not permitted. diff --git a/tools/util/half.h b/tools/util/half.h index 91e8b11301..4d5b574335 100644 --- a/tools/util/half.h +++ b/tools/util/half.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/host_matrix.h b/tools/util/host_matrix.h index 9812f757dc..23e7e5d4c0 100644 --- a/tools/util/host_matrix.h +++ b/tools/util/host_matrix.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/host_matrix_view.h b/tools/util/host_matrix_view.h index 84767878cb..1fea91be0d 100644 --- a/tools/util/host_matrix_view.h +++ b/tools/util/host_matrix_view.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/host_tensor.h b/tools/util/host_tensor.h index fc042b0b7e..1fbb493a6e 100644 --- a/tools/util/host_tensor.h +++ b/tools/util/host_tensor.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -114,29 +114,29 @@ class HostTensor : public HostTensorView< typedef typename DeviceTensorView::ConstTensorView ConstDeviceTensorView; /// Tensor reference to host memory - typedef typename Base::TensorRef TensorRef; + typedef typename Base::TensorRef_t TensorRef_t; /// Tensor view to host memory - typedef TensorView< + typedef HostTensorView< typename TypeTraits::host_type, Rank_, MapFunc_, StorageRank_, Index_, - LongIndex_> HostTensorView; + LongIndex_> HostTensorView_t; /// Tensor view to host memory - typedef typename HostTensorView::ConstTensorView ConstHostTensorView; + typedef typename HostTensorView_t::ConstTensorView ConstHostTensorView; /// Coordinate in logical tensor space - typedef typename TensorRef::TensorCoord TensorCoord; + typedef typename TensorRef_t::TensorCoord TensorCoord; /// Coordinate in storage n-D array - typedef typename TensorRef::StorageCoord StorageCoord; + typedef typename TensorRef_t::StorageCoord StorageCoord; /// Stride vector in storage coordinate space /// Least significant stride is = 1 and not stored - typedef typename TensorRef::StrideVector StrideVector; + typedef typename TensorRef_t::StrideVector StrideVector; /// Rank of internal storage. static int const kStorageRank = Base::kStorageRank; @@ -216,14 +216,14 @@ class HostTensor : public HostTensorView< host_.resize(_capacity); device_.reset(_device_memory, _capacity); - Base::reset(TensorRef(host_.data(), stride), size); + Base::reset(TensorRef_t(host_.data(), stride), size); } /// Accesses the tensor reference pointing to data - TensorRef host_ref() { return Base::ref(); } + TensorRef_t host_ref() { return Base::ref(); } /// Accesses the tensor reference pointing to data - TensorRef host_ref() const { return Base::ref(); } + TensorRef_t host_ref() const { return Base::ref(); } /// Accesses the tensor reference pointing to data DeviceTensorRef device_ref() const { @@ -231,13 +231,13 @@ class HostTensor : public HostTensorView< } /// Accesses the tensor reference pointing to data - HostTensorView host_view() { - return HostTensorView(host_data(), this->stride(), this->size()); + HostTensorView_t host_view() { + return HostTensorView_t(host_data(), this->stride(), this->size()); } /// Accesses the tensor reference pointing to data ConstHostTensorView host_view() const { - return HostTensorView(host_data(), this->stride(), this->size()); + return HostTensorView_t(host_data(), this->stride(), this->size()); } /// Accesses the tensor reference pointing to data diff --git a/tools/util/host_tensor_view.h b/tools/util/host_tensor_view.h index 4b7f90c744..b4c6aa7bc5 100644 --- a/tools/util/host_tensor_view.h +++ b/tools/util/host_tensor_view.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -61,21 +61,21 @@ class HostTensorView : /// Storage type typedef typename Base::Storage Storage; - /// Alias for underlying TensorRef - typedef typename Base::TensorRef TensorRef; + /// Alias for underlying TensorRef_t + typedef typename Base::TensorRef_t TensorRef_t; /// Index type typedef typename Base::Index Index; /// Coordinate in logical tensor space - typedef typename TensorRef::TensorCoord TensorCoord; + typedef typename TensorRef_t::TensorCoord TensorCoord; /// Coordinate in storage n-D array - typedef typename TensorRef::StorageCoord StorageCoord; + typedef typename TensorRef_t::StorageCoord StorageCoord; /// Stride vector in storage coordinate space /// Least significant stride is = 1 and not stored - typedef typename TensorRef::StrideVector StrideVector; + typedef typename TensorRef_t::StrideVector StrideVector; /// Long index type for pointer offsets typedef typename Base::LongIndex LongIndex; @@ -121,18 +121,18 @@ class HostTensorView : Storage_ *_ptr, StrideVector const &_stride, TensorCoord const& _size - ) : Base(TensorRef(_ptr, _stride), _size) {} + ) : Base(TensorRef_t(_ptr, _stride), _size) {} /// Helper to construct from pointer, stride, and size HostTensorView( Storage_ *_ptr, StorageCoord const &_stride, TensorCoord const& _size - ) : Base(TensorRef(_ptr, _stride), _size) {} + ) : Base(TensorRef_t(_ptr, _stride), _size) {} - /// Constructs a Tensor_view from a TensorRef and size assuming dense packing + /// Constructs a Tensor_view from a TensorRef_t and size assuming dense packing HostTensorView( - TensorRef const& _ref, + TensorRef_t const& _ref, TensorCoord const& _size) : Base(_ref, _size) {} /// Assigns a tensor view @@ -149,22 +149,22 @@ class HostTensorView : return result; } - /// Returns a TensorRef offset by a given amount + /// Returns a TensorRef_t offset by a given amount CUTLASS_HOST_DEVICE HostTensorView& operator+=(TensorCoord const& b) { this->add_pointer_offset(this->offset(b)); return *this; } - /// Returns a TensorRef offset by a given amount + /// Returns a TensorRef_t offset by a given amount CUTLASS_HOST_DEVICE HostTensorView operator-(TensorCoord const& b) const { - TensorRef result(*this); + TensorRef_t result(*this); result.add_pointer_offset(-this->offset(b)); return result; } - /// Returns a TensorRef offset by a given amount + /// Returns a TensorRef_t offset by a given amount CUTLASS_HOST_DEVICE HostTensorView& operator-=(TensorCoord const& b) { this->add_pointer_offset(-this->offset(b)); @@ -474,7 +474,7 @@ class HostTensorView : void operator()(Storage const& element) { double value(element); - double conj(element); // TODO - conjugates for complex + double conj(element); sum += value * conj; } diff --git a/tools/util/reference/detail/inner_product.h b/tools/util/reference/detail/inner_product.h index c47cac1e5d..26ebe1e8c1 100644 --- a/tools/util/reference/detail/inner_product.h +++ b/tools/util/reference/detail/inner_product.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/reference/device/gemm.h b/tools/util/reference/device/gemm.h index f9cbcab26e..126e8eae8d 100644 --- a/tools/util/reference/device/gemm.h +++ b/tools/util/reference/device/gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/reference/device/kernel/gemm.h b/tools/util/reference/device/kernel/gemm.h index 51630cf4c9..201668a9bf 100644 --- a/tools/util/reference/device/kernel/gemm.h +++ b/tools/util/reference/device/kernel/gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/reference/device/kernel/split_complex_gemm.h b/tools/util/reference/device/kernel/split_complex_gemm.h index eff2bac075..4be721dd4b 100644 --- a/tools/util/reference/device/kernel/split_complex_gemm.h +++ b/tools/util/reference/device/kernel/split_complex_gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/reference/device/kernel/tensor_elementwise.h b/tools/util/reference/device/kernel/tensor_elementwise.h index 31f7a2d8d1..cf47c9a4ea 100644 --- a/tools/util/reference/device/kernel/tensor_elementwise.h +++ b/tools/util/reference/device/kernel/tensor_elementwise.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/reference/device/kernel/tensor_foreach.h b/tools/util/reference/device/kernel/tensor_foreach.h index 5396d56188..04d2e7ea66 100644 --- a/tools/util/reference/device/kernel/tensor_foreach.h +++ b/tools/util/reference/device/kernel/tensor_foreach.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/reference/device/split_complex_gemm.h b/tools/util/reference/device/split_complex_gemm.h index dd2b817161..c204acffb8 100644 --- a/tools/util/reference/device/split_complex_gemm.h +++ b/tools/util/reference/device/split_complex_gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/reference/device/tensor_elementwise.h b/tools/util/reference/device/tensor_elementwise.h index 2b1eb2487a..64fdbf8eb3 100644 --- a/tools/util/reference/device/tensor_elementwise.h +++ b/tools/util/reference/device/tensor_elementwise.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/reference/device/tensor_foreach.h b/tools/util/reference/device/tensor_foreach.h index 1c3a72a6cb..6eb7e1795e 100644 --- a/tools/util/reference/device/tensor_foreach.h +++ b/tools/util/reference/device/tensor_foreach.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/reference/device/thread/gemm.h b/tools/util/reference/device/thread/gemm.h index 6a8a27952d..05e8262cb9 100644 --- a/tools/util/reference/device/thread/gemm.h +++ b/tools/util/reference/device/thread/gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/reference/device/thread/split_complex_gemm.h b/tools/util/reference/device/thread/split_complex_gemm.h index f0005d7264..cb564b199c 100644 --- a/tools/util/reference/device/thread/split_complex_gemm.h +++ b/tools/util/reference/device/thread/split_complex_gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/reference/host/gemm.h b/tools/util/reference/host/gemm.h index 31902ac3f3..1f66ce527c 100644 --- a/tools/util/reference/host/gemm.h +++ b/tools/util/reference/host/gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -219,6 +219,7 @@ void BatchedGemm( //////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace host } // namespace reference } // namespace cutlass diff --git a/tools/util/reference/host/split_complex_gemm.h b/tools/util/reference/host/split_complex_gemm.h index 149fad516c..616ef2ed80 100644 --- a/tools/util/reference/host/split_complex_gemm.h +++ b/tools/util/reference/host/split_complex_gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/reference/host/tensor_elementwise.h b/tools/util/reference/host/tensor_elementwise.h index 88f46bcdf8..259db00c87 100644 --- a/tools/util/reference/host/tensor_elementwise.h +++ b/tools/util/reference/host/tensor_elementwise.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/reference/host/tensor_foreach.h b/tools/util/reference/host/tensor_foreach.h index bd4455693a..9518aa7f6f 100644 --- a/tools/util/reference/host/tensor_foreach.h +++ b/tools/util/reference/host/tensor_foreach.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/tensor_view_io.h b/tools/util/tensor_view_io.h index c1b954eae9..51e6df932c 100644 --- a/tools/util/tensor_view_io.h +++ b/tools/util/tensor_view_io.h @@ -1,5 +1,5 @@ /*************************************************************************************************** -* Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/util/type_traits.h b/tools/util/type_traits.h index f3b1377fb1..7a82e52b84 100644 --- a/tools/util/type_traits.h +++ b/tools/util/type_traits.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: