Skip to content

Commit

Permalink
[MetaSchedule][Hexagon] Improve vectorization for standalone elementw…
Browse files Browse the repository at this point in the history
…ise op (#14408)

[MetaSchedule][Hexagon] Improve vectorization for standalone elementwise ops

Motivation:
It was found that for standalone elementwise operations (add, sub, etc.)
MetaScheduler generates code with poor performance due to lack of vector
code on some input tensor shapes. Current implementation is not able to
vectorize if innermost loops extent is not multiple of the vector
length.

What was done:
Core changes: it checks current loops nest, if all loops are "simple",
i.e. loops without annotations, bindings, reduce axis, then it does the
following:
 1) Fuse all loops into single one.
 2) Split this new loop into 2 parts: inner and outer. Herewith split
    factor for the inner loop is equal to 'max_vectorize_extent'
    MetaScheduler parameter.
 3) Parallelize outer loop and vectorize inner loop.

Performance measurement:
Measurement was done on Qualcomm Snapdragon 888. As it was expected, 1
and 2 got significant performance boost, 3 and 4 - without changes.

N |    op   | Dtype |      Shape       | Before fix, ms | After fix, ms | speedup |
--|---------|-------|------------------|----------------|---------------|---------|
1 | add     | uint8 | 1, 8, 56, 56, 32 |      1.264     |     0.167     |  7.5x   |
2 | qnn.add | uint8 | 1, 8, 56, 56, 32 |      2.213     |     0.336     |  6.6x   |
3 | add     | int32 | 1, 8, 56, 56, 32 |      0.161     |     0.150     |  1.07x  |
4 | seq*    | uint8 | 1, 64, 56, 56    |      2.634     |     2.679     |  0.98x  |
----------------------------------------------------------------------------------|

seq* - test of the ops sequence: qnn.conv2d + bias_add + qnn.requantize,
       weights shape = [256, 64, 1, 1]
  • Loading branch information
ibsidorenko authored Mar 28, 2023
1 parent a0edf24 commit 14ddb37
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 11 deletions.
81 changes: 70 additions & 11 deletions src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,40 @@ void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedA
}
}

int CalculateNumRewritableLoops(const Array<StmtSRef>& loop_srefs,
const std::vector<int>& loop_types) {
int rw_loops_num = 0;
ICHECK_EQ(loop_srefs.size(), loop_types.size());
for (size_t i = 0; i < loop_srefs.size(); ++i) {
const StmtSRef& loop_sref = loop_srefs[i];
const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
if (HasAnnOrBinding(loop)) {
continue;
}
// Cannot vectorize reduce axis
if (loop_types[i] != IterVarType::kDataPar) {
continue;
}
// Cannot fuse with a loop with multiple children
if (!IsSingleStmt(loop->body)) {
continue;
}
// Check if the loop extent is valid
if (GetLoopIntExtent(loop_sref) == nullptr) {
continue;
}
++rw_loops_num;
}
return rw_loops_num;
}

void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv,
const Array<LoopRV>& loop_rvs, ParsedAnnotation* parsed) {
StmtSRef block_sref = sch->GetSRef(block_rv);
if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) {
return;
}
int n_loops = loop_rvs.size();
const int n_loops = loop_rvs.size();
if (n_loops == 0) {
parsed->max_parallel_extent = -1;
parsed->max_vectorize_extent = -1;
Expand Down Expand Up @@ -226,6 +253,10 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv,
}
max_fusible = std::min(max_fusible, fusible);
}

// Calculate how many loops are rewritable, i.e. valid for vectorization and parallelization.
int max_rw_loops = CalculateNumRewritableLoops(loop_srefs, loop_types);

// Calculate the parallelize extent
if (parsed->max_parallel_extent != -1) {
int max_extent = parsed->max_parallel_extent;
Expand Down Expand Up @@ -290,10 +321,17 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv,
num_fusible = -1;
}
}
// Prefer num_vectorize to num_parallel

if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) {
parsed->num_parallel_loops = std::min(parsed->num_parallel_loops, //
n_loops - parsed->num_vectorize_loops);
if (max_rw_loops == n_loops && max_fusible == n_loops) {
// All loops can be fused, parallelized and vectorized
parsed->num_parallel_loops = n_loops;
parsed->num_vectorize_loops = n_loops;
} else {
// Prefer num_vectorize to num_parallel
parsed->num_parallel_loops =
std::min(parsed->num_parallel_loops, n_loops - parsed->num_vectorize_loops);
}
}
}

Expand All @@ -317,6 +355,21 @@ bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, Block
return false;
}

void RewriteFuseSplitParallelVectorize(const Schedule& sch, Array<LoopRV>* loop_rvs, int vec_len) {
size_t n_loops = loop_rvs->size();
LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->end()});
Array<LoopRV> split = sch->Split(fused, {NullOpt, Integer(vec_len)});
ICHECK_EQ(split.size(), 2);
const LoopRV& outer = split[0];
const LoopRV& inner = split[1];
sch->Parallel(outer);
sch->Vectorize(inner);
for (size_t i = 0; i < n_loops - 1; ++i) {
loop_rvs->Set(i, outer);
}
loop_rvs->Set(n_loops - 1, inner);
}

void RewriteParallel(const Schedule& sch, size_t n, Array<LoopRV>* loop_rvs) {
ICHECK_LE(n, loop_rvs->size());
LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n});
Expand Down Expand Up @@ -364,13 +417,19 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode {
}
tir::ParsedAnnotation parsed = parsed_root;
tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed);
// Parallel
if (parsed.num_parallel_loops > 0) {
tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs);
}
// Vectorize
if (parsed.num_vectorize_loops > 0) {
tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs);
const int loops_num = loop_rvs.size();
if (parsed.num_parallel_loops == loops_num && parsed.num_vectorize_loops == loops_num) {
// Fuse, split, vectorize and parallelize
tir::RewriteFuseSplitParallelVectorize(sch, &loop_rvs, parsed.max_vectorize_extent);
} else {
// Parallel
if (parsed.num_parallel_loops > 0) {
tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs);
}
// Vectorize
if (parsed.num_vectorize_loops > 0) {
tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs);
}
}
// AutoUnroll
if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,42 @@ def after_matmul_vectorize(
T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1]


@T.prim_func
def before_postproc_add(
lhs: T.Buffer((1, 8, 56, 56, 32), "uint8"),
rhs: T.Buffer((1, 8, 56, 56, 32), "uint8"),
add_compute: T.Buffer((1, 8, 56, 56, 32), "uint8"),
) -> None:
with T.block("root"):
T.block_attr({"meta_schedule.parallel":64, "meta_schedule.vectorize":128})
for n, c0, h, w, c1 in T.grid(1, 8, 56, 56, 32):
with T.block("add_compute"):
v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [n, c0, h, w, c1])
T.reads(lhs[v0, v1, v2, v3, v4], rhs[v0, v1, v2, v3, v4])
T.writes(add_compute[v0, v1, v2, v3, v4])
add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] + rhs[v0, v1, v2, v3, v4]


@T.prim_func
def after_postproc_add(
lhs: T.Buffer((1, 8, 56, 56, 32), "uint8"),
rhs: T.Buffer((1, 8, 56, 56, 32), "uint8"),
add_compute: T.Buffer((1, 8, 56, 56, 32), "uint8"),
) -> None:
with T.block("root"):
for n_c0_h_w_c1_fused_0 in T.parallel(0, 6272):
for n_c0_h_w_c1_fused_1 in T.vectorized(0, 128):
with T.block("add_compute"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(8, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) // 100352)
v2 = T.axis.spatial(56, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 100352 // 1792)
v3 = T.axis.spatial(56, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 1792 // 32)
v4 = T.axis.spatial(32, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 32)
T.reads(lhs[v0, v1, v2, v3, v4], rhs[v0, v1, v2, v3, v4])
T.writes(add_compute[v0, v1, v2, v3, v4])
add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] + rhs[v0, v1, v2, v3, v4]


# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable

Expand All @@ -161,6 +197,14 @@ def test_vectorize_inner_loop():
tvm.ir.assert_structural_equal(sch.mod["main"], after_matmul_vectorize)


def test_parallel_vectorize_add():
sch = Schedule(before_postproc_add)
rule = RewriteParallelVectorizeUnroll()
assert rule.apply(sch)
tvm.ir.assert_structural_equal(sch.mod["main"], after_postproc_add)


if __name__ == "__main__":
test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize()
test_vectorize_inner_loop()
test_parallel_vectorize_add()

0 comments on commit 14ddb37

Please sign in to comment.