Skip to content

Commit

Permalink
[GEMM] [Tuning] Option to try different initialization strategies (tr…
Browse files Browse the repository at this point in the history
…iton-lang#486)

* Add a few different GEMM init strategies

* Minor fixes

* Fix order when transposed

* Fix transpose

* Update tune_gemm.py

Fix blank lines

* remove init_type from mnks

* Fix trig_float

---------

Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
  • Loading branch information
vgokhale and zhanglx13 authored Jan 27, 2024
1 parent c631824 commit 49f6c3a
Showing 1 changed file with 40 additions and 17 deletions.
57 changes: 40 additions & 17 deletions scripts/amd/gemm/tune_gemm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# fp8
import argparse
import sys
import yaml
Expand Down Expand Up @@ -221,7 +222,7 @@ def generated_kernel_name(M, N, K, gpu_id):
# 4. test_gemm to invoke
# 4.1 run try_config in parallel
# 4.2 matmul in a loop of 10 iterations
def generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, jobs, run_bench):
def generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, run_bench):
filenames = []
for i in range(jobs):
filenames.append(generated_kernel_name(M, N, K, i))
Expand Down Expand Up @@ -259,8 +260,8 @@ def generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, j
# pre string
test_gemm_pre_str = f"""def test_gemm(M, N, K, num_threads):
thread_pool = multiprocessing.Pool(processes=num_threads)
a, a_fp16 = gen_input(M, K, '{dtype_a}', {col_a}, 1, device='cuda')
b, b_fp16 = gen_input(K, N, '{dtype_b}', {col_b}, 2, device='cuda')
a, a_fp16 = gen_input(M, K, '{dtype_a}', {col_a}, 1, '{init_type}', device='cuda')
b, b_fp16 = gen_input(K, N, '{dtype_b}', {col_b}, 2, '{init_type}', device='cuda')
c = torch.zeros((M, N), device=a.device, dtype={tl_to_torch_types[name_to_tl_types[dtype_c]]})
task_args = (M, N, K,
a.stride(0), a.stride(1),
Expand Down Expand Up @@ -359,9 +360,9 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
jobId += ngpus


def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, run_bench, jobs, verbose=0, num_threads=16, gpus=[0]):
def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, run_bench, jobs, verbose=0, num_threads=16, gpus=[0]):
# Generate kernel out of all configs
generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, jobs, run_bench)
generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, run_bench)

# remove any compiled kernel in the cache
run_bash_command("rm -rf ~/.triton/cache")
Expand Down Expand Up @@ -418,7 +419,7 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs,
print(f"post procesing time: {post_time}", flush=True)
return minTime, bestConfig, compile_time, profile_time, post_time

def gen_input(M, N, ty_name, needTrans, seed, device='cuda'):
def gen_input(M, N, ty_name, needTrans, seed, init_type, device='cuda'):
d_type = name_to_tl_types[ty_name]
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
Expand All @@ -431,10 +432,24 @@ def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
output = input
tl.store(output_ptr + offsets, output, mask=mask)

def init_by_size_and_type(size, dtype, init_type):
if init_type == 'hpl':
return torch.empty(size, device='cuda', dtype=dtype).uniform_(-0.5, 0.5)
# This init type has element[i] in row[j] equal to sin(i+j*N)
elif init_type == 'trig_float':
M, N = size
return torch.reshape(torch.arange(0, M*N), (M, N)).sin().to(dtype=dtype, device='cuda')
elif init_type == 'zeros':
return torch.zeros(size, dtype=dtype, device='cuda')
elif init_type == "randn":
temp = torch.randn(size, dtype=dtype, device='cuda')
return temp
else:
raise ValueError("Bad matrix initialization type.")

raw_data = init_by_size_and_type((N,M) if needTrans else (M,N), torch.float32, init_type)
if needTrans:
raw_data = torch.randn((N, M), dtype=torch.float32, device='cuda').T
else:
raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda')
raw_data = raw_data.T
if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \
(d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) or not d_type.is_fp8():
input = raw_data.to(tl_to_torch_types[d_type])
Expand Down Expand Up @@ -481,14 +496,14 @@ def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_
return c


def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, config, verbose):
def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, verbose):
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize = read_config(config)

torch.manual_seed(0)
#a = torch.randn((M, K), device='cuda', dtype=datatype)
#b = torch.randn((K, N), device='cuda', dtype=datatype)
a, a_fp16 = gen_input(M, K, dtype_a, col_a, 1, device='cuda')
b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, device='cuda')
a, a_fp16 = gen_input(M, K, dtype_a, col_a, 1, init_type, device='cuda')
b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, init_type, device='cuda')
# Allocates output.
c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]])
triton_output = matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize)
Expand Down Expand Up @@ -544,6 +559,7 @@ def parse_args():
parser.add_argument("--verbose", action='store_true', default=False, help="enables time_breakdown and additional logging messages")
parser.add_argument("--num_threads", type=int, default=16, help="number of threads to use for kernel compilation and post processing")
parser.add_argument("--jobs", type=int, default=1, help="number of generated files")
parser.add_argument("--init_type", type=str, default='randn', help="Initialization type for input matrices (default uniform rand [0, 1.0)])")
args = parser.parse_args()

return args
Expand Down Expand Up @@ -643,6 +659,7 @@ def main():

mnks = []
# TODO: make it more robust to get user input
init_type = args.init_type
if matrix_size_file == "" or not os.path.isfile(matrix_size_file):
M = args.m
N = args.n
Expand All @@ -660,7 +677,7 @@ def main():
# Check correctness from given configs
if args.compare_wo_tuning:
for (M, N, K, col_a, col_b, myConfig) in mnks:
test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, item, True)
test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, item, True)
return

configs_full = get_full_tuning_space()
Expand All @@ -670,7 +687,7 @@ def main():
print(f"Benchmarking gemm with {dtype_a} inputs")
print("trans M N K TFLOPS")
else:
print(f"Tuning starts at: {start_time}", flush=True)
print(f"Tuning {len(mnks)} gemm sizes starts at: {start_time}", flush=True)
f_results = open(tuning_output_file, 'w')

for (M, N, K, col_a, col_b, myConfig) in mnks:
Expand All @@ -684,14 +701,20 @@ def main():
size_str = f'SIZE: {M} {N} {K} {row_a_str}{row_b_str}'
if not run_bench:
print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True)
else:
print(f"{row_a_str}{row_b_str} {M:5d} {N:5d} {K:5d} ", end="")

# The main tuning funtion for one gemm size
verbose_level = 0
if args.time_breakdown:
verbose_level = 1
if args.verbose:
verbose_level = 2
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, pruned_configs, run_bench, jobs, num_threads=args.num_threads, gpus=gpus, verbose=verbose_level)
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(
M, N, K, col_a, col_b, dtype_a,
dtype_b, dtype_c, init_type, pruned_configs,
run_bench, jobs, num_threads=args.num_threads, gpus=gpus,
verbose=verbose_level)

# post processing the numbers
perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6)
Expand All @@ -707,7 +730,7 @@ def main():

# write best config to tuning_results.yaml
if run_bench:
print(f"{row_a_str}{row_b_str} {M:5d} {N:5d} {K:5d} {formatted_tflops}")
print(f"{formatted_tflops}")

sizeDict = {'M': M, 'N': N, 'K': K, 'rowMajorA': row_a_str, 'rowMajorB': row_b_str}
sizeDict.update(bestConfig)
Expand All @@ -727,7 +750,7 @@ def main():
# Check correctness if asked to
if args.compare:
print("correctness: ", end=" ", flush=True)
test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, bestConfig, False)
test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, bestConfig, False)
elif not run_bench:
print("", flush=True)

Expand Down

0 comments on commit 49f6c3a

Please sign in to comment.