From a30806ab9cda936a7682223919537d4a281d415e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Tue, 19 Mar 2024 15:45:49 +0100 Subject: [PATCH 1/2] Add support for mixed 4-bit/8-bit data types GEMM --- .../gemm/device/default_gemm_configuration.h | 54 ++++++ .../gemm/warp/default_mma_tensor_op_sm80.h | 71 +++++++- .../gemm/warp/mma_mixed_input_tensor_op.h | 14 +- include/cutlass/numeric_conversion.h | 80 +++++++++ python/cutlass_library/generator.py | 163 ++++++++++++++++++ test/unit/core/fast_numeric_conversion.cu | 24 ++- test/unit/gemm/device/CMakeLists.txt | 6 + ...s8n_s32t_mixed_input_tensor_op_s32_sm80.cu | 95 ++++++++++ ..._s8n_s8t_mixed_input_tensor_op_s32_sm80.cu | 95 ++++++++++ ...s4n_s32t_mixed_input_tensor_op_s32_sm80.cu | 95 ++++++++++ ..._s4n_s8t_mixed_input_tensor_op_s32_sm80.cu | 95 ++++++++++ test/unit/gemm/warp/gemm_mixed_input_sm80.cu | 48 ++++++ tools/library/CMakeLists.txt | 1 + .../src/reference/gemm_int_mixed_input.cu | 130 ++++++++++++++ .../initialize_reference_operations.cu | 3 + 15 files changed, 960 insertions(+), 14 deletions(-) create mode 100644 test/unit/gemm/device/gemm_universal_s4t_s8n_s32t_mixed_input_tensor_op_s32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_s4t_s8n_s8t_mixed_input_tensor_op_s32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_s8t_s4n_s32t_mixed_input_tensor_op_s32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu create mode 100644 tools/library/src/reference/gemm_int_mixed_input.cu diff --git a/include/cutlass/gemm/device/default_gemm_configuration.h b/include/cutlass/gemm/device/default_gemm_configuration.h index 4197a6b080..c9e7cc76d1 100644 --- a/include/cutlass/gemm/device/default_gemm_configuration.h +++ b/include/cutlass/gemm/device/default_gemm_configuration.h @@ -793,6 +793,60 @@ struct DefaultGemmConfigurationSm89F8 { using Operator = arch::OpMultiplyAdd; }; +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + int4b_t, + int8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + int8_t, + int4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Partial specialization for SM89 fe4m3 x fe4m3 template struct DefaultGemmConfiguration< diff --git a/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h b/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h index d7e3232c81..2c851f469a 100644 --- a/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h +++ b/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h @@ -268,7 +268,7 @@ struct DefaultMmaTensorOp< "DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type"); // Data type used for internal computation - use the wider of the two data types for mma.sync operands - using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)), + using ElementOperand = typename platform::conditional<(sizeof_bits::value > sizeof_bits::value), ElementA, ElementB>::type; // Operand datatypes in the internal MMA instruction - use the wider of the two data types @@ -294,6 +294,75 @@ struct DefaultMmaTensorOp< Policy, PartitionsK, AccumulatorsInRowMajor>; }; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial Specialization - inputs are mixed types - uses wider datatype internally. +/// (e.g. S32 <= S4 x S8 + S32, S32 <= S8 x S4 + S32) +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Element type of A matrix + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Element type of B matrix + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp< + WarpShape_, + GemmShape<16, 8, 32>, // InstructionShape + ElementA, // Element type of A matrix in Global Memory + LayoutA, // Layout of A matrix in Global Memory + ElementB, // Element type of B matrix in Global Memory + LayoutB, // Layout of B matrix in Global Memory + ElementC, // Element type of C matrix in Global Memory + LayoutC, // Layout of C matrix in Global Memory + arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype + PartitionsK, AccumulatorsInRowMajor> { + + + // Check if the ElementA and ElementB are of different data types + static_assert(!platform::is_same::value, + "DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type"); + + // Data type used for internal computation - use the wider of the two data types for mma.sync operands + using ElementOperand = typename platform::conditional<(sizeof_bits::value > sizeof_bits::value), + ElementA, ElementB>::type; + + // Operand datatypes in the internal MMA instruction - use the wider of the two data types + using MmaElementA = ElementOperand; + using MmaElementB = ElementOperand; + using MmaElementC = ElementC; + + // Uses + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + GemmShape<16, 8, 32>, + 32, + MmaElementA, cutlass::layout::RowMajor, + MmaElementB, cutlass::layout::ColumnMajor, + MmaElementC, cutlass::layout::RowMajor, + arch::OpMultiplyAddSaturate + >, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaMixedInputTensorOp< + WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace warp diff --git a/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h b/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h index 14d8f33455..0b37ad24c6 100644 --- a/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h @@ -104,6 +104,7 @@ struct FragmentShuffler { //////////////////////////////////////////////////////////////////////////////// /// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8) +/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4) /// for operand A multiplicand going through upcasting. template < /// Element type for the operand in registers for the mma.sync @@ -122,8 +123,10 @@ struct FragmentShuffler ::value == 16) && - (sizeof_bits::value == 8)>::type> { + typename platform::enable_if<((sizeof_bits::value == 16) && + (sizeof_bits::value == 8)) || + ((sizeof_bits::value == 8) && + (sizeof_bits::value == 4))>::type> { public: using ElementMma = ElementMma_; using ElementLoad = ElementLoad_; @@ -187,6 +190,7 @@ struct FragmentShuffler ::value == 16) && - (sizeof_bits::value == 8)>::type> { + typename platform::enable_if<((sizeof_bits::value == 16) && + (sizeof_bits::value == 8)) || + ((sizeof_bits::value == 8) && + (sizeof_bits::value == 4))>::type> { public: using ElementMma = ElementMma_; using ElementLoad = ElementLoad_; diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 2e74afa8e4..1701b4ac8d 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -2771,6 +2771,86 @@ struct NumericArrayConverter { } }; +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + unsigned const& storage = reinterpret_cast(source); + unsigned out[2]; + + asm volatile( + "{ .reg .u32 tmp0, tmp1, tmp2;" + "shl.b32 tmp0, %2, 4;" + "and.b32 tmp0, tmp0, 0xf0f0f0f0;" + "prmt.b32 tmp1, tmp0, tmp0, 0xba98;" + "and.b32 tmp1, tmp1, 0xf0f0f0f0;" + "shr.u32 tmp0, tmp0, 4;" + "or.b32 tmp2, tmp0, tmp1;" + "and.b32 tmp0, %2, 0xf0f0f0f0;" + "prmt.b32 tmp1, tmp0, tmp0, 0xba98;" + "and.b32 tmp1, tmp1, 0xf0f0f0f0;" + "shr.u32 tmp0, tmp0, 4;" + "or.b32 tmp0, tmp0, tmp1;" + "prmt.b32 %0, tmp2, tmp0, 0x5140;" + "prmt.b32 %1, tmp2, tmp0, 0x7362;" + "}" + : "=r"(out[0]), "=r"(out[1]) + : "r"(storage)); + + return reinterpret_cast(out); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + static_assert(!(N % 8), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + NumericArrayConverter convert_vector_; + + result_type result; + + Array *result_ptr = reinterpret_cast *>(&result); + Array const *source_ptr = reinterpret_cast const *>(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 8; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + #endif // Conditional guards to enable partial specialization for packed integers namespace detail { diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 7a5a47b196..c736551432 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -2855,6 +2855,167 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): op.C.alignment = 8 # +def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + # Upcast on Operand A + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s4, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[32, 16, 4],] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + alignment_constraints = [[32, 16, 16],] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_b, + DataType.f32 + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +# +def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + # Upcast on Operand B + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s8, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[16, 32, 4],] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + alignment_constraints = [[16, 32, 16],] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 # def GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version): @@ -4699,6 +4860,8 @@ def GenerateSM80(manifest, cuda_version): GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version) GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version) GenerateSM80_TensorOp_16832_TN(manifest, cuda_version) + GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version) + GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version) GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version) GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version) GenerateSM80_TensorOp_16864_TN(manifest, cuda_version) diff --git a/test/unit/core/fast_numeric_conversion.cu b/test/unit/core/fast_numeric_conversion.cu index 0d6e2401a9..9612efe36e 100644 --- a/test/unit/core/fast_numeric_conversion.cu +++ b/test/unit/core/fast_numeric_conversion.cu @@ -69,7 +69,7 @@ void run_test_integer_range_limited() { cutlass::HostTensor source({1, kN}); for (int i = 0; i < kN; ++i) { - source.host_data()[i] = Source(i % 4); + source.host_view().at({0, i}) = Source(i % 4); } source.sync_device(); @@ -82,7 +82,7 @@ void run_test_integer_range_limited() { destination.sync_host(); for (int i = 0; i < kN; ++i) { - EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i])); + EXPECT_TRUE(float(destination.host_view().at({0, i})) == float(source.host_view().at({0, i}))); } } @@ -97,13 +97,12 @@ void run_test_integer_range_all() { cutlass::HostTensor destination({1, kN}); cutlass::HostTensor source({1, kN}); - int const kIntSourceMin = std::numeric_limits::min(); - int const kIntSourceMax = std::numeric_limits::max(); + int const kIntSourceMin = cutlass::platform::numeric_limits::lowest(); + int const kIntSourceMax = cutlass::platform::numeric_limits::max(); int const kIntRange = kIntSourceMax - kIntSourceMin + 1; for (int i = 0; i < kN; ++i) { - source.host_data()[i] = Source(kIntSourceMin + (i % kIntRange)); - + source.host_view().at({0, i}) = Source(kIntSourceMin + (i % kIntRange)); } source.sync_device(); @@ -118,7 +117,7 @@ void run_test_integer_range_all() { // Verify conversion bool passed = true; for (int i = 0; i < kN; ++i) { - if(!(float(destination.host_data()[i]) == float(source.host_data()[i]))) { + if(!(float(destination.host_view().at({0, i}) == float(source.host_view().at({0, i}))))) { passed = false; break; } @@ -128,8 +127,8 @@ void run_test_integer_range_all() { // Print out results for the failed conversion. if (!passed) { for (int i = 0; i < kN; ++i) { - std::cout << "source(" << float(source.host_data()[i]) << ") -> " - << "destination ("<< float(destination.host_data()[i]) << ")" << std::endl; + std::cout << "source(" << float(source.host_view().at({0, i})) << ") -> " + << "destination ("<< float(destination.host_view().at({0, i})) << ")" << std::endl; } } std::flush(std::cout); @@ -188,3 +187,10 @@ TEST(FastNumericConversion, s8_to_bf16_array) { using Destination = cutlass::bfloat16_t; test::core::kernel::run_test_integer_range_all(); } + +TEST(FastNumericConversion, s4_to_s8_array) { + int const kN = 16; + using Source = cutlass::int4b_t; + using Destination = int8_t; + test::core::kernel::run_test_integer_range_all(); +} diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index b5afa433e9..a70ce542d0 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -264,6 +264,9 @@ cutlass_test_unit_add_executable( gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_s4t_s8n_s32t_mixed_input_tensor_op_s32_sm80.cu + gemm_universal_s4t_s8n_s8t_mixed_input_tensor_op_s32_sm80.cu + # Upcast on Operand B gemm_universal_f16t_s8n_f32t_mixed_input_tensor_op_f32_sm80.cu gemm_universal_f16t_u8n_f32t_mixed_input_tensor_op_f32_sm80.cu @@ -277,6 +280,9 @@ cutlass_test_unit_add_executable( gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu + + gemm_universal_s8t_s4n_s32t_mixed_input_tensor_op_s32_sm80.cu + gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu ) cutlass_test_unit_add_executable( diff --git a/test/unit/gemm/device/gemm_universal_s4t_s8n_s32t_mixed_input_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_universal_s4t_s8n_s32t_mixed_input_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..421ea0c0b2 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_s4t_s8n_s32t_mixed_input_tensor_op_s32_sm80.cu @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_s4t_s8n_s32t_mixed_input_tensor_op_s32, 128x128x64_64x64x64) { + + using ElementA = cutlass::int4b_t; + using ElementB = int8_t; + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 32, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_s4t_s8n_s8t_mixed_input_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_universal_s4t_s8n_s8t_mixed_input_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..685092fb84 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_s4t_s8n_s8t_mixed_input_tensor_op_s32_sm80.cu @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_s4t_s8n_s8t_mixed_input_tensor_op_s32, 128x128x64_64x64x64) { + + using ElementA = cutlass::int4b_t; + using ElementB = int8_t; + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 32, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_s8t_s4n_s32t_mixed_input_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_universal_s8t_s4n_s32t_mixed_input_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..b28cee62c0 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_s8t_s4n_s32t_mixed_input_tensor_op_s32_sm80.cu @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_s8t_s4n_s32t_mixed_input_tensor_op_s32, 128x128x64_64x64x64) { + + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 32, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..89a52b3e80 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_s8t_s4n_s8t_mixed_input_tensor_op_s32, 128x128x64_64x64x64) { + + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 32, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/warp/gemm_mixed_input_sm80.cu b/test/unit/gemm/warp/gemm_mixed_input_sm80.cu index eb7d8023d0..db5b178f38 100644 --- a/test/unit/gemm/warp/gemm_mixed_input_sm80.cu +++ b/test/unit/gemm/warp/gemm_mixed_input_sm80.cu @@ -324,4 +324,52 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_1 .run(); } +//////////////////////////////////////////////////////////////////////////////// +/// S32 <= I4 * I8 + S32 (Upcast on Operand A) +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i4_i8, 64x64x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using ElementA = cutlass::int4b_t; + using ElementB = int8_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// +/// S32 <= I8 * I4 + S32 (Upcast on Operand B) +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_i4, 64x64x64_64x64x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + #endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index f8a28fe6b9..9b54f50817 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -234,6 +234,7 @@ cutlass_add_cutlass_library( src/reference/gemm_fp32out.cu src/reference/gemm_fp_other.cu src/reference/gemm_fp_mixed_input.cu + src/reference/gemm_int_mixed_input.cu src/reference/initialize_reference_operations.cu # cutlass reduction instances in cutlass library diff --git a/tools/library/src/reference/gemm_int_mixed_input.cu b/tools/library/src/reference/gemm_int_mixed_input.cu new file mode 100644 index 0000000000..8d6072e3ef --- /dev/null +++ b/tools/library/src/reference/gemm_int_mixed_input.cu @@ -0,0 +1,130 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_gemm_reference_operations_int_mixed_input(Manifest &manifest) { + // 4-bit integer mixed with 8-bit integer input + make_gemm_real_canonical_layouts< + int4b_t, + int8_t, + int32_t, + int32_t + >(manifest); + + make_gemm_real_canonical_layouts< + int4b_t, + int8_t, + int8_t, + int32_t, + int32_t, + int8_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + int4b_t, + int8_t, + int32_t, + float, + int32_t, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + int4b_t, + int8_t, + int8_t, + float, + int32_t, + int8_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + int4b_t, + int32_t, + int32_t + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + int4b_t, + int8_t, + int32_t, + int32_t, + int8_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + int4b_t, + int32_t, + float, + int32_t, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + int4b_t, + int8_t, + float, + int32_t, + int8_t, + NumericConverterClamp + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/reference/initialize_reference_operations.cu b/tools/library/src/reference/initialize_reference_operations.cu index 16679a27d8..59872b9742 100644 --- a/tools/library/src/reference/initialize_reference_operations.cu +++ b/tools/library/src/reference/initialize_reference_operations.cu @@ -57,6 +57,7 @@ void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest); void initialize_gemm_reference_operations_fp32out(Manifest &manifest); void initialize_gemm_reference_operations_fp_other(Manifest &manifest); void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest); +void initialize_gemm_reference_operations_int_mixed_input(Manifest &manifest); void initialize_conv2d_reference_operations(Manifest &manifest); void initialize_conv3d_reference_operations(Manifest &manifest); @@ -85,6 +86,8 @@ void initialize_reference_operations(Manifest &manifest) { initialize_gemm_reference_operations_fp_other(manifest); initialize_gemm_reference_operations_fp_mixed_input(manifest); + initialize_gemm_reference_operations_int_mixed_input(manifest); + } /////////////////////////////////////////////////////////////////////////////////////////////////// From 16ce3d721a3dd03864f5cb034aa60ee2facb819f Mon Sep 17 00:00:00 2001 From: Haicheng Wu Date: Thu, 29 Aug 2024 08:42:23 -0700 Subject: [PATCH 2/2] fix ( and ) --- test/unit/core/fast_numeric_conversion.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/core/fast_numeric_conversion.cu b/test/unit/core/fast_numeric_conversion.cu index 9612efe36e..99aab24581 100644 --- a/test/unit/core/fast_numeric_conversion.cu +++ b/test/unit/core/fast_numeric_conversion.cu @@ -117,7 +117,7 @@ void run_test_integer_range_all() { // Verify conversion bool passed = true; for (int i = 0; i < kN; ++i) { - if(!(float(destination.host_view().at({0, i}) == float(source.host_view().at({0, i}))))) { + if(!(float(destination.host_view().at({0, i})) == float(source.host_view().at({0, i})))) { passed = false; break; }