Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Simplifications for floormod(x, 2) #13936

Merged
merged 10 commits into from
Apr 4, 2023

Conversation

Lunderberg
Copy link
Contributor

Because floormod(x,2) has only two possible values, it can be simplified more aggressively than most FloorMod expressions. The additional simplifications are derived from floormod(x,2) + floormod(x+1,2) == 1, which is true for denominator 2, along with the usual floordiv(x,2)*2 + floormod(x,2) == x, which is true for all denominators.

This initially arose from an index expression floormod(x + 1, 2) * 8192, for x ∈ [0, 2). This commit allows the expression to be re-written as x * (-8192) + 8192 and recognized as a strided access.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Feb 9, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

Lunderberg and others added 4 commits February 21, 2023 08:23
Because `floormod(x,2)` has only two possible values, it can be
simplified more aggressively than most FloorMod expressions.  The
additional simplifications are derived from `floormod(x,2) +
floormod(x+1,2) == 1`, which is true for denominator `2`, along with
the usual `floordiv(x,2)*2 + floormod(x,2) == x`, which is true for all
denominators.

This initially arose from an index expression `floormod(x + 1, 2) * 8192`,
for `x ∈ [0, 2)`.  This commit allows the expression to be re-written as
`x * (-8192) + 8192` and recognized as a strided access.
src/arith/rewrite_simplify.cc Outdated Show resolved Hide resolved
src/arith/rewrite_simplify.cc Outdated Show resolved Hide resolved
src/arith/rewrite_simplify.cc Outdated Show resolved Hide resolved
@@ -898,6 +898,11 @@ class IterMapRewriter : public ExprMutator {
PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs);

static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
if (sign < 0 && is_const_int(rhs->extent, 2)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a cover case for the codepath?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do, and added! This was a case that was caught by other unit tests, since the additional RewriteSimplifier rules prevented DetectIterMap from recognizing some patterns after they had been simplfied, but a specific unit test for this case is better than needing to track it down later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: this actually is a bug, see #14571

This makes me think that if other rules leads to the regression. Please do check.

@wrongtest-intellif
Copy link
Contributor

BTW for cancellation like f(x + c1) - f(x + c2), it seems generally the pair may not coincide in summation terms. So we would fail to cancel things like f(x + c1) + g(x + c3) - f(x + c2) even if we do have the rules.

For my curiosity about our tradeoff in simplication, may in the future it allows introduce O(n^2) level rewriting for summation terms to fully utilize rewrite rules? cc @tqchen @Lunderberg

@Lunderberg
Copy link
Contributor Author

Thank you for the review, @wrongtest-intellif , and everything should be updated now.

Regarding cases where it can fail to cancel, that definitely can occur. I have some work on a local branch where I've been trying to make the rewrite rules more general (and to reduce duplication as a benefit), but at the moment it slows down the RewriteSimplifier by about 10x.

One thing that I think could help would be to enforce an ordering between terms for any commutative and associative operator. For example, right now (x + y)//4 - (y + x)//4 doesn't cancel out, because the numerators are different. Sorting by the variable(s) used by each term would at least move related terms together, making it more likely that they'd cancel out.

Copy link
Contributor

@wrongtest-intellif wrongtest-intellif left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@wrongtest-intellif
Copy link
Contributor

some expected tir in ut seems outdated

@Lunderberg
Copy link
Contributor Author

@wrongtest-intellif Thank you for pointing that out, and merged main into the dev branch to resolve.

@Hzfengsy
Copy link
Member

Could you please fix the CI and we can get it in @Lunderberg

@Lunderberg
Copy link
Contributor Author

@Hzfengsy , thank you on the reminder, and it would be good to get this PR finished and merged in.

@junrushao Could I get your assistance on the last CI error? The failing unit tests are in test_cuda_c2d, added in #12043, but I'm not familiar enough with the meta-schedule search space generation to tell if this is a spurious failure.

It looks like this simplification allowed the search space to go down an entirely different route than the unit test was expecting, as the generated sketch outputs the transpose (conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256))) with a completely different shape than the expected output defined in the unit test (conv2d_nhwc: T.Buffer((1, 112, 112, 64))).

@junrushao
Copy link
Member

hey sorry for the delayed response - on pto right now :-)

🤯 wow, i mean any schedule is not supposed to make a conv become transposed conv...do you think it might be a parser issue that it kinda capture the wrong source code? if needed, i can help with debugging 🐛 once i'm back

@Lunderberg
Copy link
Contributor Author

Thank you, and sorry for pinging you on PTO, and this is much lower priority than PTO rest/relaxation.

I wouldn't guess that it was a parsing issue, as the parsed TIR hasn't changed as a result of the additional simplifications. So, maybe it would be more likely to be a change in the output of mod = create_te_workload("C2D", 0).

@Lunderberg
Copy link
Contributor Author

Lunderberg commented Mar 31, 2023

I think I got it. The first issue was that the failure was in test_cuda_t2d, but I had been comparing against the expected result in test_cuda_c2d. The test_cuda_t2d does start with the transpose, so it isn't an issue of a schedule primitive erroneously inserting a transposition. After that, I was able to track down and fix the test breakage. The main issue in understanding the test was that I assumed that it was comparing the module in actual[0].mod to the expected module. Instead, it was only using the trace from the sketch, re-applying it with the hard-coded decisions, and then comparing that result to the expected module.

After that, the only difference was resolved by replacing (var % 2 + 1) // 2 with (var % 2), exactly the simplification enabled by this PR.

@Lunderberg Lunderberg merged commit dba987c into apache:main Apr 4, 2023
@Lunderberg Lunderberg deleted the simplify_floormod_2 branch April 4, 2023 16:41
@tqchen
Copy link
Member

tqchen commented Apr 11, 2023

Note that this PR introduces a bug in #14571

And the bug was intended to fix a regression, which now comes back. This makes me think that we should revisit other rules here. Since there is a possibility of other issues.

Likely we should remove some of the simplification in favor of iter_map because that is more important. @Lunderberg @junrushao let us quickly figure out the situation here.

@tqchen
Copy link
Member

tqchen commented Apr 11, 2023

@wrongtest-intellif sorry for following up late, for term cancelations, we have canonical simplifier which should be able to do such cancelation, so likely we don't need to enhance rewrite simplifier too much here.

@tqchen
Copy link
Member

tqchen commented Apr 11, 2023

added more discussions in #14571

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants