Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Merge surjective/non-surjective iter mapping detections #11287

Merged

Conversation

wrongtest-intellif
Copy link
Contributor

@wrongtest-intellif wrongtest-intellif commented May 12, 2022

Update a simplify rule when c2 is nonzero, original rule is covered with constant folding.
floormod(x * c1, c2) =>
floormod(x * (floordiv(c1, c2) * c2 + floormod(c1, c2)), c2) =>
floormod(x * floormod(c1, c2)), c2)

This is useful for certain non-perfect tiling case, where there are dynamic loop ranges which is actually constant wrt outer loop domain.

For example, floordiv(floormod(x * 360, 16) + 359, 16) with x in [0, 2) can finally reduce to constant 22, since the rule could eliminate the multiply factor 360 to 360 % 16, activating more available rules.

Unfortunately the working example on tiling encounter a region_cover related problem again.

@wrongtest-intellif
Copy link
Contributor Author

where (or is it neccesary) to write testcase on analyzer.simplify()'s behavior ?

@tqchen
Copy link
Member

tqchen commented May 12, 2022

@wrongtest yes we should cover simplifier's behavior, but the rewrite_simplifier testcase should be sufficient for now

@wrongtest-intellif
Copy link
Contributor Author

The failed compute_at's region cover check possibly could get fixed by #11235 improvement on iteration analysis.

@vinx13
Copy link
Member

vinx13 commented May 13, 2022

LGTM, let's have #11235 merged first

@wrongtest-intellif
Copy link
Contributor Author

To enable region cover proof on such cases, we need to lift DetectIterMapPadded to standard implementation for DetectIterMap.

@Hzfengsy
Copy link
Member

A gentle ping for @vinx13

@wrongtest-intellif wrongtest-intellif changed the title [Arith][Simplify] Extend simplify rule for floormod(x * c1 + y, c2) [Arith] Merge surjective/non-surjective iter mapping detections May 23, 2022
@vinx13
Copy link
Member

vinx13 commented May 23, 2022

@wrongtest Can you elaborate the usage of DetectIterMapPadded in our analysis? Do we need the padding information?

also cc @Lunderberg for DetectIterMap changes

@wrongtest-intellif
Copy link
Contributor Author

wrongtest-intellif commented May 23, 2022

usage of DetectIterMapPadded in our analysis

Try merge DetectIterMapPadded and DetectIterMap into the same interface, and replace option require_bijective with a new enum IterMapLevel with three alternatives:

  • Bijective
    for original behavior on require_bijective=true
  • Surjective
    for original behavior on require_bijective=false
  • Injective
    for behavior of DetectIterMapPadded

The #11235 brings great way to analyze iteration form like (x + 7) // 16 with padding. The surjective checking of DetectIterMap is used many where (like region cover check after schedule step), however, it can not leverage this analysis now, it is checked to take no padding predicate.

I think actually, as an example, though (x + 7) // 16 is rewritten into a "padded" iteration form, we could still prove the mapping is surjective, since the left and right padding is no more than the largest divisor by how we pad it. If we extent CheckMapping rules carefully, we may be able to distinguish that

  • (x + 7) // 16 -> surjective
    • this is the access index form in my original failed case
  • (x + 7) % 16 -> surjective [0, 16) if x's extent is larger than 16
  • ((x + 7) // 16, (x + 7) % 16) -> non-surjective

So from my perspective it would be great if we have a uniform interface and share same padding based analysis. Ideally padding_predicate is not affected for IndexMap functionalities, and it should not introduce false positives in bijective/surjective checking. I'm still working to check more unittest cases and adapt padding analysis if surjective mapping is required.

Do we need the padding information

No, original usages of DetectIterMap do not require padding_predicate as before. But we need prove surjective-ness if padding is added for new iteration form supported by original DetectIterMapPadded .

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I focused on the DetectIterMap changes, and especially like the merging and de-duplication. Mostly just some nitpicks here and there.

include/tvm/arith/iter_affine_map.h Outdated Show resolved Hide resolved
include/tvm/arith/iter_affine_map.h Show resolved Hide resolved
}

// Step0.1: Check each index to determine required padding
bool allow_padding = !require_bijective;
bool allow_padding = check_level != IterMapLevel::Bijective;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would enable padding for IterMapLevel::Surjective, which I don't think is correct. Since padding is any output value for which no input value exists, any introduction of padding wouldn't be surjective.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is the claim~ I try to change padding to iter mark itself.

