Skip to content

Commit

Permalink
Fix get_canonical_form_slice when lengths are numpy integers
Browse files Browse the repository at this point in the history
Introduced in f9dfe70
  • Loading branch information
ricardoV94 committed Jul 4, 2024
1 parent 781073b commit afc1a6c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def analyze(x):
and is_step_constant
and is_length_constant
):
assert isinstance(length, int)
assert isinstance(length, int | np.integer)
_start, _stop, _step = slice(start, stop, step).indices(length)
if _start <= _stop and _step >= 1:
return slice(_start, _stop, _step), 1
Expand Down
7 changes: 5 additions & 2 deletions tests/tensor/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,11 @@ def test_symbolic_tensor(self):
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
assert res[1] == 1

def test_all_integer(self):
res = get_canonical_form_slice(slice(1, 5, 2), 7)
@pytest.mark.parametrize("int_fn", [int, np.int64, as_tensor, as_scalar])
def test_all_integer(self, int_fn):
res = get_canonical_form_slice(
slice(int_fn(1), int_fn(5), int_fn(2)), int_fn(7)
)
assert isinstance(res[0], slice)
assert res[1] == 1

Expand Down

0 comments on commit afc1a6c

Please sign in to comment.