-
Notifications
You must be signed in to change notification settings - Fork 147
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(frontend): add support for np.min and np.max
- Loading branch information
1 parent
bea18cf
commit 7f6f90c
Showing
15 changed files
with
705 additions
and
274 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
210 changes: 210 additions & 0 deletions
210
frontends/concrete-python/concrete/fhe/mlir/operations/min_max.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
""" | ||
Conversion of min and max operations. | ||
""" | ||
|
||
from copy import deepcopy | ||
from typing import Sequence, Set, Tuple, Union | ||
|
||
import numpy as np | ||
|
||
from ..context import Context | ||
from ..conversion import Conversion, ConversionType | ||
|
||
|
||
def min_max( | ||
ctx: Context, | ||
resulting_type: ConversionType, | ||
x: Conversion, | ||
axes: Union[int, np.integer, Sequence[Union[int, np.integer]]] = (), | ||
keep_dims: bool = False, | ||
*, | ||
operation: str, | ||
) -> Conversion: | ||
""" | ||
Convert min or max operation. | ||
Args: | ||
ctx (Context): | ||
conversion context | ||
resulting_type (ConversionType): | ||
resulting type of the operation | ||
x (Conversion): | ||
input of the operation | ||
axes (Union[int, np.integer, Sequence[Union[int, np.integer]]], default = ()): | ||
axes to reduce over | ||
keep_dims (bool, default = False): | ||
whether to keep the reduced axes | ||
operation (str): | ||
"min" or "max" | ||
Returns: | ||
Conversion: | ||
np.min or np.max on x depending on operation | ||
""" | ||
|
||
if x.is_clear: | ||
highlights = { | ||
x.origin: "value is clear", | ||
ctx.converting: f"but computing {operation} of clear values is not supported", | ||
} | ||
ctx.error(highlights) | ||
|
||
if x.is_scalar: | ||
return x | ||
|
||
if axes is None: | ||
axes = [] | ||
|
||
axes = list( | ||
set([int(axes)] if isinstance(axes, (int, np.integer)) else [int(axis) for axis in axes]) | ||
) | ||
|
||
input_dimensions = len(x.shape) | ||
for i, axis in enumerate(axes): | ||
if axis < 0: | ||
axes[i] += input_dimensions | ||
|
||
if len(axes) == 0 or len(axes) == len(x.shape): | ||
x = ctx.flatten(x) | ||
|
||
resulting_element_type = ctx.element_typeof(resulting_type) | ||
if len(x.shape) == 1: | ||
result = reduce(ctx, resulting_element_type, x, operation=operation) | ||
if keep_dims: | ||
result = ctx.reshape(result, shape=resulting_type.shape) | ||
return result | ||
|
||
class Mock: | ||
indices: Set[Tuple[int, ...]] | ||
|
||
def __init__(self, index: Tuple[int, ...]): | ||
self.indices = {index} | ||
|
||
def __repr__(self) -> str: | ||
return f"{self.indices}" | ||
|
||
def combine(self, other: "Mock") -> "Mock": | ||
result = deepcopy(self) | ||
for index in other.indices: | ||
result.indices.add(index) | ||
return result | ||
|
||
mock_input = [] | ||
for index in np.ndindex(x.shape): | ||
mock_input.append(Mock(index)) | ||
mock_input = np.array(mock_input).reshape(x.shape) | ||
|
||
def accumulate(mock1, mock2): | ||
return mock1.combine(mock2) | ||
|
||
mock_output = np.frompyfunc(accumulate, 2, 1).reduce( | ||
mock_input, | ||
axis=tuple(axes), | ||
keepdims=keep_dims, | ||
) | ||
|
||
sample_mock = mock_output.flat[0] | ||
|
||
sample_mock_indices = sample_mock.indices | ||
sample_mock_index = next(iter(sample_mock_indices)) | ||
|
||
number_of_comparisons = len(sample_mock_indices) | ||
number_of_indices = len(sample_mock_index) | ||
index_shape = resulting_type.shape | ||
|
||
to_compare = [] | ||
for _ in range(number_of_comparisons): | ||
indices = [] | ||
for _ in range(number_of_indices): | ||
index = np.zeros(index_shape, dtype=np.int64) # type: ignore | ||
indices.append(index) | ||
to_compare.append(tuple(indices)) | ||
|
||
for position in np.ndindex(mock_output.shape): | ||
mock_indices = list(mock_output[position].indices) | ||
for i in range(number_of_comparisons): | ||
for j in range(number_of_indices): | ||
to_compare[i][j][position] = mock_indices[i][j] # type: ignore | ||
|
||
extracted = [] | ||
extracted_type = ctx.tensor(ctx.element_typeof(x), shape=resulting_type.shape) | ||
|
||
for index in to_compare: # type: ignore | ||
extracted.append(ctx.index(extracted_type, x, index)) | ||
|
||
while len(extracted) > 1: | ||
a = extracted.pop() | ||
b = extracted.pop() | ||
|
||
if operation == "min": | ||
c = ctx.minimum(resulting_type, a, b) | ||
else: | ||
c = ctx.maximum(resulting_type, a, b) | ||
|
||
c.set_original_bit_width(x.original_bit_width) | ||
extracted.insert(0, c) | ||
|
||
return extracted[0] | ||
|
||
|
||
def reduce( | ||
ctx: Context, | ||
resulting_type: ConversionType, | ||
values: Conversion, | ||
*, | ||
operation: str, | ||
) -> Conversion: | ||
""" | ||
Reduce a vector of values to its min/max value. | ||
""" | ||
|
||
assert operation in {"min", "max"} | ||
|
||
assert values.is_tensor | ||
assert len(values.shape) == 1 | ||
assert values.is_encrypted | ||
|
||
assert resulting_type.is_scalar | ||
assert resulting_type.is_encrypted | ||
|
||
middle = values.shape[0] // 2 | ||
values_element_type = ctx.element_typeof(values) | ||
|
||
half_type = ( | ||
ctx.tensor(values_element_type, shape=(middle,)) if middle != 1 else values_element_type | ||
) | ||
reduced_type = ctx.tensor(resulting_type, shape=(middle,)) if middle != 1 else resulting_type | ||
|
||
if middle == 1: | ||
first_half = ctx.index(half_type, values, index=[0]) | ||
second_half = ctx.index(half_type, values, index=[1]) | ||
else: | ||
first_half = ctx.index(half_type, values, index=[slice(0, middle)]) | ||
second_half = ctx.index(half_type, values, index=[slice(middle, 2 * middle)]) | ||
|
||
if operation == "min": | ||
reduced = ctx.minimum(reduced_type, first_half, second_half) | ||
else: | ||
reduced = ctx.maximum(reduced_type, first_half, second_half) | ||
|
||
reduced.set_original_bit_width(values.original_bit_width) | ||
|
||
result = ( | ||
reduced if reduced.is_scalar else reduce(ctx, resulting_type, reduced, operation=operation) | ||
) | ||
|
||
if values.shape[0] % 2 == 1: | ||
last_value = ctx.index(values_element_type, values, index=[-1]) | ||
result = ( | ||
ctx.minimum(resulting_type, result, last_value) | ||
if operation == "min" | ||
else ctx.maximum(resulting_type, result, last_value) | ||
) | ||
result.set_original_bit_width(values.original_bit_width) | ||
|
||
return result |
Oops, something went wrong.