Skip to content

Commit

Permalink
Merge pull request #14 from serengil/feat-task-1212-some-improvements
Browse files Browse the repository at this point in the history
Feat task 1212 some improvements
  • Loading branch information
serengil authored Dec 12, 2023
2 parents 79e6009 + ca9b98d commit 2e855c0
Show file tree
Hide file tree
Showing 33 changed files with 294 additions and 94 deletions.
3 changes: 1 addition & 2 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,7 @@ disable=raw-checker-failed,
consider-iterating-dictionary,
unexpected-keyword-arg,
arguments-differ,
line-too-long,
broad-exception-caught
line-too-long

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,26 @@ with pytest.raises(ValueError, match="Paillier is not homomorphic with respect t

However, if you tried to multiply ciphertexts with RSA, or xor ciphertexts with Goldwasser-Micali, these will be succeeded because those cryptosystems support those homomorphic operations.

# Encrypt & Decrypt Tensors

You can encrypt the output tensors of machine learning models with LightPHE.

```python
cs = LightPHE(algorithm_name="Paillier")

# define plain tensor
tensor = [1.005, 2.005, 3.005, -4.005, 5.005]

# encrypt tensor
encrypted_tensors = cs.encrypt(tensor)

# decrypt tensor
decrypted_tensors = cs.decrypt(encrypted_tensors)

for i, decrypted_tensor in enumerate(decrypted_tensors):
assert tensor[i] == decrypted_tensor
```

# Contributing

All PRs are more than welcome! If you are planning to contribute a large patch, please create an issue first to get any upfront questions or design decisions out of the way first.
Expand Down
114 changes: 102 additions & 12 deletions lightphe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
from typing import Optional, Union
from typing import Optional, Union, List

from lightphe.models.Homomorphic import Homomorphic
from lightphe.models.Ciphertext import Ciphertext
from lightphe.models.Algorithm import Algorithm
from lightphe.models.Tensor import EncryptedTensor, EncryptedTensors
from lightphe.cryptosystems.RSA import RSA
from lightphe.cryptosystems.ElGamal import ElGamal
from lightphe.cryptosystems.Paillier import Paillier
Expand All @@ -12,12 +14,12 @@
from lightphe.cryptosystems.NaccacheStern import NaccacheStern
from lightphe.cryptosystems.GoldwasserMicali import GoldwasserMicali
from lightphe.cryptosystems.EllipticCurveElGamal import EllipticCurveElGamal
from lightphe.commons import calculations
from lightphe.commons import phe_utils
from lightphe.commons.logger import Logger

# pylint: disable=eval-used
# pylint: disable=eval-used, simplifiable-if-expression

logger = Logger()
logger = Logger(module="lightphe/__init__.py")


class LightPHE:
Expand All @@ -43,11 +45,11 @@ def __init__(
if key_file is not None:
keys = self.restore_keys(target_file=key_file)

self.cs = self.build_cryptosystem(
self.cs: Homomorphic = self.__build_cryptosystem(
algorithm_name=algorithm_name, keys=keys, key_size=key_size
)

def build_cryptosystem(
def __build_cryptosystem(
self,
algorithm_name: str,
keys: Optional[dict] = None,
Expand Down Expand Up @@ -103,31 +105,116 @@ def build_cryptosystem(
raise ValueError(f"unimplemented algorithm - {algorithm_name}")
return cs

def encrypt(self, plaintext: Union[int, float]) -> Ciphertext:
def encrypt(self, plaintext: Union[int, float, list]) -> Union[Ciphertext, EncryptedTensors]:
"""
Encrypt a plaintext with a built cryptosystem
Args:
plaintext (int or float): message
plaintext (int, float or tensor): message
Returns
ciphertext (from lightphe.models.Ciphertext import Ciphertext): encrypted message
"""
if self.cs.keys.get("private_key") is None:
raise ValueError("You must have private key to perform encryption")

if isinstance(plaintext, list):
# then encrypt tensors
return self.__encrypt_tensors(tensor=plaintext)

ciphertext = self.cs.encrypt(
plaintext=calculations.parse_int(
value=plaintext, modulo=self.cs.modulo or self.cs.plaintext_modulo
)
plaintext=phe_utils.parse_int(value=plaintext, modulo=self.cs.plaintext_modulo)
)
return Ciphertext(algorithm_name=self.algorithm_name, keys=self.cs.keys, value=ciphertext)

def decrypt(self, ciphertext: Ciphertext) -> int:
def decrypt(
self, ciphertext: Union[Ciphertext, EncryptedTensors]
) -> Union[int, List[int], List[float]]:
"""
Decrypt a ciphertext with a buit cryptosystem
Args:
ciphertext (from lightphe.models.Ciphertext import Ciphertext): encrypted message
Returns:
plaintext (int): restored message
"""
if self.cs.keys.get("private_key") is None:
raise ValueError("You must have private key to perform decryption")

if isinstance(ciphertext, EncryptedTensors):
# then this is encrypted tensor
return self.__decrypt_tensors(encrypted_tensor=ciphertext)

return self.cs.decrypt(ciphertext=ciphertext.value)

def __encrypt_tensors(self, tensor: list) -> EncryptedTensors:
"""
Encrypt a given tensor
Args:
tensor (list of int or float)
Returns
encrypted tensor (list of encrypted tensor object)
"""
encrypted_tensor: List[EncryptedTensor] = []
for m in tensor:
sign = 1 if m >= 0 else -1
# get rid of sign anyway
m = m * sign
sign_encrypted = self.cs.encrypt(plaintext=sign)
if isinstance(m, int):
dividend_encrypted = self.cs.encrypt(plaintext=m)
divisor_encrypted = self.cs.encrypt(plaintext=1)
c = EncryptedTensor(
dividend=dividend_encrypted,
divisor=divisor_encrypted,
sign=sign_encrypted,
)
elif isinstance(m, float):
dividend, divisor = phe_utils.fractionize(value=m, modulo=self.cs.plaintext_modulo)
dividend_encrypted = self.cs.encrypt(plaintext=dividend)
divisor_encrypted = self.cs.encrypt(plaintext=divisor)
c = EncryptedTensor(
dividend=dividend_encrypted,
divisor=divisor_encrypted,
sign=sign_encrypted,
)
else:
raise ValueError(f"unimplemented type - {type(m)}")
encrypted_tensor.append(c)
return EncryptedTensors(encrypted_tensor=encrypted_tensor)

def __decrypt_tensors(
self, encrypted_tensor: EncryptedTensors
) -> Union[List[int], List[float]]:
"""
Decrypt a given encrypted tensor
Args:
encrypted_tensor (list of encrypted tensor)
Returns:
List of plain tensors
"""
plain_tensor = []
for c in encrypted_tensor.encrypted_tensor:
if isinstance(c, EncryptedTensor) is False:
raise ValueError("Ciphertext items must be EncryptedTensor")

encrypted_dividend = c.dividend
encrypted_divisor = c.divisor
encrypted_sign = c.sign

dividend = self.cs.decrypt(ciphertext=encrypted_dividend)
divisor = self.cs.decrypt(ciphertext=encrypted_divisor)
sign = self.cs.decrypt(ciphertext=encrypted_sign)

if sign == self.cs.plaintext_modulo - 1:
sign = -1
elif sign == 1:
sign = 1
else:
raise ValueError("this cannot be true!")

m = sign * (dividend / divisor)

plain_tensor.append(m)
return plain_tensor

def regenerate_ciphertext(self, ciphertext: Ciphertext) -> Ciphertext:
"""
Generate a different ciphertext belonging to same plaintext
Expand All @@ -136,6 +223,9 @@ def regenerate_ciphertext(self, ciphertext: Ciphertext) -> Ciphertext:
Returns:
ciphertext (from lightphe.models.Ciphertext import Ciphertext): encrypted message
"""
if self.cs.keys.get("private_key") is None:
raise ValueError("You must have private key to perform decryption")

ciphertext_new = self.cs.reencrypt(ciphertext=ciphertext.value)
return Ciphertext(
algorithm_name=self.algorithm_name, keys=self.cs.keys, value=ciphertext_new
Expand Down
23 changes: 0 additions & 23 deletions lightphe/commons/calculations.py

This file was deleted.

3 changes: 2 additions & 1 deletion lightphe/commons/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@


class Logger:
def __init__(self):
def __init__(self, module):
self.module = module
log_level = os.environ.get("LIGHTPHE_LOG_LEVEL", str(logging.INFO))
try:
self.log_level = int(log_level)
Expand Down
35 changes: 35 additions & 0 deletions lightphe/commons/phe_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Union, Tuple
from lightphe.commons.logger import Logger

logger = Logger(module="lightphe/commons/phe_utils.py")

# pylint: disable=no-else-return


def parse_int(value: Union[int, float], modulo: int) -> int:
if isinstance(value, int):
result = value % modulo
elif isinstance(value, float) and value >= 0:
dividend, divisor = fractionize(value=value, modulo=modulo)
logger.debug(f"{dividend}*{divisor}^-1 mod {modulo}")
result = (dividend * pow(divisor, -1, modulo)) % modulo
elif isinstance(value, float) and value < 0:
# TODO: think and implement this later
raise ValueError("Case constant float and negative not implemented yet")
else:
raise ValueError(f"Unimplemented case for constant type {type(value)}")

return result


def fractionize(value: float, modulo: int) -> Tuple[int, int]:
decimal_places = len(str(value).split(".")[1])
scaling_factor = 10**decimal_places
integer_value = int(value * scaling_factor) % modulo
logger.debug(f"{integer_value}*{scaling_factor}^-1 mod {modulo}")
return integer_value, scaling_factor


def solve_dlp():
# TODO: implement this later
pass
2 changes: 1 addition & 1 deletion lightphe/cryptosystems/Benaloh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lightphe.models.Homomorphic import Homomorphic
from lightphe.commons.logger import Logger

logger = Logger()
logger = Logger(module="lightphe/cryptosystems/Benaloh.py")


class Benaloh(Homomorphic):
Expand Down
2 changes: 1 addition & 1 deletion lightphe/cryptosystems/DamgardJurik.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lightphe.models.Homomorphic import Homomorphic
from lightphe.commons.logger import Logger

logger = Logger()
logger = Logger(module="lightphe/cryptosystems/DamgardJurik.py")


class DamgardJurik(Homomorphic):
Expand Down
5 changes: 3 additions & 2 deletions lightphe/cryptosystems/ElGamal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lightphe.models.Homomorphic import Homomorphic
from lightphe.commons.logger import Logger

logger = Logger()
logger = Logger(module="lightphe/cryptosystems/ElGamal.py")


class ElGamal(Homomorphic):
Expand All @@ -26,7 +26,8 @@ def __init__(self, keys: Optional[dict] = None, exponential=False, key_size: int
"""
self.exponential = exponential
self.keys = keys or self.generate_keys(key_size)
self.modulo = self.keys["public_key"]["p"]
self.plaintext_modulo = self.keys["public_key"]["p"]
self.ciphertext_modulo = self.keys["public_key"]["p"]

def generate_keys(self, key_size: int):
"""
Expand Down
5 changes: 3 additions & 2 deletions lightphe/cryptosystems/EllipticCurveElGamal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from lightphe.elliptic.Weierstrass import Weierstrass
from lightphe.commons.logger import Logger

logger = Logger()
logger = Logger(module="lightphe/cryptosystems/EllipticCurveElGamal.py")


class EllipticCurveElGamal(Homomorphic):
Expand All @@ -27,7 +27,8 @@ def __init__(self, keys: Optional[dict] = None, key_size: int = 160):
# TODO: add different forms and curves. e.g. Koblitz, Edwards (Ed25519)
self.curve = Weierstrass()
self.keys = keys or self.generate_keys(key_size)
self.modulo = self.curve.p
self.plaintext_modulo = self.curve.p
self.ciphertext_modulo = self.curve.p

def generate_keys(self, key_size: int):
"""
Expand Down
4 changes: 3 additions & 1 deletion lightphe/cryptosystems/GoldwasserMicali.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from lightphe.models.Homomorphic import Homomorphic
from lightphe.commons.logger import Logger

logger = Logger()
logger = Logger(module="lightphe/cryptosystems/GoldwasserMicali.py")

# pylint:disable=consider-using-enumerate

Expand All @@ -26,6 +26,8 @@ def __init__(self, keys: Optional[dict] = None, key_size=100):
"""
self.keys = keys or self.generate_keys(key_size)
self.ciphertext_modulo = self.keys["public_key"]["n"]
# TODO: not sure about the plaintext modulo
self.plaintext_modulo = self.keys["public_key"]["n"]

def generate_keys(self, key_size: int) -> dict:
"""
Expand Down
2 changes: 1 addition & 1 deletion lightphe/cryptosystems/NaccacheStern.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from lightphe.models.Homomorphic import Homomorphic
from lightphe.commons.logger import Logger

logger = Logger()
logger = Logger(module="lightphe/cryptosystems/NaccacheStern.py")

# pylint: disable=simplifiable-if-expression, consider-using-enumerate

Expand Down
2 changes: 1 addition & 1 deletion lightphe/cryptosystems/OkamotoUchiyama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lightphe.models.Homomorphic import Homomorphic
from lightphe.commons.logger import Logger

logger = Logger()
logger = Logger(module="lightphe/cryptosystems/OkamotoUchiyama.py")


class OkamotoUchiyama(Homomorphic):
Expand Down
2 changes: 1 addition & 1 deletion lightphe/cryptosystems/Paillier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lightphe.models.Homomorphic import Homomorphic
from lightphe.commons.logger import Logger

logger = Logger()
logger = Logger(module="lightphe/cryptosystems/Paillier.py")


class Paillier(Homomorphic):
Expand Down
Loading

0 comments on commit 2e855c0

Please sign in to comment.