Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Key Value Database Example #599

Merged
merged 2 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/tutorial/extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,21 @@ return %5
```

regardless of the bounds.

Alternatively, you can use it to make sure a value can store certain integers:

```python
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def is_vectors_same(x, y):
assert x.ndim != 1
assert y.ndim != 1

assert len(x) == len(y)
n = len(x)

number_of_same_elements = np.sum(x == y)
fhe.hint(number_of_same_elements, can_store=n) # hint that number of same elements can go up to n
is_same = number_of_same_elements == n

return is_same
```
21 changes: 19 additions & 2 deletions frontends/concrete-python/concrete/fhe/extensions/hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@

from typing import Any, Optional, Union

from ..dtypes import Integer
from ..tracing import Tracer


def hint(x: Union[Tracer, Any], *, bit_width: Optional[int] = None) -> Union[Tracer, Any]:
def hint(
x: Union[Tracer, Any],
*,
bit_width: Optional[int] = None,
can_store: Optional[Any] = None,
) -> Union[Tracer, Any]:
"""
Hint the compilation process about properties of a value.

Expand All @@ -24,6 +30,9 @@ def hint(x: Union[Tracer, Any], *, bit_width: Optional[int] = None) -> Union[Tra
bit_width (Optional[int], default = None):
hint about bit width

can_store (Optional[Any], default = None):
hint that the value needs to be able to store the given value

Returns:
Union[Tracer, Any]:
hinted value
Expand All @@ -32,7 +41,15 @@ def hint(x: Union[Tracer, Any], *, bit_width: Optional[int] = None) -> Union[Tra
if not isinstance(x, Tracer): # pragma: no cover
return x

bit_width_hint = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we reject invalid value ?
E.g. bit_width = 0 is ignored.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the semantically correct behavior. bit_width = 0 in this context means bit width needs to be at least 0 which will always be the case so we can ignore the constraint as it's already satisfied.


if bit_width is not None:
x.computation.properties["bit_width_hint"] = bit_width
bit_width_hint = max(bit_width_hint, bit_width)

if can_store is not None:
bit_width_hint = max(bit_width_hint, Integer.that_can_represent(can_store).bit_width)

if bit_width_hint > 0:
x.computation.properties["bit_width_hint"] = bit_width_hint

return x
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def decode(encoded_number):


def _replace_impl(key, value, candidate_key, candidate_value):
match = np.sum((candidate_key - key) == 0) == NUMBER_OF_KEY_CHUNKS
number_of_matching_chunks = np.sum((candidate_key - key) == 0)
fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS)
match = number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS

packed_match_and_value = (2**CHUNK_SIZE) * match + value
value_if_match_else_zeros = keep_if_match_lut[packed_match_and_value]
Expand All @@ -55,7 +57,9 @@ def _replace_impl(key, value, candidate_key, candidate_value):


def _query_impl(key, candidate_key, candidate_value):
match = np.sum((candidate_key - key) == 0) == NUMBER_OF_KEY_CHUNKS
number_of_matching_chunks = np.sum((candidate_key - key) == 0)
fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS)
match = number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS

packed_match_and_candidate_value = (2**CHUNK_SIZE) * match + candidate_value
candidate_value_if_match_else_zeros = keep_if_match_lut[packed_match_and_candidate_value]
Expand Down Expand Up @@ -182,9 +186,10 @@ def query(self, key):

encoded_key = encode_key(key)

accumulation = np.zeros(1 + NUMBER_OF_VALUE_CHUNKS, dtype=np.uint64)
accumulation = np.zeros(1 + NUMBER_OF_VALUE_CHUNKS, dtype=np.int64)
for entry in self._state:
accumulation += self._query_circuit.encrypt_run_decrypt(encoded_key, *entry)
contribution = self._query_circuit.encrypt_run_decrypt(encoded_key, *entry)
accumulation += contribution

match_count = accumulation[0]
if match_count > 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ def _replace_impl(state, key, value):
keys = state[:, KEY]
values = state[:, VALUE]

equal_rows = np.sum((keys - key) == 0, axis=1) == NUMBER_OF_KEY_CHUNKS
number_of_matching_chunks = np.sum((keys - key) == 0, axis=1)
fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS)

equal_rows = number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS
selection = (flags * 2 + equal_rows == 3).reshape((-1, 1))

packed_selection_and_value = selection * (2**CHUNK_SIZE) + value
Expand All @@ -103,7 +106,10 @@ def _query_impl(state, key):
keys = state[:, KEY]
values = state[:, VALUE]

selection = (np.sum((keys - key) == 0, axis=1) == NUMBER_OF_KEY_CHUNKS).reshape((-1, 1))
number_of_matching_chunks = np.sum((keys - key) == 0, axis=1)
fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS)

selection = (number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS).reshape((-1, 1))
found = np.sum(selection)

packed_selection_and_values = selection * (2**CHUNK_SIZE) + values
Expand Down
26 changes: 26 additions & 0 deletions frontends/concrete-python/tests/mlir/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,32 @@ def test_converter_bad_convert(
}
}

""", # noqa: E501
),
pytest.param(
lambda x, y: fhe.hint(np.sum((x - y) == 0), can_store=len(x)) == len(x),
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (10,)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (10,)},
},
"""

module {
func.func @main(%arg0: tensor<10x!FHE.eint<4>>, %arg1: tensor<10x!FHE.eint<4>>) -> !FHE.eint<1> {
%0 = "FHELinalg.to_signed"(%arg0) : (tensor<10x!FHE.eint<4>>) -> tensor<10x!FHE.esint<4>>
%1 = "FHELinalg.to_signed"(%arg1) : (tensor<10x!FHE.eint<4>>) -> tensor<10x!FHE.esint<4>>
%2 = "FHELinalg.sub_eint"(%0, %1) : (tensor<10x!FHE.esint<4>>, tensor<10x!FHE.esint<4>>) -> tensor<10x!FHE.esint<4>>
%c0_i2 = arith.constant 0 : i2
%cst = arith.constant dense<[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<16xi64>
%3 = "FHELinalg.apply_lookup_table"(%2, %cst) : (tensor<10x!FHE.esint<4>>, tensor<16xi64>) -> tensor<10x!FHE.eint<4>>
%4 = "FHELinalg.sum"(%3) {axes = [], keep_dims = false} : (tensor<10x!FHE.eint<4>>) -> !FHE.eint<4>
%c10_i5 = arith.constant 10 : i5
%cst_0 = arith.constant dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]> : tensor<16xi64>
%5 = "FHE.apply_lookup_table"(%4, %cst_0) : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<1>
return %5 : !FHE.eint<1>
}
}

""", # noqa: E501
),
],
Expand Down
Loading