Skip to content

Commit

Permalink
Add couple configs into generator.py for mixed input MM
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic committed Feb 21, 2024
1 parent bbe579a commit 33f48a7
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 11 deletions.
27 changes: 21 additions & 6 deletions python/cutlass_library/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2205,17 +2205,17 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version):
math_instructions = [
MathInstruction( \
[16, 8, 16], \
DataType.s8, DataType.f16, DataType.f16, \
DataType.s8, DataType.f16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.s8, DataType.f16, DataType.f32, \
DataType.u8, DataType.f16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.u8, DataType.f16, DataType.f32, \
DataType.s8, DataType.bf16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
Expand All @@ -2225,7 +2225,12 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version):
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.s8, DataType.bf16, DataType.f32, \
DataType.s8, DataType.f16, DataType.f16, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.u8, DataType.f16, DataType.f16, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
]
Expand Down Expand Up @@ -2302,19 +2307,29 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version):
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.bf16, DataType.s8, DataType.f32, \
DataType.f16, DataType.u8, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.f16, DataType.u8, DataType.f32, \
DataType.bf16, DataType.s8, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.bf16, DataType.u8, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.f16, DataType.s8, DataType.f16, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.f16, DataType.u8, DataType.f16, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
]

min_cc = 80
Expand Down
74 changes: 69 additions & 5 deletions tools/library/src/reference/gemm_fp_mixed_input.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ namespace library {
void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) {
// half_t mixed with 8-bit integer input
make_gemm_real_canonical_layouts<
int8_t,
uint8_t,
half_t,
half_t,
half_t,
half_t
>(manifest);

make_gemm_real_canonical_layouts<
uint8_t,
int8_t,
half_t,
half_t,
half_t,
Expand All @@ -74,11 +74,43 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) {
make_gemm_real_canonical_layouts<
int8_t,
half_t,
float,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
uint8_t,
half_t,
float,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
int8_t,
half_t,
half_t,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
half_t,
uint8_t,
half_t,
half_t,
half_t
>(manifest);

make_gemm_real_canonical_layouts<
half_t,
int8_t,
half_t,
half_t,
half_t
>(manifest);

make_gemm_real_canonical_layouts<
half_t,
uint8_t,
Expand All @@ -95,6 +127,22 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) {
float
>(manifest);

make_gemm_real_canonical_layouts<
half_t,
uint8_t,
float,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
half_t,
int8_t,
float,
float,
float
>(manifest);

// bfloat16_t mixed with 8-bit integer input
make_gemm_real_canonical_layouts<
uint8_t,
Expand All @@ -107,6 +155,14 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) {
make_gemm_real_canonical_layouts<
int8_t,
bfloat16_t,
bfloat16_t,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
uint8_t,
bfloat16_t,
float,
float,
float
Expand All @@ -115,31 +171,39 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) {
make_gemm_real_canonical_layouts<
int8_t,
bfloat16_t,
bfloat16_t,
float,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
bfloat16_t,
uint8_t,
bfloat16_t,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
bfloat16_t,
int8_t,
bfloat16_t,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
bfloat16_t,
uint8_t,
bfloat16_t,
float,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
bfloat16_t,
int8_t,
bfloat16_t,
float,
float,
float
>(manifest);
Expand Down

0 comments on commit 33f48a7

Please sign in to comment.