Skip to content

Commit

Permalink
feat(frontend-python): fancy indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Dec 18, 2023
1 parent ddd85c4 commit e2b15e4
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 4 deletions.
1 change: 1 addition & 0 deletions frontends/concrete-python/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ disable=raw-checker-failed,
too-many-locals,
too-many-public-methods,
too-many-statements,
too-many-return-statements,
unnecessary-lambda-assignment,
use-implicit-booleaness-not-comparison,
wrong-import-order,
Expand Down
39 changes: 38 additions & 1 deletion frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,13 +2055,16 @@ def index_static(
self,
resulting_type: ConversionType,
x: Conversion,
index: Sequence[Union[int, np.integer, slice]],
index: Sequence[Union[int, np.integer, slice, np.ndarray, list]],
) -> Conversion:
assert self.is_bit_width_compatible(resulting_type, x)
assert resulting_type.is_encrypted == x.is_encrypted

x = self.to_signedness(x, of=resulting_type)

if any(isinstance(indexing_element, (list, np.ndarray)) for indexing_element in index):
return self.index_static_fancy(resulting_type, x, index)

index = list(index)
while len(index) < len(x.shape):
index.append(slice(None, None, None))
Expand Down Expand Up @@ -2182,6 +2185,40 @@ def index_static(
),
)

def index_static_fancy(
self,
resulting_type: ConversionType,
x: Conversion,
index: Sequence[Union[int, np.integer, slice, np.ndarray, list]],
) -> Conversion:
resulting_element_type = (self.eint if resulting_type.is_unsigned else self.esint)(
resulting_type.bit_width
)

result = self.zeros(resulting_type)
for destination_position in np.ndindex(resulting_type.shape):
source_position = []
for indexing_element in index:
if isinstance(indexing_element, (int, np.integer)):
source_position.append(indexing_element)

elif isinstance(indexing_element, (list, np.ndarray)):
position = indexing_element[destination_position[0]]
for n in range(1, len(destination_position)):
position = position[destination_position[n]]
source_position.append(position)

else: # pragma: no cover
message = f"invalid indexing element of type {type(indexing_element)}"
raise AssertionError(message)

source_position = tuple(source_position)

element = self.index_static(resulting_element_type, x, source_position)
result = self.assign_static(resulting_type, result, element, destination_position)

return result

def less(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> Conversion:
return self.comparison(resulting_type, x, y, accept={Comparison.LESS})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,9 @@ def format_indexing_element(indexing_element: Union[int, np.integer, slice, Any]
result += ":"
result += str(indexing_element.step)
else:
result += str(indexing_element)
result += (
str(indexing_element)
if not isinstance(indexing_element, np.ndarray)
else str(indexing_element.tolist())
)
return result.replace("\n", " ")
26 changes: 24 additions & 2 deletions frontends/concrete-python/concrete/fhe/tracing/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,13 @@ def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> "Tracer":
def __getitem__(
self,
index: Union[
int, np.integer, slice, "Tracer", Tuple[Union[int, np.integer, slice, "Tracer"], ...]
int,
np.integer,
slice,
np.ndarray,
list,
Tuple[Union[int, np.integer, slice, np.ndarray, list, "Tracer"], ...],
"Tracer",
],
) -> "Tracer":
if (
Expand All @@ -762,11 +768,27 @@ def __getitem__(
if not isinstance(index, tuple):
index = (index,)

is_fancy = False
has_slices = False

reject = False
for indexing_element in index:
if isinstance(indexing_element, list):
try:
indexing_element = np.array(indexing_element)
except Exception: # pylint: disable=broad-except
reject = True
break

if isinstance(indexing_element, np.ndarray):
is_fancy = True
reject = not np.issubdtype(indexing_element.dtype, np.integer)
continue

valid = isinstance(indexing_element, (int, np.integer, slice))

if isinstance(indexing_element, slice): # noqa: SIM102
has_slices = True
if (
not (
indexing_element.start is None
Expand All @@ -787,7 +809,7 @@ def __getitem__(
reject = True
break

if reject:
if reject or (is_fancy and has_slices):
indexing_elements = [
format_indexing_element(indexing_element) for indexing_element in index
]
Expand Down
57 changes: 57 additions & 0 deletions frontends/concrete-python/tests/execution/test_static_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,63 @@
lambda x: x[slice(np.int64(8), np.int64(2), np.int64(-2))],
id="x[8:2:-2] where x.shape == (10,)",
),
pytest.param(
(5,),
lambda x: x[[3, 1, 2]],
id="x[[3, 1, 2]] where x.shape == (5,)",
),
pytest.param(
(5,),
lambda x: x[
[
[3, 0],
[1, 2],
]
],
id="x[[[3, 0], [1, 2]]] where x.shape == (5,)",
),
pytest.param(
(5, 4),
lambda x: x[
[0, 0, 4, 4],
[0, 3, 0, 3],
],
id="x[[0, 0, 4, 4], [0, 3, 0, 3]] where x.shape == (5, 4)",
),
pytest.param(
(5, 4),
lambda x: x[
0,
[0, 3, 0, 3],
],
id="x[0, [0, 3, 0, 3]] where x.shape == (5, 4)",
),
pytest.param(
(5, 4),
lambda x: x[
[0, 0, 4, 4],
0,
],
id="x[[0, 0, 4, 4], 0] where x.shape == (5, 4)",
),
pytest.param(
(5, 4),
lambda x: x[
[[0, 0], [4, 4]],
[[0, 3], [0, 3]],
],
id="x[[[0, 0], [4, 4]], [[0, 3], [0, 3]]] where x.shape == (5, 4)",
),
pytest.param(
(5, 4),
lambda x: x[0, [[0, 3], [0, 3]]],
id="x[0, [[0, 3], [0, 3]]] where x.shape == (5, 4)",
),
pytest.param(
(5, 4),
lambda x: x[[[0, 3], [0, 3]], 0],
id="x[[[0, 3], [0, 3]], 0] where x.shape == (5, 4)",
),
],
)
def test_static_indexing(shape, function, helpers):
Expand Down

0 comments on commit e2b15e4

Please sign in to comment.