-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add selective_scan compilable/exportable custom_op #651
base: main
Are you sure you want to change the base?
Conversation
@Hprairie Can you take a look at this? |
pytest -k
FAILED tests/ops/test_selective_scan.py::test_selective_scan[True-True-1-True-True-True-True-True-128-itype0-wtype0-compiled] - AssertionError: expected size 2==2, stride 64==8192 at dim=0; expected size 4==4, stride 16==2048 at dim=1; expected size 1==128, stride 16==16 at dim=2
FAILED tests/ops/test_selective_scan.py::test_selective_scan[True-True-1-True-True-True-True-True-256-itype0-wtype0-compiled] - AssertionError: expected size 2==2, stride 64==16384 at dim=0; expected size 4==4, stride 16==4096 at dim=1; expected size 1==256, stride 16==16 at dim=2
FAILED tests/ops/test_selective_scan.py::test_selective_scan[True-True-1-True-True-True-True-True-512-itype0-wtype0-compiled] - AssertionError: expected size 2==2, stride 64==32768 at dim=0; expected size 4==4, stride 16==8192 at dim=1; expected size 1==512, stride 16==16 at dim=2
FAILED tests/ops/test_selective_scan.py::test_selective_scan[True-True-1-True-True-True-True-True-1024-itype0-wtype0-compiled] - AssertionError: expected size 2==2, stride 64==65536 at dim=0; expected size 4==4, stride 16==16384 at dim=1; expected size 1==1024, stride 16==16 at dim=2
FAILED tests/ops/test_selective_scan.py::test_selective_scan[True-True-1-True-True-True-True-True-2048-itype0-wtype0-compiled] - AssertionError: expected size 2==2, stride 64==131072 at dim=0; expected size 4==4, stride 16==32768 at dim=1; expected size 1==2048, stride 16==16 at dim=2
FAILED tests/ops/test_selective_scan.py::test_selective_scan[True-True-1-True-True-True-True-True-4096-itype0-wtype0-compiled] - AssertionError: expected size 2==2, stride 128==262144 at dim=0; expected size 4==4, stride 32==65536 at dim=1; expected size 2==4096, stride 16==16 at dim=2
FAILED tests/ops/test_selective_scan.py::test_selective_scan[True-True-2-True-True-True-True-True-128-itype0-wtype0-compiled] - AssertionError: expected size 2==2, stride 64==8192 at dim=0; expected size 4==4, stride 16==2048 at dim=1; expected size 1==128, stride 16==16 at dim=2
FAILED tests/ops/test_selective_scan.py::test_selective_scan[True-True-2-True-True-True-True-True-256-itype0-wtype0-compiled] - AssertionError: expected size 2==2, stride 64==16384 at dim=0; expected size 4==4, stride 16==4096 at dim=1; expected size 1==256, stride 16==16 at dim=2
FAILED tests/ops/test_selective_scan.py::test_selective_scan[True-True-2-True-True-True-True-True-512-itype0-wtype0-compiled] - AssertionError: expected size 2==2, stride 64==32768 at dim=0; expected size 4==4, stride 16==8192 at dim=1; expected size 1==512, stride 16==16 at dim=2
FAILED tests/ops/test_selective_scan.py::test_selective_scan[True-True-2-True-True-True-True-True-1024-itype0-wtype0-compiled] - AssertionError: expected size 2==2, stride 64==65536 at dim=0; expected size 4==4, stride 16==16384 at dim=1; expected size 1==1024, stride 16==16 at dim=2
FAILED tests/ops/test_selective_scan.py::test_selective_scan[True-True-2-True-True-True-True-True-2048-itype0-wtype0-compiled] - AssertionError: expected size 2==2, stride 64==131072 at dim=0; expected size 4==4, stride 16==32768 at dim=1; expected size 1==2048, stride 16==16 at dim=2
FAILED tests/ops/test_selective_scan.py::test_selective_scan[True-True-2-True-True-True-True-True-4096-itype0-wtype0-compiled] - AssertionError: expected size 2==2, stride 128==262144 at dim=0; expected size 4==4, stride 32==65536 at dim=1; expected size 2==4096, stride 16==16 at dim=2 |
Yes I'll take a look later today 👍 |
|
||
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) | ||
has_z = z is not None | ||
final_out = rest[0].clone() if has_z else out.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you cloning the tensor right here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without the extra clone we get (not in the test but on a real training session)
RuntimeError: selective_scan_fwd (with implementation in <module 'torch._library.custom_ops' from '/opt/conda/lib/python3.11/site-packages/torch/_library/custom_ops.py'>): The output of this custom operator (1) must not also be an input to this custom operator and (2) may not alias any inputs to this custom operator or other returns. The most common way to trigger this error is if we have y = custom_op(x) and y and x are the same Tensor. Please instead return a clone of the offending output tensor(s) (e.g. return x.clone()) or refactor the custom operator to not return y.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh that seems weird to me no? In the CPP code we are clearly creating a new tensor for out and out_z which are independent from any input tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is coming from here:
https://github.com/pytorch/pytorch/blob/main/torch/_library/utils.py#L349C5-L372
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aliasing... other returns.
So I think that one candidate is that the same final_out
return is aliasing different buffers right?
Also, I am curious if you have used |
If you have a list of required opcheck sample inputs we could add an opcheck test to this test |
What do you think it is causing the failure of the compiled test at #651 (comment) ? |
1296f82
to
f9945b2
Compare
@Hprairie |
No description provided.