Skip to content

Commit

Permalink
fix glb tiling when not fully unrolled by channel dim
Browse files Browse the repository at this point in the history
  • Loading branch information
yuchen-mei committed Oct 11, 2024
1 parent 5307d24 commit 1726746
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 4 deletions.
41 changes: 41 additions & 0 deletions apps/hardware_benchmarks/hw_support/mat2raw_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import scipy.io
import numpy as np
import sys
import os

def mat_to_raw(input_name, output_name):
print(f"[mat2raw] Converting {input_name} to {output_name}...")

mat = scipy.io.loadmat(input_name)
# Assuming there's only one variable of interest and it's not one of the
# automatic variables added by MATLAB ('__version__', '__header__', and '__globals__')
data_keys = [key for key in mat.keys() if not key.startswith('__')]
if len(data_keys) != 1:
raise ValueError(f"The .mat file {input_name} contains none or more than one variable.")

array = np.array(mat[data_keys[0]], dtype=np.uint16)
# Transpose the array. Reverse the axes.
transposed_array = array.transpose(np.arange(array.ndim)[::-1])

transposed_array.byteswap().tofile(output_name)
print(f"[mat2raw] Conversion of {output_name} Complete.")

def convert_all_mat_files(directory):
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith(".mat"):
input_path = os.path.join(root, file)
output_path = os.path.join(root, os.path.splitext(file)[0] + ".raw")
mat_to_raw(input_path, output_path)

if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python mat2raw_batch.py <directory_path>")
sys.exit(1)

directory = sys.argv[1]
if not os.path.isdir(directory):
print(f"Error: {directory} is not a valid directory.")
sys.exit(1)

convert_all_mat_files(directory)
51 changes: 47 additions & 4 deletions apps/hardware_benchmarks/hw_support/parse_design_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,45 @@ def unflatten_extent(addr_dict, X_dim, pad_o_left=0, pad_o_right=0):
raise ValueError("write_data_stride or read_data_stride not found in addr_dict")
assert found_X_cnt == 2, "X_dim and Y_dim not found in addr_dict['extent']"

def unflatten_extent_glb_tiling(addr_dict, X_dim):

found_X_cnt = 0
# add dimension if X_dim or Y_dim are flattened
for i, ext in enumerate(addr_dict['extent']):
if ext == X_dim:
found_X_cnt += 1
if found_X_cnt == 2:
break
elif ext == X_dim * X_dim:
found_X_cnt += 2
addr_dict['dimensionality'] += 1
addr_dict['extent'][i] = X_dim
addr_dict['extent'].insert(i+1, X_dim)
addr_dict['cycle_stride'].insert(i+1, addr_dict['cycle_stride'][i] * X_dim)
if 'write_data_stride' in addr_dict:
addr_dict['write_data_stride'].insert(i+1, addr_dict['write_data_stride'][i] * X_dim)
elif 'read_data_stride' in addr_dict:
addr_dict['read_data_stride'].insert(i+1, addr_dict['read_data_stride'][i] * X_dim)
break
elif ext % (X_dim * X_dim) == 0:
assert addr_dict['dimensionality'] == 1, "Should be fully flattened in this case"
found_X_cnt += 2
addr_dict['dimensionality'] += 2
addr_dict['extent'][i] = ext // (X_dim * X_dim)
addr_dict['extent'].insert(i+1, X_dim)
addr_dict['extent'].insert(i+2, X_dim)
addr_dict['cycle_stride'].insert(i+1, addr_dict['cycle_stride'][i] * (ext // (X_dim * X_dim)))
addr_dict['cycle_stride'].insert(i+2, addr_dict['cycle_stride'][i+1] * addr_dict['extent'][i+1])
if 'write_data_stride' in addr_dict:
addr_dict['write_data_stride'].insert(i+1, addr_dict['write_data_stride'][i] * (ext // (X_dim * X_dim)))
addr_dict['write_data_stride'].insert(i+2, addr_dict['write_data_stride'][i+1] * addr_dict['extent'][i+1])
elif 'read_data_stride' in addr_dict:
addr_dict['read_data_stride'].insert(i+1, addr_dict['read_data_stride'][i] * (ext // (X_dim * X_dim)))
addr_dict['read_data_stride'].insert(i+2, addr_dict['read_data_stride'][i+1] * addr_dict['extent'][i+1])
break

assert found_X_cnt == 2, "X_dim and Y_dim not found in addr_dict['extent']"

def parseLoopExtentforPadding(meta, halide_gen_args):
# Get pad_o values
args = halide_gen_args.split()
Expand Down Expand Up @@ -259,8 +298,10 @@ def parseLoopExtentforTiling(meta, halide_gen_args):
X_dim = shape_list[-1]
for io_tile in io_tiles_list:
addr_dict = io_tile['addr']
unflatten_extent(addr_dict, X_dim)
assert addr_dict['dimensionality'] == 2, "Implement fully unrolling along channel first"
if X_dim != 1:
print("Unflattening extent of input for GLB tiling\n", addr_dict)
unflatten_extent_glb_tiling(addr_dict, X_dim)
# assert addr_dict['dimensionality'] == 2, "Implement fully unrolling along channel first"
for output in meta['IOs']['outputs']:
io_tile_list = output['io_tiles']
# Get X_dim
Expand All @@ -269,8 +310,10 @@ def parseLoopExtentforTiling(meta, halide_gen_args):
X_dim = shape_list[-1]
for io_tile in io_tile_list:
addr_dict = io_tile['addr']
unflatten_extent(addr_dict, X_dim)
assert addr_dict['dimensionality'] == 2, "Implement fully unrolling along channel first"
if X_dim != 1:
print("Unflattening extent of output for GLB tiling\n", addr_dict)
unflatten_extent_glb_tiling(addr_dict, X_dim)
# assert addr_dict['dimensionality'] == 2, "Implement fully unrolling along channel first"

# Add HALIDE_GEN_ARGS to meta file
if meta.get("HALIDE_GEN_ARGS") is None: meta["HALIDE_GEN_ARGS"] = args_dict
Expand Down
42 changes: 42 additions & 0 deletions apps/hardware_benchmarks/hw_support/raw2txt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import sys

def raw_to_txt(raw_file_path, txt_file_path):
try:
# Open the .raw file in binary mode
with open(raw_file_path, 'rb') as raw_file:
# Read the binary data
raw_data = raw_file.read()

# Open the .txt file in write mode
with open(txt_file_path, 'w') as txt_file:
# Ensure the data length is a multiple of 2 (for 16-bit words)
if len(raw_data) % 2 != 0:
raise ValueError("The .raw file contains incomplete 16-bit words.")

# Convert the binary data to 16-bit words (2 bytes per word)
word_data = [raw_data[i:i+2] for i in range(0, len(raw_data), 2)]

# Write 8 words per line, with a space separating each word
for i in range(0, len(word_data), 8):
line_words = word_data[i:i+8]
hex_words = [word.hex() for word in line_words]
txt_file.write(' '.join(hex_words) + '\n')

print(f"Successfully converted {raw_file_path} to {txt_file_path}")

except Exception as e:
print(f"An error occurred: {e}")

if __name__ == "__main__":
# Example usage:
if len(sys.argv) != 3:
print("Usage: python raw_to_txt.py <input.raw> <output.txt>")
else:
raw_file = sys.argv[1]
txt_file = sys.argv[2]

if not os.path.exists(raw_file):
print(f"Error: The file {raw_file} does not exist.")
else:
raw_to_txt(raw_file, txt_file)

0 comments on commit 1726746

Please sign in to comment.