-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Forbid runtime broadcasting in Elemwise #372
Forbid runtime broadcasting in Elemwise #372
Conversation
d1a0ff7
to
3c1d876
Compare
3beaaff
to
60b6d6f
Compare
|
||
out_shape = pytensor.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True) | ||
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could just use this function: https://github.com/pymc-devs/pytensor/blob/main/pytensor/tensor/extra_ops.py#L1465
The make_node
method doesn't seem to properly take into account the broadcastable
flag either though, maybe that needs an update as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't want to introduce checks or comparison between shapes, which that function does. This allows it to return a more optimized graph like Theano used to by assuming no invalid shapes were provided
The question then is whether we want to refactor that helper to do the same when arrays_are_shapes=False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the make_node
is correct insofar as it uses static shape and it's not possible to have broadcastable=False and shape=1
That one still requires some thinking and would be tackled in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't want to introduce checks or comparison between shapes, which that function does. This allows it to return a more optimized graph like Theano used to by assuming no invalid shapes were provided
So we allow undefined behavior in the shapes and in rewrites? I'm not sure I see that much downside with having that check here...
But at least I think we shouldn't have this logic in both places. Maybe the function should have a flag if it should return shape with or without checks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am thinking we should add a config.assume_shapes_correct
flag (default to True) to toggle that behavior in both shape_inference and rewrites that can return simplified cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually that helper works differently in that it expects either shapes or arrays, but here we are combining information from both shapes and arrays so it would require some refactoring. We don't want to simply pass node.inputs
since infer_shape
wants us to return a graph from ishapes
.
I don't know if that is the right place to implement this logic since it is a user facing function. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I reverted to using the helper. Things are a bit weird in shape compilation because it will just use the static type shape of the node if that's available. Because the Elemwise make_node assumes valid shapes, the check introduced by infer_shape
is only triggered when all dims are None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not much we can do about that then I think without a major rewrite of the shape handling...
""" | ||
@staticmethod | ||
def check_runtime_broadcast_error(mode): | ||
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'd feel better if those tests were a bit more complete, ie inputs with different lengths etc...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean different runtime shapes (3 vs 5)? I am sure there are old tests for that already.
There are tests for invalid static shapes.
This one test was added when we specifically allowed runtime broadcasting in Aesara. The other thing I considered doing was to just remove it.
I'll confirm other tests for invalid shapes exist and maybe combine with this if they are not too convoluted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test for incompatible non-broadcast shapes. Let me know if you meant something else
60b6d6f
to
a21ae05
Compare
a21ae05
to
b2c2743
Compare
a26e46b
to
f3ad19a
Compare
@@ -35,15 +35,20 @@ def compute_itershape( | |||
with builder.if_then( | |||
builder.icmp_unsigned("!=", length, shape[i]), likely=False | |||
): | |||
with builder.if_else(builder.icmp_unsigned("==", length, one)) as ( | |||
with builder.if_else( | |||
builder.or_( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weird the changes cause a SegmentationFault on the BroadcastTo numba test, but only on python 3.11? I couldn't replicate locally on 3.8 either. https://github.com/pymc-devs/pytensor/actions/runs/5507826611/jobs/10039563156?pr=372
Did I do something obviously wrong @aseyboldt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see anything wrong, I can try locally with py311 and if I can reproduce I can try to look at it in a debugger (with no debugging symbols, but well...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you can quickly try to reproduce that's already helpful (even if you don't dig down)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No luck so far, for me the tests run just fine...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It reliably segfaults here. I'll remove the numba changes for now and put the new test as an xfail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it segfault during the test_BroadcastTo test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes... tests/link/numba/test_extra_ops.py::test_BroadcastTo[x0-shape0]
.
https://github.com/pymc-devs/pytensor/actions/runs/5507826611/jobs/10039563156?pr=372#step:6:281
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I don't see how it could be a problem in those tests. There is nothing else in the compiled graph other than the BroadcastTo
eb98809
to
1ab333d
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #372 +/- ##
==========================================
- Coverage 80.40% 80.40% -0.01%
==========================================
Files 156 156
Lines 45401 45390 -11
Branches 11106 11103 -3
==========================================
- Hits 36505 36496 -9
Misses 6689 6689
+ Partials 2207 2205 -2
|
1ab333d
to
28b3b46
Compare
28b3b46
to
d044271
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good :-)
Related to #100
Related to #149
Related to #371