Skip to content

Commit

Permalink
Add couple configs into generator.py for mixed input MM
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic committed Mar 7, 2024
1 parent ffa34e7 commit e130f14
Show file tree
Hide file tree
Showing 17 changed files with 1,279 additions and 28 deletions.
29 changes: 22 additions & 7 deletions python/cutlass_library/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2205,17 +2205,17 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version):
math_instructions = [
MathInstruction( \
[16, 8, 16], \
DataType.s8, DataType.f16, DataType.f16, \
DataType.s8, DataType.f16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.s8, DataType.f16, DataType.f32, \
DataType.u8, DataType.f16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.u8, DataType.f16, DataType.f32, \
DataType.s8, DataType.bf16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
Expand All @@ -2225,9 +2225,14 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version):
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.s8, DataType.bf16, DataType.f32, \
DataType.s8, DataType.f16, DataType.f16, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.u8, DataType.f16, DataType.f16, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast), \
]

min_cc = 80
Expand Down Expand Up @@ -2267,7 +2272,7 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version):
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8)

# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
if math_inst.element_a != math_inst.element_accumulator:
if math_inst.element_b != math_inst.element_accumulator:

data_type_mixed = [
math_inst.element_a,
Expand Down Expand Up @@ -2302,19 +2307,29 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version):
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.bf16, DataType.s8, DataType.f32, \
DataType.f16, DataType.u8, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.f16, DataType.u8, DataType.f32, \
DataType.bf16, DataType.s8, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.bf16, DataType.u8, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.f16, DataType.s8, DataType.f16, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.f16, DataType.u8, DataType.f16, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast), \
]

min_cc = 80
Expand Down
36 changes: 32 additions & 4 deletions test/unit/gemm/device/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,42 @@ cutlass_test_unit_add_executable(
BATCH_SIZE 4

# Upcast on Operand A
gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_s8t_f16n_f32t_mixed_input_tensor_op_f32_sm80.cu
gemm_universal_u8t_f16n_f32t_mixed_input_tensor_op_f32_sm80.cu
gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f32_sm80.cu
gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f32_sm80.cu

gemm_universal_s8t_bf16n_f32t_mixed_input_tensor_op_f32_sm80.cu
gemm_universal_u8t_bf16n_f32t_mixed_input_tensor_op_f32_sm80.cu
gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu
# Following test could be created by making a copy of
# gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f32_sm80.cu,
# and then replacing occurences of "half_t" in the code with
# "bfloat16_t". Such test would fail, but with only a single value
# differing from the reference.
#gemm_universal_u8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu

gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu

# Upcast on Operand B
gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_f16t_s8n_f32t_mixed_input_tensor_op_f32_sm80.cu
gemm_universal_f16t_u8n_f32t_mixed_input_tensor_op_f32_sm80.cu
gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f32_sm80.cu
gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f32_sm80.cu

gemm_universal_bf16t_s8n_f32t_mixed_input_tensor_op_f32_sm80.cu
gemm_universal_bf16t_u8n_f32t_mixed_input_tensor_op_f32_sm80.cu
gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu
# Following test could be created by making a copy of
# gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f32_sm80.cu,
# and then replacing occurences of "half_t" in the code with
# "bfloat16_t". Such test would fail, but with only a single value
# differing from the reference.
# gemm_universal_bf16t_u8n_bf16t_mixed_input_tensor_op_f32_sm80.cu

gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu
)

cutlass_test_unit_add_executable(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/

#include <iostream>

#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"

#include "cutlass/gemm/device/gemm_universal.h"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"

#include "testbed_universal.h"

////////////////////////////////////////////////////////////////////////////////

#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

////////////////////////////////////////////////////////////////////////////////


TEST(SM80_Device_GemmUniversal_bf16t_s8n_f32t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) {

using ElementA = cutlass::bfloat16_t;
using ElementB = int8_t;
using ElementOutput = float;
using ElementAccumulator = float;

using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
8, // AlignmentA
16, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;

EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////

#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

////////////////////////////////////////////////////////////////////////////////
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/

#include <iostream>

#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"

#include "cutlass/gemm/device/gemm_universal.h"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"

#include "testbed_universal.h"

////////////////////////////////////////////////////////////////////////////////

#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

////////////////////////////////////////////////////////////////////////////////


TEST(SM80_Device_GemmUniversal_bf16t_u8n_f32t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) {

using ElementA = cutlass::bfloat16_t;
using ElementB = uint8_t;
using ElementOutput = float;
using ElementAccumulator = float;

using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
8, // AlignmentA
16, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;

EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////

#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

////////////////////////////////////////////////////////////////////////////////
Loading

0 comments on commit e130f14

Please sign in to comment.