Skip to content

Commit

Permalink
feat(frontend): add support for np.min and np.max
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Aug 28, 2024
1 parent bea18cf commit 7f6f90c
Show file tree
Hide file tree
Showing 15 changed files with 705 additions and 274 deletions.
24 changes: 0 additions & 24 deletions docs/core-features/workarounds.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,6 @@ This document introduces several common techniques for optimizing code to fit Fu
All code snippets provided here are temporary workarounds. In future versions of Concrete, some functions described here could be directly available in a more generic and efficient form. These code snippets are coming from support answers in our [community forum](https://community.zama.ai)
{% endhint %}

## Minimum/Maximum for multiple values

Concrete supports `np.minimum`/`np.maximum` natively, but not `np.min`/`np.max` yet. To achieve the functionality, you can do a series of `np.minimum`/`np.maximum`s:

```python
import numpy as np
from concrete import fhe

@fhe.compiler({"args": "encrypted"})
def fhe_min(args):
remaining = list(args)
while len(remaining) > 1:
a = remaining.pop()
b = remaining.pop()
remaining.insert(0, np.minimum(a, b))
return remaining[0]

inputset = [np.random.randint(0, 16, size=5) for _ in range(50)]
circuit = fhe_min.compile(inputset, min_max_strategy_preference=fhe.MinMaxStrategy.ONE_TLU_PROMOTED)

x1, x2, x3, x4, x5 = np.random.randint(0, 16, size=5)
assert circuit.encrypt_run_decrypt([x1, x2, x3, x4, x5]) == min(x1, x2, x3, x4, x5)
```

## Retrieving a value within an encrypted array with an encrypted index
This example demonstrates how to retrieve a value from an array using an encrypted index. The method creates a "selection" array filled with `0`s except for the requested index, which will be `1`. It then sums the products of all array values with this selection array:

Expand Down
27 changes: 25 additions & 2 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,7 +2197,11 @@ def constant(
if cached_conversion is None:
cached_conversion = Conversion(
self.converting,
arith.ConstantOp(resulting_type, attribute, loc=self.location()),
arith.ConstantOp( # pylint: disable=too-many-function-args
resulting_type,
attribute,
loc=self.location(),
),
)

try:
Expand Down Expand Up @@ -2813,6 +2817,25 @@ def maxpool2d(

return self.to_signedness(result, of=resulting_type)

def min_max(
self,
resulting_type: ConversionType,
x: Conversion,
axes: Union[int, np.integer, Sequence[Union[int, np.integer]]] = (),
keep_dims: bool = False,
*,
operation: str,
):
# This import needs to happen here to avoid circular imports.

# pylint: disable=import-outside-toplevel

from .operations.min_max import min_max

return min_max(self, resulting_type, x, axes, keep_dims, operation=operation)

# pylint: enable=import-outside-toplevel

def minimum(
self,
resulting_type: ConversionType,
Expand Down Expand Up @@ -3987,7 +4010,7 @@ def zeros(self, resulting_type: ConversionType) -> Conversion:
)

def get_partition_name(self, partition: tfhers.TFHERSParams) -> str:
if partition not in self.tfhers_partition.keys():
if partition not in self.tfhers_partition:
self.tfhers_partition[partition] = f"tfhers_{randint(0, 2**32)}" # noqa: S311
return self.tfhers_partition[partition]

Expand Down
4 changes: 2 additions & 2 deletions frontends/concrete-python/concrete/fhe/mlir/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Declaration of `ConversionType` and `Conversion` classes.
"""

# pylint: disable=import-error,
# pylint: disable=import-error,no-name-in-module

import re
from typing import Optional, Tuple
Expand All @@ -12,7 +12,7 @@

from ..representation import Node

# pylint: enable=import-error
# pylint: enable=import-error,no-name-in-module


SCALAR_INT_SEARCH_REGEX = re.compile(r"^i([0-9]+)$")
Expand Down
28 changes: 28 additions & 0 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,20 @@ def matmul(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversio
assert len(preds) == 2
return ctx.matmul(ctx.typeof(node), preds[0], preds[1])

def max(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1

if all(pred.is_encrypted for pred in preds):
return ctx.min_max(
ctx.typeof(node),
preds[0],
axes=node.properties["kwargs"].get("axis", ()),
keep_dims=node.properties["kwargs"].get("keepdims", False),
operation="max",
)

return self.tlu(ctx, node, preds)

def maximum(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2

Expand All @@ -556,6 +570,20 @@ def maxpool3d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conver
ctx.error({node: "3-dimensional maxpooling is not supported at the moment"})
assert False, "unreachable" # pragma: no cover

def min(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1

if all(pred.is_encrypted for pred in preds):
return ctx.min_max(
ctx.typeof(node),
preds[0],
axes=node.properties["kwargs"].get("axis", ()),
keep_dims=node.properties["kwargs"].get("keepdims", False),
operation="min",
)

return self.tlu(ctx, node, preds)

def minimum(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ def fancy_indexing(
resulting_type,
x.result,
indices.result,
original_bit_width=x.original_bit_width,
)


Expand Down Expand Up @@ -640,6 +641,7 @@ def indexing(
MlirDenseI64ArrayAttr.get(static_offsets),
MlirDenseI64ArrayAttr.get(static_sizes),
MlirDenseI64ArrayAttr.get(static_strides),
original_bit_width=x.original_bit_width,
)

reassociaton = []
Expand Down Expand Up @@ -669,4 +671,5 @@ def indexing(
for indices in reassociaton
],
),
original_bit_width=x.original_bit_width,
)
210 changes: 210 additions & 0 deletions frontends/concrete-python/concrete/fhe/mlir/operations/min_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""
Conversion of min and max operations.
"""

from copy import deepcopy
from typing import Sequence, Set, Tuple, Union

import numpy as np

from ..context import Context
from ..conversion import Conversion, ConversionType


def min_max(
ctx: Context,
resulting_type: ConversionType,
x: Conversion,
axes: Union[int, np.integer, Sequence[Union[int, np.integer]]] = (),
keep_dims: bool = False,
*,
operation: str,
) -> Conversion:
"""
Convert min or max operation.
Args:
ctx (Context):
conversion context
resulting_type (ConversionType):
resulting type of the operation
x (Conversion):
input of the operation
axes (Union[int, np.integer, Sequence[Union[int, np.integer]]], default = ()):
axes to reduce over
keep_dims (bool, default = False):
whether to keep the reduced axes
operation (str):
"min" or "max"
Returns:
Conversion:
np.min or np.max on x depending on operation
"""

if x.is_clear:
highlights = {
x.origin: "value is clear",
ctx.converting: f"but computing {operation} of clear values is not supported",
}
ctx.error(highlights)

if x.is_scalar:
return x

if axes is None:
axes = []

axes = list(
set([int(axes)] if isinstance(axes, (int, np.integer)) else [int(axis) for axis in axes])
)

input_dimensions = len(x.shape)
for i, axis in enumerate(axes):
if axis < 0:
axes[i] += input_dimensions

if len(axes) == 0 or len(axes) == len(x.shape):
x = ctx.flatten(x)

resulting_element_type = ctx.element_typeof(resulting_type)
if len(x.shape) == 1:
result = reduce(ctx, resulting_element_type, x, operation=operation)
if keep_dims:
result = ctx.reshape(result, shape=resulting_type.shape)
return result

class Mock:
indices: Set[Tuple[int, ...]]

def __init__(self, index: Tuple[int, ...]):
self.indices = {index}

def __repr__(self) -> str:
return f"{self.indices}"

def combine(self, other: "Mock") -> "Mock":
result = deepcopy(self)
for index in other.indices:
result.indices.add(index)
return result

mock_input = []
for index in np.ndindex(x.shape):
mock_input.append(Mock(index))
mock_input = np.array(mock_input).reshape(x.shape)

def accumulate(mock1, mock2):
return mock1.combine(mock2)

mock_output = np.frompyfunc(accumulate, 2, 1).reduce(
mock_input,
axis=tuple(axes),
keepdims=keep_dims,
)

sample_mock = mock_output.flat[0]

sample_mock_indices = sample_mock.indices
sample_mock_index = next(iter(sample_mock_indices))

number_of_comparisons = len(sample_mock_indices)
number_of_indices = len(sample_mock_index)
index_shape = resulting_type.shape

to_compare = []
for _ in range(number_of_comparisons):
indices = []
for _ in range(number_of_indices):
index = np.zeros(index_shape, dtype=np.int64) # type: ignore
indices.append(index)
to_compare.append(tuple(indices))

for position in np.ndindex(mock_output.shape):
mock_indices = list(mock_output[position].indices)
for i in range(number_of_comparisons):
for j in range(number_of_indices):
to_compare[i][j][position] = mock_indices[i][j] # type: ignore

extracted = []
extracted_type = ctx.tensor(ctx.element_typeof(x), shape=resulting_type.shape)

for index in to_compare: # type: ignore
extracted.append(ctx.index(extracted_type, x, index))

while len(extracted) > 1:
a = extracted.pop()
b = extracted.pop()

if operation == "min":
c = ctx.minimum(resulting_type, a, b)
else:
c = ctx.maximum(resulting_type, a, b)

c.set_original_bit_width(x.original_bit_width)
extracted.insert(0, c)

return extracted[0]


def reduce(
ctx: Context,
resulting_type: ConversionType,
values: Conversion,
*,
operation: str,
) -> Conversion:
"""
Reduce a vector of values to its min/max value.
"""

assert operation in {"min", "max"}

assert values.is_tensor
assert len(values.shape) == 1
assert values.is_encrypted

assert resulting_type.is_scalar
assert resulting_type.is_encrypted

middle = values.shape[0] // 2
values_element_type = ctx.element_typeof(values)

half_type = (
ctx.tensor(values_element_type, shape=(middle,)) if middle != 1 else values_element_type
)
reduced_type = ctx.tensor(resulting_type, shape=(middle,)) if middle != 1 else resulting_type

if middle == 1:
first_half = ctx.index(half_type, values, index=[0])
second_half = ctx.index(half_type, values, index=[1])
else:
first_half = ctx.index(half_type, values, index=[slice(0, middle)])
second_half = ctx.index(half_type, values, index=[slice(middle, 2 * middle)])

if operation == "min":
reduced = ctx.minimum(reduced_type, first_half, second_half)
else:
reduced = ctx.maximum(reduced_type, first_half, second_half)

reduced.set_original_bit_width(values.original_bit_width)

result = (
reduced if reduced.is_scalar else reduce(ctx, resulting_type, reduced, operation=operation)
)

if values.shape[0] % 2 == 1:
last_value = ctx.index(values_element_type, values, index=[-1])
result = (
ctx.minimum(resulting_type, result, last_value)
if operation == "min"
else ctx.maximum(resulting_type, result, last_value)
)
result.set_original_bit_width(values.original_bit_width)

return result
Loading

0 comments on commit 7f6f90c

Please sign in to comment.