Skip to content

Commit

Permalink
feat(frontend): add support for np.min and np.max
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Sep 5, 2024
1 parent 532000f commit 3e44083
Show file tree
Hide file tree
Showing 18 changed files with 1,019 additions and 321 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/concrete_python_test_macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,11 @@ jobs:
find .testenv/lib/python3.10/site-packages -not \( -path .testenv/lib/python3.10/site-packages/concrete -prune \) -name 'lib*omp5.dylib' -or -name 'lib*omp.dylib' | xargs -n 1 ln -f -s $(pwd)/.testenv/lib/python3.10/site-packages/concrete/.dylibs/libomp.dylib
cp -R $GITHUB_WORKSPACE/frontends/concrete-python/examples ./examples
cp -R $GITHUB_WORKSPACE/frontends/concrete-python/tests ./tests
cp $GITHUB_WORKSPACE/frontends/concrete-python/Makefile .
KEY_CACHE_DIRECTORY=./KeySetCache PYTEST_MARKERS="not dataflow and not graphviz" make pytest
KEY_CACHE_DIRECTORY=./KeySetCache PYTEST_MARKERS="not dataflow and not graphviz" make pytest-macos
- name: Cleanup host
if: success() || failure()
Expand Down
24 changes: 0 additions & 24 deletions docs/core-features/workarounds.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,6 @@ This document introduces several common techniques for optimizing code to fit Fu
All code snippets provided here are temporary workarounds. In future versions of Concrete, some functions described here could be directly available in a more generic and efficient form. These code snippets are coming from support answers in our [community forum](https://community.zama.ai)
{% endhint %}

## Minimum/Maximum for multiple values

Concrete supports `np.minimum`/`np.maximum` natively, but not `np.min`/`np.max` yet. To achieve the functionality, you can do a series of `np.minimum`/`np.maximum`s:

```python
import numpy as np
from concrete import fhe

@fhe.compiler({"args": "encrypted"})
def fhe_min(args):
remaining = list(args)
while len(remaining) > 1:
a = remaining.pop()
b = remaining.pop()
remaining.insert(0, np.minimum(a, b))
return remaining[0]

inputset = [np.random.randint(0, 16, size=5) for _ in range(50)]
circuit = fhe_min.compile(inputset, min_max_strategy_preference=fhe.MinMaxStrategy.ONE_TLU_PROMOTED)

x1, x2, x3, x4, x5 = np.random.randint(0, 16, size=5)
assert circuit.encrypt_run_decrypt([x1, x2, x3, x4, x5]) == min(x1, x2, x3, x4, x5)
```

## Retrieving a value within an encrypted array with an encrypted index
This example demonstrates how to retrieve a value from an array using an encrypted index. The method creates a "selection" array filled with `0`s except for the requested index, which will be `1`. It then sums the products of all array values with this selection array:

Expand Down
2 changes: 2 additions & 0 deletions docs/dev/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ Some operations are not supported between two encrypted values. If attempted, a
* [np.logical\_or](https://numpy.org/doc/stable/reference/generated/numpy.logical\_or.html)
* [np.logical\_xor](https://numpy.org/doc/stable/reference/generated/numpy.logical\_xor.html)
* [np.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html)
* [np.max](https://numpy.org/doc/stable/reference/generated/numpy.max.html)
* [np.maximum](https://numpy.org/doc/stable/reference/generated/numpy.maximum.html)
* [np.min](https://numpy.org/doc/stable/reference/generated/numpy.min.html)
* [np.minimum](https://numpy.org/doc/stable/reference/generated/numpy.minimum.html)
* [np.multiply](https://numpy.org/doc/stable/reference/generated/numpy.multiply.html)
* [np.negative](https://numpy.org/doc/stable/reference/generated/numpy.negative.html)
Expand Down
2 changes: 1 addition & 1 deletion frontends/concrete-python/.ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ ignore = [
"**/__init__.py" = ["F401"]
"concrete/fhe/compilation/configuration.py" = ["ARG002"]
"concrete/fhe/mlir/processors/all.py" = ["F401"]
"concrete/fhe/mlir/processors/assign_bit_widths.py" = ["ARG002"]
"concrete/fhe/mlir/processors/assign_bit_widths.py" = ["ARG002", "RUF012"]
"concrete/fhe/mlir/converter.py" = ["ARG002", "B011", "F403", "F405"]
"concrete/**" = ["RUF010"]
"examples/**" = ["PLR2004", "RUF010"]
Expand Down
17 changes: 11 additions & 6 deletions frontends/concrete-python/Makefile
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
PYTHON=python
PIP=$(PYTHON) -m pip

COMPILER_BUILD_DIRECTORY ?= $(PWD)/../../compilers/concrete-compiler/compiler/build
BINDINGS_DIRECTORY=${COMPILER_BUILD_DIRECTORY}/tools/concretelang/python_packages/concretelang_core/
TFHERS_UTILS_DIRECTORY ?= $(PWD)/tests/tfhers-utils/

OS=undefined
COVERAGE_OPT=""
ifeq ($(shell uname), Linux)
OS=linux
COVERAGE_OPT="--cov=concrete.fhe --cov-fail-under=100 --cov-report=term-missing:skip-covered"
RUNTIME_LIBRARY?=${COMPILER_BUILD_DIRECTORY}/lib/libConcretelangRuntime.so
else ifeq ($(shell uname), Darwin)
OS=darwin
RUNTIME_LIBRARY?=${COMPILER_BUILD_DIRECTORY}/lib/libConcretelangRuntime.dylib
endif


COMPILER_BUILD_DIRECTORY ?= $(PWD)/../../compilers/concrete-compiler/compiler/build
BINDINGS_DIRECTORY=${COMPILER_BUILD_DIRECTORY}/tools/concretelang/python_packages/concretelang_core/
RUNTIME_LIBRARY?=${COMPILER_BUILD_DIRECTORY}/lib/libConcretelangRuntime.so
TFHERS_UTILS_DIRECTORY ?= $(PWD)/tests/tfhers-utils/

CONCRETE_VERSION?="" # empty mean latest
# E.g. to use a previous version: `make CONCRETE_VERSION="<2.7.0" venv`
# E.g. to use a nightly: `make CONCRETE_VERSION="==2.7.0dev20240801`
Expand Down Expand Up @@ -76,6 +76,11 @@ pytest-default: tfhers-utils
--key-cache "${KEY_CACHE_DIRECTORY}" \
-m "${PYTEST_MARKERS}"

pytest-macos:
pytest tests -svv -n auto \
--key-cache "${KEY_CACHE_DIRECTORY}" \
-m "${PYTEST_MARKERS}"

pytest-single: tfhers-utils
eval $(shell make silent_cp_activate)
# test single precision, mono params
Expand Down
27 changes: 25 additions & 2 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,7 +2197,11 @@ def constant(
if cached_conversion is None:
cached_conversion = Conversion(
self.converting,
arith.ConstantOp(resulting_type, attribute, loc=self.location()),
arith.ConstantOp( # pylint: disable=too-many-function-args
resulting_type,
attribute,
loc=self.location(),
),
)

try:
Expand Down Expand Up @@ -2813,6 +2817,25 @@ def maxpool2d(

return self.to_signedness(result, of=resulting_type)

def min_max(
self,
resulting_type: ConversionType,
x: Conversion,
axes: Union[int, np.integer, Sequence[Union[int, np.integer]]] = (),
keep_dims: bool = False,
*,
operation: str,
):
# This import needs to happen here to avoid circular imports.

# pylint: disable=import-outside-toplevel

from .operations.min_max import min_max

return min_max(self, resulting_type, x, axes, keep_dims, operation=operation)

# pylint: enable=import-outside-toplevel

def minimum(
self,
resulting_type: ConversionType,
Expand Down Expand Up @@ -3988,7 +4011,7 @@ def zeros(self, resulting_type: ConversionType) -> Conversion:

def get_partition_name(self, partition: tfhers.CryptoParams) -> str:
if partition not in self.tfhers_partition.keys():
self.tfhers_partition[partition] = f"tfhers_{randint(0, 2**32)}" # noqa: S311
self.tfhers_partition[partition] = f"tfhers_{randint(0, 2 ** 32)}" # noqa: S311
return self.tfhers_partition[partition]

def change_partition(
Expand Down
4 changes: 2 additions & 2 deletions frontends/concrete-python/concrete/fhe/mlir/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Declaration of `ConversionType` and `Conversion` classes.
"""

# pylint: disable=import-error,
# pylint: disable=import-error,no-name-in-module

import re
from typing import Optional, Tuple
Expand All @@ -12,7 +12,7 @@

from ..representation import Node

# pylint: enable=import-error
# pylint: enable=import-error,no-name-in-module


SCALAR_INT_SEARCH_REGEX = re.compile(r"^i([0-9]+)$")
Expand Down
34 changes: 34 additions & 0 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ def add(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
return ctx.add(ctx.typeof(node), preds[0], preds[1])

def amax(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
return self.max(ctx, node, preds)

def amin(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
return self.min(ctx, node, preds)

def array(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) > 0
return ctx.array(ctx.typeof(node), elements=preds)
Expand Down Expand Up @@ -530,6 +536,20 @@ def matmul(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversio
assert len(preds) == 2
return ctx.matmul(ctx.typeof(node), preds[0], preds[1])

def max(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1

if all(pred.is_encrypted for pred in preds):
return ctx.min_max(
ctx.typeof(node),
preds[0],
axes=node.properties["kwargs"].get("axis", ()),
keep_dims=node.properties["kwargs"].get("keepdims", False),
operation="max",
)

return self.tlu(ctx, node, preds)

def maximum(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2

Expand All @@ -556,6 +576,20 @@ def maxpool3d(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conver
ctx.error({node: "3-dimensional maxpooling is not supported at the moment"})
assert False, "unreachable" # pragma: no cover

def min(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1

if all(pred.is_encrypted for pred in preds):
return ctx.min_max(
ctx.typeof(node),
preds[0],
axes=node.properties["kwargs"].get("axis", ()),
keep_dims=node.properties["kwargs"].get("keepdims", False),
operation="min",
)

return self.tlu(ctx, node, preds)

def minimum(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ def fancy_indexing(
resulting_type,
x.result,
indices.result,
original_bit_width=x.original_bit_width,
)


Expand Down Expand Up @@ -640,6 +641,7 @@ def indexing(
MlirDenseI64ArrayAttr.get(static_offsets),
MlirDenseI64ArrayAttr.get(static_sizes),
MlirDenseI64ArrayAttr.get(static_strides),
original_bit_width=x.original_bit_width,
)

reassociaton = []
Expand Down Expand Up @@ -669,4 +671,5 @@ def indexing(
for indices in reassociaton
],
),
original_bit_width=x.original_bit_width,
)
Loading

0 comments on commit 3e44083

Please sign in to comment.