diff --git a/docs/tutorial/extensions.md b/docs/tutorial/extensions.md index 00ca5c1d0f..43ae05a99f 100644 --- a/docs/tutorial/extensions.md +++ b/docs/tutorial/extensions.md @@ -286,3 +286,21 @@ return %5 ``` regardless of the bounds. + +Alternatively, you can use it to make sure a value can store certain integers: + +```python +@fhe.compiler({"x": "encrypted", "y": "encrypted"}) +def is_vectors_same(x, y): + assert x.ndim != 1 + assert y.ndim != 1 + + assert len(x) == len(y) + n = len(x) + + number_of_same_elements = np.sum(x == y) + fhe.hint(number_of_same_elements, can_store=n) # hint that number of same elements can go up to n + is_same = number_of_same_elements == n + + return is_same +``` diff --git a/frontends/concrete-python/concrete/fhe/extensions/hint.py b/frontends/concrete-python/concrete/fhe/extensions/hint.py index c447d95062..8609650ee0 100644 --- a/frontends/concrete-python/concrete/fhe/extensions/hint.py +++ b/frontends/concrete-python/concrete/fhe/extensions/hint.py @@ -4,10 +4,16 @@ from typing import Any, Optional, Union +from ..dtypes import Integer from ..tracing import Tracer -def hint(x: Union[Tracer, Any], *, bit_width: Optional[int] = None) -> Union[Tracer, Any]: +def hint( + x: Union[Tracer, Any], + *, + bit_width: Optional[int] = None, + can_store: Optional[Any] = None, +) -> Union[Tracer, Any]: """ Hint the compilation process about properties of a value. @@ -24,6 +30,9 @@ def hint(x: Union[Tracer, Any], *, bit_width: Optional[int] = None) -> Union[Tra bit_width (Optional[int], default = None): hint about bit width + can_store (Optional[Any], default = None): + hint that the value needs to be able to store the given value + Returns: Union[Tracer, Any]: hinted value @@ -32,7 +41,15 @@ def hint(x: Union[Tracer, Any], *, bit_width: Optional[int] = None) -> Union[Tra if not isinstance(x, Tracer): # pragma: no cover return x + bit_width_hint = 0 + if bit_width is not None: - x.computation.properties["bit_width_hint"] = bit_width + bit_width_hint = max(bit_width_hint, bit_width) + + if can_store is not None: + bit_width_hint = max(bit_width_hint, Integer.that_can_represent(can_store).bit_width) + + if bit_width_hint > 0: + x.computation.properties["bit_width_hint"] = bit_width_hint return x diff --git a/frontends/concrete-python/examples/key-value-database/dynamic-size.py b/frontends/concrete-python/examples/key-value-database/dynamic-size.py index f4c7a10f6d..c55de94d12 100644 --- a/frontends/concrete-python/examples/key-value-database/dynamic-size.py +++ b/frontends/concrete-python/examples/key-value-database/dynamic-size.py @@ -43,7 +43,9 @@ def decode(encoded_number): def _replace_impl(key, value, candidate_key, candidate_value): - match = np.sum((candidate_key - key) == 0) == NUMBER_OF_KEY_CHUNKS + number_of_matching_chunks = np.sum((candidate_key - key) == 0) + fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS) + match = number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS packed_match_and_value = (2**CHUNK_SIZE) * match + value value_if_match_else_zeros = keep_if_match_lut[packed_match_and_value] @@ -55,7 +57,9 @@ def _replace_impl(key, value, candidate_key, candidate_value): def _query_impl(key, candidate_key, candidate_value): - match = np.sum((candidate_key - key) == 0) == NUMBER_OF_KEY_CHUNKS + number_of_matching_chunks = np.sum((candidate_key - key) == 0) + fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS) + match = number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS packed_match_and_candidate_value = (2**CHUNK_SIZE) * match + candidate_value candidate_value_if_match_else_zeros = keep_if_match_lut[packed_match_and_candidate_value] @@ -182,9 +186,10 @@ def query(self, key): encoded_key = encode_key(key) - accumulation = np.zeros(1 + NUMBER_OF_VALUE_CHUNKS, dtype=np.uint64) + accumulation = np.zeros(1 + NUMBER_OF_VALUE_CHUNKS, dtype=np.int64) for entry in self._state: - accumulation += self._query_circuit.encrypt_run_decrypt(encoded_key, *entry) + contribution = self._query_circuit.encrypt_run_decrypt(encoded_key, *entry) + accumulation += contribution match_count = accumulation[0] if match_count > 1: diff --git a/frontends/concrete-python/examples/key-value-database/static-size.py b/frontends/concrete-python/examples/key-value-database/static-size.py index 434fcd627c..c0427e189d 100644 --- a/frontends/concrete-python/examples/key-value-database/static-size.py +++ b/frontends/concrete-python/examples/key-value-database/static-size.py @@ -83,7 +83,10 @@ def _replace_impl(state, key, value): keys = state[:, KEY] values = state[:, VALUE] - equal_rows = np.sum((keys - key) == 0, axis=1) == NUMBER_OF_KEY_CHUNKS + number_of_matching_chunks = np.sum((keys - key) == 0, axis=1) + fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS) + + equal_rows = number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS selection = (flags * 2 + equal_rows == 3).reshape((-1, 1)) packed_selection_and_value = selection * (2**CHUNK_SIZE) + value @@ -103,7 +106,10 @@ def _query_impl(state, key): keys = state[:, KEY] values = state[:, VALUE] - selection = (np.sum((keys - key) == 0, axis=1) == NUMBER_OF_KEY_CHUNKS).reshape((-1, 1)) + number_of_matching_chunks = np.sum((keys - key) == 0, axis=1) + fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS) + + selection = (number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS).reshape((-1, 1)) found = np.sum(selection) packed_selection_and_values = selection * (2**CHUNK_SIZE) + values diff --git a/frontends/concrete-python/tests/mlir/test_converter.py b/frontends/concrete-python/tests/mlir/test_converter.py index 7063b18283..76963ebfb2 100644 --- a/frontends/concrete-python/tests/mlir/test_converter.py +++ b/frontends/concrete-python/tests/mlir/test_converter.py @@ -976,6 +976,32 @@ def test_converter_bad_convert( } } + """, # noqa: E501 + ), + pytest.param( + lambda x, y: fhe.hint(np.sum((x - y) == 0), can_store=len(x)) == len(x), + { + "x": {"range": [0, 7], "status": "encrypted", "shape": (10,)}, + "y": {"range": [0, 7], "status": "encrypted", "shape": (10,)}, + }, + """ + +module { + func.func @main(%arg0: tensor<10x!FHE.eint<4>>, %arg1: tensor<10x!FHE.eint<4>>) -> !FHE.eint<1> { + %0 = "FHELinalg.to_signed"(%arg0) : (tensor<10x!FHE.eint<4>>) -> tensor<10x!FHE.esint<4>> + %1 = "FHELinalg.to_signed"(%arg1) : (tensor<10x!FHE.eint<4>>) -> tensor<10x!FHE.esint<4>> + %2 = "FHELinalg.sub_eint"(%0, %1) : (tensor<10x!FHE.esint<4>>, tensor<10x!FHE.esint<4>>) -> tensor<10x!FHE.esint<4>> + %c0_i2 = arith.constant 0 : i2 + %cst = arith.constant dense<[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<16xi64> + %3 = "FHELinalg.apply_lookup_table"(%2, %cst) : (tensor<10x!FHE.esint<4>>, tensor<16xi64>) -> tensor<10x!FHE.eint<4>> + %4 = "FHELinalg.sum"(%3) {axes = [], keep_dims = false} : (tensor<10x!FHE.eint<4>>) -> !FHE.eint<4> + %c10_i5 = arith.constant 10 : i5 + %cst_0 = arith.constant dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]> : tensor<16xi64> + %5 = "FHE.apply_lookup_table"(%4, %cst_0) : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<1> + return %5 : !FHE.eint<1> + } +} + """, # noqa: E501 ), ],