Skip to content

Commit

Permalink
Revert numba runtime broadcast check
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 11, 2023
1 parent 8a46960 commit 28b3b46
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
13 changes: 4 additions & 9 deletions pytensor/link/numba/dispatch/elemwise_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,15 @@ def compute_itershape(
with builder.if_then(
builder.icmp_unsigned("!=", length, shape[i]), likely=False
):
with builder.if_else(
builder.or_(
builder.icmp_unsigned("==", length, one),
builder.icmp_unsigned("==", shape[i], one),
)
) as (
with builder.if_else(builder.icmp_unsigned("==", length, one)) as (
then,
otherwise,
):
with then:
msg = (
"Runtime broadcasting not allowed. "
"One input had a distinct dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
f"Incompatible shapes for input {j} and axis {i} of "
f"elemwise. Input {j} has shape 1, but is not statically "
"known to have shape 1, and thus not broadcastable."
)
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
with otherwise:
Expand Down
1 change: 1 addition & 0 deletions tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
compare_numba_and_py(out_fg, input_vals)


@pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults")
def test_elemwise_runtime_shape_error():
TestElemwise.check_runtime_shapes_error(get_mode("NUMBA"))

Expand Down

0 comments on commit 28b3b46

Please sign in to comment.