Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Identity + Fix Scalar Squeeze #696

Merged
merged 2 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/tutorial/extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,37 @@ assert circuit.encrypt_run_decrypt(0, 3, -5) == -5
{% hint style="info" %}
`fhe.if_then_else` is just an alias for [np.where](https://numpy.org/doc/stable/reference/generated/numpy.where.html).
{% endhint %}

## fhe.identity(value)

Allows you to copy the value:

```python
import numpy as np
from concrete import fhe

@fhe.compiler({"x": "encrypted"})
def f(x):
return fhe.identity(x)

inputset = [np.random.randint(-10, 10) for _ in range(10)]
circuit = f.compile(inputset)

assert circuit.encrypt_run_decrypt(0) == 0
assert circuit.encrypt_run_decrypt(1) == 1
assert circuit.encrypt_run_decrypt(-1) == -1
assert circuit.encrypt_run_decrypt(-3) == -3
assert circuit.encrypt_run_decrypt(5) == 5
```

{% hint style="info" %}
Identity extension can be used to clone an input while changing its bit-width. Imagine you
have `return x**2, x+100` where `x` is 2-bits. Because of `x+100`, `x` will be assigned 7-bits
and `x**2` would be more expensive than it needs to be. If `return x**2, fhe.identity(x)+100`
is used instead, `x` will be assigned 2-bits as it should and `fhe.identity(x)` will be assigned
7-bits as necessary.
{% endhint %}

{% hint style="warning" %}
Identity extension only works in `Native` encoding, which is usually selected when all table lookups in the circuit are below or equal to 8 bits.
{% endhint %}
1 change: 1 addition & 0 deletions frontends/concrete-python/concrete/fhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
bits,
conv,
hint,
identity,
if_then_else,
maxpool,
multivariate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .bits import bits
from .convolution import conv
from .hint import hint
from .identity import identity
from .maxpool import maxpool
from .multivariate import multivariate
from .ones import one, ones, ones_like
Expand Down
37 changes: 37 additions & 0 deletions frontends/concrete-python/concrete/fhe/extensions/identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Declaration of `identity` extension.
"""

from copy import deepcopy
from typing import Any, Union

from ..representation import Node
from ..tracing import Tracer


def identity(x: Union[Tracer, Any]) -> Union[Tracer, Any]:
"""
Apply identity function to x.

Bit-width of the input and the output can be different.

Args:
x (Union[Tracer, Any]):
input to identity

Returns:
Union[Tracer, Any]:
identity tracer if called with a tracer
deepcopy of the input otherwise
"""

if not isinstance(x, Tracer):
return deepcopy(x)

computation = Node.generic(
"identity",
[deepcopy(x.output)],
x.output,
lambda x: deepcopy(x), # pylint: disable=unnecessary-lambda
)
return Tracer(computation, [x])
32 changes: 32 additions & 0 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2269,6 +2269,38 @@ def greater_equal(
) -> Conversion:
return self.comparison(resulting_type, x, y, accept={Comparison.GREATER, Comparison.EQUAL})

def identity(self, resulting_type: ConversionType, x: Conversion) -> Conversion:
assert (
x.is_encrypted
and resulting_type.is_encrypted
and x.shape == resulting_type.shape
and x.is_signed == resulting_type.is_signed
)

if resulting_type.bit_width == x.bit_width:
rudy-6-4 marked this conversation as resolved.
Show resolved Hide resolved
return x

result = self.extract_bits(
self.tensor(self.eint(resulting_type.bit_width), shape=x.shape),
x,
bits=slice(0, x.original_bit_width),
)

if x.is_signed:
sign = self.extract_bits(
self.tensor(self.eint(resulting_type.bit_width), shape=x.shape),
x,
bits=(x.original_bit_width - 1),
)
base = self.mul(
resulting_type,
sign,
self.constant(self.i(sign.bit_width + 1), -(2**x.original_bit_width)),
)
result = self.add(resulting_type, base, result)

return result

def index_static(
self,
resulting_type: ConversionType,
Expand Down
7 changes: 7 additions & 0 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ def greater_equal(self, ctx: Context, node: Node, preds: List[Conversion]) -> Co

return self.tlu(ctx, node, preds)

def identity(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
return ctx.identity(ctx.typeof(node), preds[0])

def index_static(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
return ctx.index_static(
Expand Down Expand Up @@ -517,6 +521,9 @@ def squeeze(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversi
# if the output shape is (), it means (1, 1, ..., 1, 1) is squeezed
# and the result is a scalar, so we need to do indexing, not reshape
if node.output.shape == ():
if preds[0].shape == ():
return preds[0]

assert all(size == 1 for size in preds[0].shape)
index = (0,) * len(preds[0].shape)
return ctx.index_static(ctx.typeof(node), preds[0], index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def initialize(identity: Optional[Node]) -> Node:
return identity

identity = Node.generic(
"identity",
"reinterpret",
rudy-6-4 marked this conversation as resolved.
Show resolved Hide resolved
[deepcopy(node.output)],
deepcopy(node.output),
lambda x: x,
Expand Down
71 changes: 71 additions & 0 deletions frontends/concrete-python/tests/execution/test_identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Tests of execution of identity extension.
"""

import random

import numpy as np
import pytest

from concrete import fhe
from concrete.fhe.dtypes import Integer

# pylint: disable=redefined-outer-name


@pytest.mark.parametrize(
"sample,expected_output",
[
(0, 0),
(1, 1),
(-1, -1),
(10, 10),
(-10, -10),
],
)
def test_plain_identity(sample, expected_output):
"""
Test plain evaluation of identity extension.
"""
assert fhe.identity(sample) == expected_output


operations = [
lambda x: fhe.identity(x),
lambda x: fhe.identity(x) + 100,
]

cases = []
for function in operations:
for bit_width in [1, 2, 3, 4, 5, 8, 12]:
for is_signed in [False, True]:
for shape in [(), (3,), (2, 3)]:
cases += [
[
function,
bit_width,
is_signed,
shape,
]
]


@pytest.mark.parametrize(
"function,bit_width,is_signed,shape",
cases,
)
def test_identity(function, bit_width, is_signed, shape, helpers):
"""
Test encrypted evaluation of identity extension.
"""

dtype = Integer(is_signed, bit_width)

inputset = [np.random.randint(dtype.min(), dtype.max() + 1, size=shape) for _ in range(100)]
configuration = helpers.configuration()

compiler = fhe.Compiler(function, {"x": "encrypted"})
circuit = compiler.compile(inputset, configuration)

for value in random.sample(inputset, 8):
helpers.check_execution(circuit, function, value, retries=3)
7 changes: 7 additions & 0 deletions frontends/concrete-python/tests/execution/test_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,13 @@ def copy_modify(x):
},
id="x ** 3",
),
pytest.param(
lambda x: np.squeeze(x),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": ()},
},
id="np.squeeze(x)",
),
pytest.param(
lambda x: np.squeeze(x),
{
Expand Down
2 changes: 1 addition & 1 deletion frontends/concrete-python/tests/mlir/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def assign(x, y):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear
%1 = round_bit_pattern(%0, lsbs_to_remove=2) # ClearScalar<uint6> ∈ [12, 32]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear round bit pattern is not supported
%2 = identity(%1) # ClearScalar<uint6>
%2 = reinterpret(%1) # ClearScalar<uint6>
return %2

""", # noqa: E501
Expand Down
Loading