Skip to content

Commit

Permalink
feat(frontend): add support for np.min and np.max
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Sep 2, 2024
1 parent 5716912 commit 4b922d3
Show file tree
Hide file tree
Showing 16 changed files with 947 additions and 274 deletions.
24 changes: 0 additions & 24 deletions docs/core-features/workarounds.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,6 @@ This document introduces several common techniques for optimizing code to fit Fu
All code snippets provided here are temporary workarounds. In future versions of Concrete, some functions described here could be directly available in a more generic and efficient form. These code snippets are coming from support answers in our [community forum](https://community.zama.ai)
{% endhint %}

## Minimum/Maximum for multiple values

Concrete supports `np.minimum`/`np.maximum` natively, but not `np.min`/`np.max` yet. To achieve the functionality, you can do a series of `np.minimum`/`np.maximum`s:

```python
import numpy as np
from concrete import fhe

@fhe.compiler({"args": "encrypted"})
def fhe_min(args):
remaining = list(args)
while len(remaining) > 1:
a = remaining.pop()
b = remaining.pop()
remaining.insert(0, np.minimum(a, b))
return remaining[0]

inputset = [np.random.randint(0, 16, size=5) for _ in range(50)]
circuit = fhe_min.compile(inputset, min_max_strategy_preference=fhe.MinMaxStrategy.ONE_TLU_PROMOTED)

x1, x2, x3, x4, x5 = np.random.randint(0, 16, size=5)
assert circuit.encrypt_run_decrypt([x1, x2, x3, x4, x5]) == min(x1, x2, x3, x4, x5)
```

## Retrieving a value within an encrypted array with an encrypted index
This example demonstrates how to retrieve a value from an array using an encrypted index. The method creates a "selection" array filled with `0`s except for the requested index, which will be `1`. It then sums the products of all array values with this selection array:

Expand Down
2 changes: 2 additions & 0 deletions docs/dev/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ Some operations are not supported between two encrypted values. If attempted, a
* [np.logical\_or](https://numpy.org/doc/stable/reference/generated/numpy.logical\_or.html)
* [np.logical\_xor](https://numpy.org/doc/stable/reference/generated/numpy.logical\_xor.html)
* [np.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html)
* [np.max](https://numpy.org/doc/stable/reference/generated/numpy.max.html)
* [np.maximum](https://numpy.org/doc/stable/reference/generated/numpy.maximum.html)
* [np.min](https://numpy.org/doc/stable/reference/generated/numpy.min.html)
* [np.minimum](https://numpy.org/doc/stable/reference/generated/numpy.minimum.html)
* [np.multiply](https://numpy.org/doc/stable/reference/generated/numpy.multiply.html)
* [np.negative](https://numpy.org/doc/stable/reference/generated/numpy.negative.html)
Expand Down
27 changes: 25 additions & 2 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,7 +2197,11 @@ def constant(
if cached_conversion is None:
cached_conversion = Conversion(
self.converting,
arith.ConstantOp(resulting_type, attribute, loc=self.location()),
arith.ConstantOp( # pylint: disable=too-many-function-args
resulting_type,
attribute,
loc=self.location(),
),
)

try:
Expand Down Expand Up @@ -2813,6 +2817,25 @@ def maxpool2d(

return self.to_signedness(result, of=resulting_type)

def min_max(
self,
resulting_type: ConversionType,
x: Conversion,
axes: Union[int, np.integer, Sequence[Union[int, np.integer]]] = (),
keep_dims: bool = False,
*,
operation: str,
):
# This import needs to happen here to avoid circular imports.

# pylint: disable=import-outside-toplevel

from .operations.min_max import min_max

return min_max(self, resulting_type, x, axes, keep_dims, operation=operation)

# pylint: enable=import-outside-toplevel

def minimum(
self,
resulting_type: ConversionType,
Expand Down Expand Up @@ -3988,7 +4011,7 @@ def zeros(self, resulting_type: ConversionType) -> Conversion:

def get_partition_name(self, partition: tfhers.TFHERSParams) -> str:
if partition not in self.tfhers_partition.keys():
self.tfhers_partition[partition] = f"tfhers_{randint(0, 2**32)}" # noqa: S311
self.tfhers_partition[partition] = f"tfhers_{randint(0, 2 ** 32)}" # noqa: S311
return self.tfhers_partition[partition]

def change_partition(
Expand Down
4 changes: 2 additions & 2 deletions frontends/concrete-python/concrete/fhe/mlir/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Declaration of `ConversionType` and `Conversion` classes.
"""

# pylint: disable=import-error,
# pylint: disable=import-error,no-name-in-module

import re
from typing import Optional, Tuple
Expand All @@ -12,7 +12,7 @@

from ..representation import Node

# pylint: enable=import-error
# pylint: enable=import-error,no-name-in-module


SCALAR_INT_SEARCH_REGEX = re.compile(r"^i([0-9]+)$")
Expand Down
28 changes: 28 additions & 0 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,20 @@ def matmul(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversio
assert len(preds) == 2
return ctx.matmul(ctx.typeof(node), preds[0], preds[1])

def max(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1

if all(pred.is_encrypted for pred in preds):
return ctx.min_max(
ctx.typeof(node),
preds[0],
axes=node.properties["kwargs"].get("axis", ()),
keep_dims=node.properties["kwargs"].get("keepdims", False),
operation="max",
)

return self.tlu(ctx, node, preds)

def maximum(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2

Expand All @@ -556,6 +570,20 @@ def maxpool3d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conver
ctx.error({node: "3-dimensional maxpooling is not supported at the moment"})
assert False, "unreachable" # pragma: no cover

def min(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1

if all(pred.is_encrypted for pred in preds):
return ctx.min_max(
ctx.typeof(node),
preds[0],
axes=node.properties["kwargs"].get("axis", ()),
keep_dims=node.properties["kwargs"].get("keepdims", False),
operation="min",
)

return self.tlu(ctx, node, preds)

def minimum(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ def fancy_indexing(
resulting_type,
x.result,
indices.result,
original_bit_width=x.original_bit_width,
)


Expand Down Expand Up @@ -640,6 +641,7 @@ def indexing(
MlirDenseI64ArrayAttr.get(static_offsets),
MlirDenseI64ArrayAttr.get(static_sizes),
MlirDenseI64ArrayAttr.get(static_strides),
original_bit_width=x.original_bit_width,
)

reassociaton = []
Expand Down Expand Up @@ -669,4 +671,5 @@ def indexing(
for indices in reassociaton
],
),
original_bit_width=x.original_bit_width,
)
Loading

0 comments on commit 4b922d3

Please sign in to comment.