For example,(x + 7) x in [0, 8) => IterMark(IterSplit(IterSum({x}, 7), lower_factor=1, extent=16, scale=1), extent=16 with left_pad=7, right_pad=1

Then (x + 7) // 8 is mapped to range [0, extent//2) == [0, 2), though we have padding into iter mark, the IterSplit's range can be achieved when we only iterate x in it's original domain: (0 + 7) // 8 = 0, (7 + 7) // 8 = 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, and that does maintain surjectivity for a single index. I'm not entirely sure for the case of two indices, though. For the same x ∈ [0,8), the indices [(x+7)//8, (x+7)%8] would have the same padding left_pad=7 and right_pad=1. Even though each individual index can take any value in the output ((x+7)//8 ∈[0,2) and (x+7)%8 ∈ [0,8)), there are some coordinate pairs that cannot be generated for any value of x (e.g. [0,0] and [1,7]).

Copy link
Contributor Author

@wrongtest-intellif wrongtest-intellif May 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree! This is where we should be careful. In CheckMapping with surjective mode when padding exists, we check padded // LCM and padded % LCM(or it's sub-splits) must not both exists. The case below depict this check:

sum = 80 + y
dom_map = var_dom([(y, 176)])

# (80 + y) // 32 itself could be surjective
assert_iter_sum_pattern(
    {fld(sum, 32): (6, 2, 1)},
    dom_map,
 )

# (80 + y) % 2, ((80 + y) // 2) % 16) could be surjective,
# since they can be seen as sub-splits of (80 + y) % 32
assert_iter_sum_pattern(
    {flm(fld(sum, 2), 16): (16, 0, 1), flm(sum, 2): (2, 0, 1)},
    dom_map,
)

# but (80 + y) // 32, (80 + y) % 32 are not surjective
assert_iter_sum_failure({fld(sum, 32), flm(sum, 32)}, dom_map)

Other kinds of negatives like (80 + y) // 32, (80 + y) // 4 would be banned by existing checking rule.

requires_padding_ = requires_padding_ || (left_padding_introduced || right_padding_introduced);
padding_predicate_ = padding_predicate_ || (left_padding_predicate || right_padding_predicate);
}
// ICHECK(CanProveDivisible(info.padded->extent, split->lower_factor));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these // ICHECK lines be either uncommented or removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like to check the padding factor is divisible by split->lower_factor, then the commented check can be ensured from context. I found it may fail unfortunetely due to simplifier's ability limitation when the padded extent contain complex flm/fld expressions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I noticed that there were also some simplification steps that needed to increase the number of iterations performed. Is the failure to prove divisibility related, since CanProveDivisible only uses the default of 2 steps?

(I'm also wondering if the default for Analyzer::Simplify should be to iterate until it the simplification converges, rather than using a fixed number of steps.)

python/tvm/arith/iter_affine_map.py Outdated Show resolved Hide resolved
@junrushao
Copy link
Member

Quick note: #11235 is merged

@wrongtest-intellif wrongtest-intellif force-pushed the simplify_floormod_after_multiply branch 2 times, most recently from a1a2086 to 1c15f4d Compare May 25, 2022 09:45
@wrongtest-intellif wrongtest-intellif force-pushed the simplify_floormod_after_multiply branch 2 times, most recently from f4280f0 to 001ed50 Compare May 25, 2022 20:36
@@ -1659,7 +1676,7 @@ bool IterMapRewriter::CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs
PrimExpr divisor = normalizer.Convert(rhs);

return analyzer_->CanProveEqual(dividend, divisor) ||
analyzer_->CanProve(floormod(dividend, divisor) == 0);
analyzer_->CanProve(analyzer_->Simplify(floormod(dividend, divisor), 8) == 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be great to have some explanations here that it need more simplification steps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, that is something forget to revert. There is some cases the division could not be proved like
floormod(0 + -x * 8, x) == 0, floormod(8*c1*c2, c1) == 0, even we increate iteration num. They get work-around here and there, for example,

if (CanProveDivisible(right_edge, divisor)) {
    right_pad = 0;
  } else {
    right_pad = analyzer_->Simplify(floormod(-right_edge, divisor));
  }

@Lunderberg suggest Simplify could be optimized to iterate until reaching fix point. But now it is suffice to work on existing tests.

@vinx13
Copy link
Member

vinx13 commented May 31, 2022

Could you also update this line https://github.com/apache/tvm/blob/main/src/tir/schedule/primitive/layout_transformation.cc#L395? There are some conflict that CI didn't catch because of concurrent merge

- determine case like x % 16, x in [0, 5) to be non-surjective, since usages may treat the region extent as 16 by mistake.
- skip second round of rewrite when there is no padding
- fix some typo in comments
@vinx13 vinx13 merged commit c1b22ee into apache:main May 31, 2022
@junrushao
Copy link
Member

One bug from my side is magically fixed by this PR!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants