-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Invalid vector floormod rewrite. #8616
Comments
Thanks @CallahanRL . Do you mind to propose a fix with the regression test? We might need to look into a similar simplify in the Mod case. We will likely need to circulate over all the possible ramp indices to decide if they are the same(instead of look into the two ends) |
Sorry, no fix to propose, I am too new to this code base. I am not sure the rewrite that hoists the ramp operation is ever safe given the information available (in my case that b1 is even) since I don't see that we know anything about the range of b1 values so don't no the range of values returned by the ramp . |
@CallahanRL Can you print out the value before and after rewrite, then give a counter example where it does not hold? This would help clarify the situation further. In guess the additional info comes from the fact that that
We also know that coeff % c2 in certainly cases. |
cc @jcf94 @merrymercy I did a closer check and seems indeed this rule is not always correct https://github.com/apache/tvm/blob/main/src/arith/rewrite_simplify.cc#L867 Can you double check and propose a fix? I would be great to cross check https://github.com/apache/tvm/blob/main/src/arith/rewrite_simplify.cc#L739 |
@tqchen Thanks, I'll check this! |
Could you provide any script or compute/schedule example to reproduce this bug? @CallahanRL |
@jcf94 Based on the information, I think we can simplify it to the following case. Which may not necessarily hold(try out different versions of x and they may not all reside in the same modulo bucket)
|
ping @jcf94 |
tvm/src/arith/rewrite_simplify.cc
Line 857 in 9f29e2a
I have a counter example to this rewrite. It takes this code
floormod((((threadIdx.x*50) + (b.i.fused.j.fused.inner.outer.inner*8)) + b.i.fused.j.fused.inner.inner.s_1), 20)
to this code:
ramp((floormod(((threadIdx.x*50) + (b.i.fused.j.fused.inner.outer.inner*8)), 20)*32), 32, 8)
The code determines the modular set to have base 0 and coeff 2 but when it computes ramp_max for lanes=8, it incorrectly determines that ramp_min==ramp_max. Perhaps this is always a problem with ceoff*lanes < c2val?
I can't provide a test case because it requires a modified version of vectorize_loop.cc which I can't yet share.
cc @jcf94 @tqchen
The text was updated successfully, but these errors were encountered: