Skip to content

Commit

Permalink
[Arith] Inverse affine map (apache#8384)
Browse files Browse the repository at this point in the history
* [Arith] Inverse affine map

* [Arith] Inverse affine map

* Update iter_affine_map.h

* Update iter_affine_map.h

* Update iter_affine_map.py

* Topology order visit

* doc

* fix

* address comments

* lint

* remove print
  • Loading branch information
vinx13 authored and ylc committed Sep 29, 2021
1 parent 79520f8 commit 20b9837
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 1 deletion.
21 changes: 21 additions & 0 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,27 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer);

/*!
* \brief Apply the inverse of the affine transformation to the outputs.
*
* Similar to the back-propagation, starting from the outputs, it visits the DAG of the expressions
* in reverse topology order and applies the inverse of the affine transformation until it reaches
* the input. The affine iter map is required to be bijective.
*
* For example, iter_map = [l0 // 16, l0 % 16], outputs = [output_0, output_1],
* the affine transformation specified by `iter_map` will be applied to `outputs` and the result
* will be {l0: ((output_0*16) + output_1)}.
*
* \sa DetectIterMap
*
* \param iter_map The bijective affine iter map.
* \param outputs The outputs of the affine transformation.
*
* \return The map from the input to the transformed result.
*/
Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
const Array<PrimExpr> outputs);

/*!
* \brief Detect if bindings can be written as
* [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,9 @@
from .pattern import detect_linear_equation, detect_clip_bound
from .int_solver import solve_linear_equations, solve_linear_inequalities
from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr, subspace_divide
from .iter_affine_map import (
detect_iter_map,
normalize_iter_map_to_expr,
subspace_divide,
inverse_affine_iter_map,
)
27 changes: 27 additions & 0 deletions python/tvm/arith/iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,30 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi
Empty array if no match can be found.
"""
return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, require_bijective)


def inverse_affine_iter_map(iter_map, outputs):
"""Apply the inverse of the affine transformation to the outputs.
Similar to the back-propagation, starting from the outputs, it visits the DAG of the expressions
in reverse topology order and applies the inverse of the affine transformation until it reaches
the input. The affine iter map is required to be bijective.
For example, iter_map = [l0 // 16, l0 % 16], outputs = [output_0, output_1],
the affine transformation specified by `iter_map` will be applied to `outputs` and the result
will be {l0: ((output_0*16) + output_1)}.
See also :any:`detect_iter_map`.
Parameters
----------
iter_map : List[IterSumExpr]
The bijective affine iter map.
outputs : List[PrimExpr]
The outputs of the affine transformation.
Returns
-------
results : Map[Var, PrimExpr]
The map from the input to the transformed result.
"""
return _ffi_api.InverseAffineIterMap(iter_map, outputs)
142 changes: 142 additions & 0 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1385,5 +1385,147 @@ TVM_REGISTER_GLOBAL("arith.SubspaceDivide")
return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana);
});

class InverseAffineIterMapTransformer {
public:
explicit InverseAffineIterMapTransformer(Analyzer* analyzer) : analyzer_(analyzer) {}

Map<Var, PrimExpr> operator()(const Array<IterSumExpr>& iter_map,
const Array<PrimExpr>& outputs) {
ICHECK(iter_map.size() == outputs.size());
std::vector<const IterMapExprNode*> post_dfs_order = ReverseTopologyOrder(iter_map);

// initialize back propagation accumulator
for (const IterMapExprNode* node : post_dfs_order) {
backprop_.Set(GetRef<IterMapExpr>(node), Integer(0));
}
for (size_t i = 0; i < iter_map.size(); i++) {
backprop_.Set(iter_map[i], outputs[i]);
}

// run back propagation
for (const IterMapExprNode* node : post_dfs_order) {
if (node->IsInstance<IterSumExprNode>()) {
Visit_(Downcast<IterSumExpr>(GetRef<IterMapExpr>(node)));
} else {
ICHECK(node->IsInstance<IterSplitExprNode>());
Visit_(Downcast<IterSplitExpr>(GetRef<IterMapExpr>(node)));
}
}
return std::move(inverse_);
}

private:
void Visit_(const IterSumExpr& iter_map_expr) {
PrimExpr input = backprop_.at(iter_map_expr) - iter_map_expr->base;

// Case 1: Propagate to the input node directly when the sum expression has only one components
if (iter_map_expr->args.size() == 1) {
const auto& source = iter_map_expr->args[0];
backprop_.Set(source, backprop_.at(source) + input);
return;
}

// Case 2: If the sum expression has multiple components, match the fuse pattern and then split
// the sum expression for each components.
// For example, consider the iterator i1[dom = (0, 16)], i2[dom = (0, 8)], fusing i1 and i2
// we will have i1_i2_fused[dom = (0, 64)]. During back propagation, we need to split the
// propagated value to get the corresponding components of i1 and i2, which are
// floordiv(i1_i2_fused, 8) and floormod(i1_i2_fused, 8), respectively.
Array<IterSplitExpr> splits = MatchFusePattern(iter_map_expr);
ICHECK(!splits.empty());

for (const IterSplitExpr& split : splits) {
backprop_.Set(split,
backprop_.at(split) + floormod(floordiv(input, split->scale), split->extent));
}
}

std::vector<const IterMapExprNode*> ReverseTopologyOrder(const Array<IterSumExpr>& iter_map) {
std::vector<const IterMapExprNode*> post_dfs_order;
std::unordered_map<IterMapExpr, bool, ObjectPtrHash, ObjectPtrEqual> visited;

std::function<void(const IterMapExpr&)> fvisit = [&](const IterMapExpr& expr) {
if (visited[expr]) {
return;
}
visited[expr] = true;
if (const auto* sum_expr = expr.as<IterSumExprNode>()) {
for (const IterSplitExpr& child : sum_expr->args) {
fvisit(child);
}
} else {
const auto* split_expr = expr.as<IterSplitExprNode>();
ICHECK(split_expr);
if (const auto* source = split_expr->source->source.as<IterMapExprNode>()) {
fvisit(GetRef<IterMapExpr>(source));
}
}
post_dfs_order.push_back(expr.get());
};
for (const IterSumExpr& expr : iter_map) {
fvisit(expr);
}
std::reverse(post_dfs_order.begin(), post_dfs_order.end());
return post_dfs_order;
}

void Visit_(const IterSplitExpr& iter_map_expr) {
PrimExpr input = backprop_.at(iter_map_expr) * iter_map_expr->lower_factor;
const IterMark& source = iter_map_expr->source;
if (source->source.as<IterSumExprNode>()) {
IterSumExpr source_expr = Downcast<IterSumExpr>(source->source);
backprop_.Set(source_expr, backprop_.at(source_expr) + input);
} else {
Var source_var = Downcast<Var>(source->source);
if (inverse_.count(source_var)) {
inverse_.Set(source_var, inverse_.at(source_var) + input);
} else {
inverse_.Set(source_var, input);
}
}
}

Array<IterSplitExpr> MatchFusePattern(const IterSumExpr sum_expr) {
IntImm base_scale(nullptr);
size_t base_index = 0;
for (size_t i = 0; i < sum_expr->args.size(); ++i) {
if (const auto* op = sum_expr->args[i]->scale.as<IntImmNode>()) {
if (!base_scale.defined() || op->value < base_scale->value) {
base_scale = GetRef<IntImm>(op);
base_index = i;
}
}
}
ICHECK(base_scale.defined());
std::vector<IterSplitExpr> iters;
std::vector<bool> visited(sum_expr->args.size(), false);
PrimExpr expected_scale = base_scale;
for (size_t i = 0; i < sum_expr->args.size(); i++) {
size_t j = i == 0 ? base_index : 0;
for (; j < sum_expr->args.size(); ++j) {
if (!visited[j] && analyzer_->CanProveEqual(sum_expr->args[j]->scale, expected_scale))
break;
}
ICHECK(j != sum_expr->args.size());
visited[j] = true;
iters.push_back(sum_expr->args[j]);
expected_scale *= sum_expr->args[j]->extent;
}
return iters;
}

Analyzer* analyzer_;
Map<IterMapExpr, PrimExpr> backprop_; // the accumulator of backpropgation
Map<Var, PrimExpr> inverse_; // the result of inverse transformation
};

Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
const Array<PrimExpr> outputs) {
Analyzer analyzer;
return InverseAffineIterMapTransformer(&analyzer)(iter_map, outputs);
}

TVM_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap);

} // namespace arith
} // namespace tvm
61 changes: 61 additions & 0 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,66 @@ def test_normalize_iter_map_to_expr():
tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5))


def test_inverse_affine_iter_map():
analyzer = tvm.arith.Analyzer()
l0 = create_iter("l0", 64)
l1 = create_iter("l1", 64)
l2 = create_iter("l2", 64)

# simple case
l0_0, l0_1 = isplit(l0, 16)
l1_0, l1_1 = isplit(l1, 4)
l0_1_l1_1_fused = ifuse([l0_1, l1_1])

iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1]))
outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))]
res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
assert len(res) == 2
l0_inverse = floormod(floordiv(outputs[0], 4), 16) + outputs[1] * 16
l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4
assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0

# compound case
l0_0, l0_1 = isplit(l0, 16)
l1_0, l1_1 = isplit(l1, 4)
l2_1, l2_2 = isplit(l2, 4)
l2_0, l2_1 = isplit(l2_1, 4)

l0_1_l2_1_l1_1_l2_0_fused = ifuse([l0_1, l2_1, l1_1, l2_0])

iter_map = tvm.arith.detect_iter_map(
[l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], var_dom([l0, l1, l2])
)
outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))]
res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
assert len(res) == 3
l0_inverse = floormod(floordiv(outputs[0], 64), 16) + outputs[1] * 16
l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4
l2_inverse = (
floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 + outputs[2]
)

assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0
assert analyzer.simplify(res[l2[0]] - l2_inverse) == 0

# diamond-shape DAG
l0_0, l0_1 = isplit(l0, 16)
l1 = ifuse([l0_1, l0_0])
l1_0, l1_1 = isplit(l1, 8)
l2 = ifuse([l1_1, l1_0])

iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0]))
outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))]
res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
assert len(res) == 1
l1_inverse = floormod(outputs[0], 8) * 8 + floormod(floordiv(outputs[0], 8), 8)
l0_inverse = floormod(l1_inverse, 4) * 16 + floormod(floordiv(l1_inverse, 4), 16)

assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0


if __name__ == "__main__":
test_split()
test_trivial()
Expand All @@ -652,3 +712,4 @@ def test_normalize_iter_map_to_expr():
test_normalize_iter_map_to_expr()
test_subspace_division()
test_complex()
test_inverse_affine_iter_map()

0 comments on commit 20b9837

Please sign in to comment.