Skip to content

Commit

Permalink
fix(frontend-python): stop crashing on scalar squeeze
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Feb 2, 2024
1 parent 477a982 commit 966a1fa
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
3 changes: 3 additions & 0 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,9 @@ def squeeze(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversi
# if the output shape is (), it means (1, 1, ..., 1, 1) is squeezed
# and the result is a scalar, so we need to do indexing, not reshape
if node.output.shape == ():
if preds[0].shape == ():
return preds[0]

assert all(size == 1 for size in preds[0].shape)
index = (0,) * len(preds[0].shape)
return ctx.index_static(ctx.typeof(node), preds[0], index)
Expand Down
7 changes: 7 additions & 0 deletions frontends/concrete-python/tests/execution/test_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,13 @@ def copy_modify(x):
},
id="x ** 3",
),
pytest.param(
lambda x: np.squeeze(x),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": ()},
},
id="np.squeeze(x)",
),
pytest.param(
lambda x: np.squeeze(x),
{
Expand Down

0 comments on commit 966a1fa

Please sign in to comment.