From afc1a6caf595cb1bf08c9cb1901e1b9e0bb98efc Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 4 Jul 2024 11:14:55 +0200 Subject: [PATCH] Fix get_canonical_form_slice when lengths are numpy integers Introduced in f9dfe702dac141be4f006e9bfa042be9ad64ce16 --- pytensor/tensor/subtensor.py | 2 +- tests/tensor/test_subtensor.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 2cec476c4a..59961e7c2f 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -325,7 +325,7 @@ def analyze(x): and is_step_constant and is_length_constant ): - assert isinstance(length, int) + assert isinstance(length, int | np.integer) _start, _stop, _step = slice(start, stop, step).indices(length) if _start <= _stop and _step >= 1: return slice(_start, _stop, _step), 1 diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index f4ba58e26a..9cb730aafd 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -154,8 +154,11 @@ def test_symbolic_tensor(self): assert isinstance(res[0].owner.op, ptb.ScalarFromTensor) assert res[1] == 1 - def test_all_integer(self): - res = get_canonical_form_slice(slice(1, 5, 2), 7) + @pytest.mark.parametrize("int_fn", [int, np.int64, as_tensor, as_scalar]) + def test_all_integer(self, int_fn): + res = get_canonical_form_slice( + slice(int_fn(1), int_fn(5), int_fn(2)), int_fn(7) + ) assert isinstance(res[0], slice) assert res[1] == 1