Skip to content

Commit

Permalink
Rewrite rot (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi authored Oct 4, 2023
1 parent c621da3 commit 0d69185
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 130 deletions.
38 changes: 15 additions & 23 deletions src/common/composed_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def stage_expr(proc, expr_cursors, new_name, precision="R", memory=DRAM, n_lifts
return proc


def stage_alloc(proc, alloc_cursor):
def stage_alloc(proc, alloc_cursor, n_lifts=1):
"""
for i in seq(0, hi):
B1;
Expand All @@ -149,7 +149,7 @@ def stage_alloc(proc, alloc_cursor):
proc = expand_dim(
proc, alloc_cursor, expr_to_string(enclosing_loop.hi()), enclosing_loop.name()
)
proc = lift_alloc(proc, alloc_cursor, n_lifts=1)
proc = lift_alloc(proc, alloc_cursor, n_lifts=n_lifts)
return proc


Expand Down Expand Up @@ -237,30 +237,26 @@ def vectorize_to_loops(proc, loop_cursor, vec_width, memory_type, precision):
else:
inner_loop_cursor = loop_cursor

inner_loop_stmts = list(inner_loop_cursor.body())

staged_allocs = []

for stmt in inner_loop_stmts:
if isinstance(stmt, AllocCursor):
proc = stage_alloc(proc, stmt)
staged_allocs.append(stmt)

inner_loop_cursor = proc.forward(inner_loop_cursor)

stmts = []

def fission_stmts(proc, body, depth=1):
body_list = list(body)
for stmt in body_list[:-1]:
forwarded_stmt = proc.forward(stmt)
stmts.append(stmt)
proc = fission(proc, forwarded_stmt.after(), n_lifts=depth)
forwarded_stmt = proc.forward(stmt)
if isinstance(forwarded_stmt, IfCursor):
proc = fission_stmts(proc, forwarded_stmt.body(), depth + 1)
elif isinstance(forwarded_stmt, ForSeqCursor):
raise BLAS_SchedulingError("This is an inner loop vectorizer")
if isinstance(stmt, AllocCursor):
proc = stage_alloc(proc, stmt, n_lifts=depth)
proc = set_memory(proc, stmt, memory_type)
proc = set_precision(proc, stmt, precision)
else:
forwarded_stmt = proc.forward(stmt)
stmts.append(stmt)
proc = fission(proc, forwarded_stmt.after(), n_lifts=depth)
forwarded_stmt = proc.forward(stmt)
if isinstance(forwarded_stmt, IfCursor):
proc = fission_stmts(proc, forwarded_stmt.body(), depth + 1)
elif isinstance(forwarded_stmt, ForSeqCursor):
raise BLAS_SchedulingError("This is an inner loop vectorizer")
forwarded_stmt = body_list[-1]
stmts.append(forwarded_stmt)
if isinstance(forwarded_stmt, IfCursor):
Expand Down Expand Up @@ -377,10 +373,6 @@ def vectorize_stmt(proc, stmt, depth=1):
assert len(inner_loop.body()) == 1
proc = vectorize_stmt(proc, inner_loop.body()[0])

for alloc in staged_allocs:
proc = set_memory(proc, alloc, memory_type)
proc = set_precision(proc, alloc, precision)

return proc


Expand Down
148 changes: 146 additions & 2 deletions src/common/machines/avx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def avx2_abs_pd(dst: [f64][8] @ AVX2, src: [f64][8] @ AVX2):
__m256i prefix = _mm256_set1_epi32({bound});
__m256i cmp = _mm256_cmpgt_epi32(prefix, indices);
__m256 src_abs = _mm256_and_ps({src_data}, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)));
{dst_data} = _mm256_blendv_ps ({src_data}, src_abs, _mm256_castsi256_ps(cmp));
{dst_data} = _mm256_blendv_ps ({dst_data}, src_abs, _mm256_castsi256_ps(cmp));
}}
"""
)
Expand All @@ -379,7 +379,7 @@ def avx2_prefix_abs_ps(dst: [f32][8] @ AVX2, src: [f32][8] @ AVX2, bound: size):
__m256i prefix = _mm256_set1_epi64x({bound});
__m256i cmp = _mm256_cmpgt_epi64(prefix, indices);
__m256d src_abs = _mm256_and_pd({src_data}, _mm256_castsi256_pd(_mm256_set1_epi64x(0x7FFFFFFFFFFFFFFFLL)));
{dst_data} = _mm256_blendv_pd ({src_data}, src_abs, _mm256_castsi256_pd(cmp));
{dst_data} = _mm256_blendv_pd ({dst_data}, src_abs, _mm256_castsi256_pd(cmp));
}}
"""
)
Expand All @@ -392,6 +392,144 @@ def avx2_prefix_abs_pd(dst: [f64][4] @ AVX2, src: [f64][4] @ AVX2, bound: size):
dst[i] = select(0.0, src[i], src[i], -src[i])


@instr(
"""
{{
__m256i indices = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
__m256i prefix = _mm256_set1_epi32({bound});
__m256i cmp = _mm256_cmpgt_epi32(prefix, indices);
__m256 src_sign = _mm256_mul_ps({src_data}, _mm256_set1_ps(-1.0f));;
{dst_data} = _mm256_blendv_ps ({dst_data}, src_sign, _mm256_castsi256_ps(cmp));
}}
"""
)
def avx2_prefix_sign_ps(dst: [f32][8] @ AVX2, src: [f32][8] @ AVX2, bound: size):
assert stride(dst, 0) == 1
assert stride(src, 0) == 1
assert bound <= 8

for i in seq(0, 8):
if i < bound:
dst[i] = -src[i]


@instr(
"""
{{
__m256i indices = _mm256_set_epi64x(3, 2, 1, 0);
__m256i prefix = _mm256_set1_epi64x({bound});
__m256i cmp = _mm256_cmpgt_epi64(prefix, indices);
__m256d src_sign = _mm256_mul_pd({src_data}, _mm256_set1_pd(-1.0f));
{dst_data} = _mm256_blendv_pd ({dst_data}, src_sign, _mm256_castsi256_pd(cmp));
}}
"""
)
def avx2_prefix_sign_pd(dst: [f64][4] @ AVX2, src: [f64][4] @ AVX2, bound: size):
assert stride(dst, 0) == 1
assert stride(src, 0) == 1
assert bound <= 4

for i in seq(0, 4):
if i < bound:
dst[i] = -src[i]


@instr(
"""
{{
__m256i indices = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
__m256i prefix = _mm256_set1_epi32({bound});
__m256i cmp = _mm256_cmpgt_epi32(prefix, indices);
__m256 mul = _mm256_mul_ps({x_data}, {y_data});
{out_data} = _mm256_blendv_ps ({out_data}, mul, _mm256_castsi256_ps(cmp));
}}
"""
)
def mm256_prefix_mul_ps(
out: [f32][8] @ AVX2, x: [f32][8] @ AVX2, y: [f32][8] @ AVX2, bound: size
):
assert stride(out, 0) == 1
assert stride(x, 0) == 1
assert stride(y, 0) == 1
assert bound <= 8

for i in seq(0, 8):
if i < bound:
out[i] = x[i] * y[i]


@instr(
"""
{{
__m256i indices = _mm256_set_epi64x(3, 2, 1, 0);
__m256i prefix = _mm256_set1_epi64x({bound});
__m256i cmp = _mm256_cmpgt_epi64(prefix, indices);
__m256d mul = _mm256_mul_pd({x_data}, {y_data});
{out_data} = _mm256_blendv_pd ({out_data}, mul, _mm256_castsi256_pd(cmp));
}}
"""
)
def mm256_prefix_mul_pd(
out: [f64][4] @ AVX2, x: [f64][4] @ AVX2, y: [f64][4] @ AVX2, bound: size
):
assert stride(out, 0) == 1
assert stride(x, 0) == 1
assert stride(y, 0) == 1
assert bound <= 4

for i in seq(0, 4):
if i < bound:
out[i] = x[i] * y[i]


@instr(
"""
{{
__m256i indices = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
__m256i prefix = _mm256_set1_epi32({bound});
__m256i cmp = _mm256_cmpgt_epi32(prefix, indices);
__m256 add = _mm256_add_ps({x_data}, {y_data});
{out_data} = _mm256_blendv_ps ({out_data}, add, _mm256_castsi256_ps(cmp));
}}
"""
)
def mm256_prefix_add_ps(
out: [f32][8] @ AVX2, x: [f32][8] @ AVX2, y: [f32][8] @ AVX2, bound: size
):
assert stride(out, 0) == 1
assert stride(x, 0) == 1
assert stride(y, 0) == 1
assert bound <= 8

for i in seq(0, 8):
if i < bound:
out[i] = x[i] + y[i]


@instr(
"""
{{
__m256i indices = _mm256_set_epi64x(3, 2, 1, 0);
__m256i prefix = _mm256_set1_epi64x({bound});
__m256i cmp = _mm256_cmpgt_epi64(prefix, indices);
__m256d add = _mm256_add_pd({x_data}, {y_data});
{out_data} = _mm256_blendv_pd ({out_data}, add, _mm256_castsi256_pd(cmp));
}}
"""
)
def mm256_prefix_add_pd(
out: [f64][4] @ AVX2, x: [f64][4] @ AVX2, y: [f64][4] @ AVX2, bound: size
):
assert stride(out, 0) == 1
assert stride(x, 0) == 1
assert stride(y, 0) == 1
assert bound <= 4

for i in seq(0, 4):
if i < bound:
out[i] = x[i] + y[i]


Machine = MachineParameters(
name="avx2",
mem_type=AVX2,
Expand All @@ -417,11 +555,14 @@ def avx2_prefix_abs_pd(dst: [f64][4] @ AVX2, src: [f64][4] @ AVX2, bound: size):
assoc_reduce_add_instr_f32=avx2_assoc_reduce_add_ps,
assoc_reduce_add_f32_buffer=avx2_assoc_reduce_add_ps_buffer,
mul_instr_f32=mm256_mul_ps,
prefix_mul_instr_f32=mm256_prefix_mul_ps,
add_instr_f32=mm256_add_ps,
prefix_add_instr_f32=mm256_prefix_add_ps,
reduce_add_wide_instr_f32=avx2_reduce_add_wide_ps,
prefix_reduce_add_wide_instr_f32=avx2_prefix_reduce_add_wide_ps,
reg_copy_instr_f32=avx2_reg_copy_ps,
sign_instr_f32=avx2_sign_ps,
prefix_sign_instr_f32=avx2_prefix_sign_ps,
select_instr_f32=avx2_select_ps,
abs_instr_f32=avx2_abs_ps,
prefix_abs_instr_f32=avx2_prefix_abs_ps,
Expand All @@ -440,12 +581,15 @@ def avx2_prefix_abs_pd(dst: [f64][4] @ AVX2, src: [f64][4] @ AVX2, bound: size):
set_zero_instr_f64=mm256_setzero_pd,
assoc_reduce_add_instr_f64=avx2_assoc_reduce_add_pd,
mul_instr_f64=mm256_mul_pd,
prefix_mul_instr_f64=mm256_prefix_mul_pd,
add_instr_f64=mm256_add_pd,
prefix_add_instr_f64=mm256_prefix_add_pd,
reduce_add_wide_instr_f64=avx2_reduce_add_wide_pd,
prefix_reduce_add_wide_instr_f64=avx2_prefix_reduce_add_wide_pd,
assoc_reduce_add_f64_buffer=avx2_assoc_reduce_add_pd_buffer,
reg_copy_instr_f64=avx2_reg_copy_pd,
sign_instr_f64=avx2_sign_pd,
prefix_sign_instr_f64=avx2_prefix_sign_pd,
select_instr_f64=avx2_select_pd,
abs_instr_f64=avx2_abs_pd,
prefix_abs_instr_f64=avx2_prefix_abs_pd,
Expand Down
6 changes: 6 additions & 0 deletions src/common/machines/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ class MachineParameters:
set_zero_instr_f32: Any
assoc_reduce_add_instr_f32: Any
mul_instr_f32: Any
prefix_mul_instr_f32: Any
add_instr_f32: Any
prefix_add_instr_f32: Any
reduce_add_wide_instr_f32: Any
prefix_reduce_add_wide_instr_f32: Any
reg_copy_instr_f32: Any
sign_instr_f32: Any
prefix_sign_instr_f32: Any
select_instr_f32: Any
assoc_reduce_add_f32_buffer: Any
abs_instr_f32: Any
Expand All @@ -64,12 +67,15 @@ class MachineParameters:
set_zero_instr_f64: Any
assoc_reduce_add_instr_f64: Any
mul_instr_f64: Any
prefix_mul_instr_f64: Any
add_instr_f64: Any
prefix_add_instr_f64: Any
reduce_add_wide_instr_f64: Any
prefix_reduce_add_wide_instr_f64: Any
assoc_reduce_add_f64_buffer: Any
reg_copy_instr_f64: Any
sign_instr_f64: Any
prefix_sign_instr_f64: Any
select_instr_f64: Any
abs_instr_f64: Any
prefix_abs_instr_f64: Any
Expand Down
6 changes: 6 additions & 0 deletions src/common/machines/neon.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,14 @@ def neon_vst_2xf64_backwards(dst: [f64][2] @ DRAM, src: [f64][2] @ Neon):
set_zero_instr_f32=neon_zero_4xf32,
assoc_reduce_add_instr_f32=neon_assoc_reduce_add_instr_4xf32,
mul_instr_f32=neon_vmul_4xf32,
prefix_mul_instr_f32=None,
add_instr_f32=neon_vadd_4xf32,
prefix_add_instr_f32=None,
reduce_add_wide_instr_f32=neon_reduce_vadd_4xf32,
prefix_reduce_add_wide_instr_f32=None,
reg_copy_instr_f32=neon_reg_copy_4xf32,
sign_instr_f32=neon_vneg_4xf32,
prefix_sign_instr_f32=None,
select_instr_f32=None,
assoc_reduce_add_f32_buffer=neon_assoc_reduce_add_instr_4xf32_buffer,
abs_instr_f32=None,
Expand All @@ -128,11 +131,14 @@ def neon_vst_2xf64_backwards(dst: [f64][2] @ DRAM, src: [f64][2] @ Neon):
assoc_reduce_add_instr_f64=neon_assoc_reduce_add_instr_2xf64,
assoc_reduce_add_f64_buffer=neon_assoc_reduce_add_instr_2xf64_buffer,
mul_instr_f64=neon_vmul_2xf64,
prefix_mul_instr_f64=None,
add_instr_f64=neon_vadd_2xf64,
prefix_add_instr_f64=None,
reduce_add_wide_instr_f64=neon_reduce_vadd_2xf64,
prefix_reduce_add_wide_instr_f64=None,
reg_copy_instr_f64=neon_reg_copy_2xf64,
sign_instr_f64=neon_vneg_2xf64,
prefix_sign_instr_f64=None,
select_instr_f64=None,
abs_instr_f64=None,
prefix_abs_instr_f64=None,
Expand Down
Loading

0 comments on commit 0d69185

Please sign in to comment.