Skip to content

Commit

Permalink
Merge pull request #10 from serengil/feat-task-0712-encrypting-floats
Browse files Browse the repository at this point in the history
float encryption support added
  • Loading branch information
serengil authored Dec 7, 2023
2 parents 7394b9e + 6ceb47a commit dc2049e
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 41 deletions.
11 changes: 8 additions & 3 deletions lightphe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lightphe.cryptosystems.NaccacheStern import NaccacheStern
from lightphe.cryptosystems.GoldwasserMicali import GoldwasserMicali
from lightphe.cryptosystems.EllipticCurveElGamal import EllipticCurveElGamal
from lightphe.commons import calculations
from lightphe.commons.logger import Logger

# pylint: disable=eval-used
Expand Down Expand Up @@ -102,15 +103,19 @@ def build_cryptosystem(
raise ValueError(f"unimplemented algorithm - {algorithm_name}")
return cs

def encrypt(self, plaintext: int) -> Ciphertext:
def encrypt(self, plaintext: Union[int, float]) -> Ciphertext:
"""
Encrypt a plaintext with a built cryptosystem
Args:
plaintext (int): message
plaintext (int or float): message
Returns
ciphertext (from lightphe.models.Ciphertext import Ciphertext): encrypted message
"""
ciphertext = self.cs.encrypt(plaintext=plaintext)
ciphertext = self.cs.encrypt(
plaintext=calculations.parse_int(
value=plaintext, modulo=self.cs.modulo or self.cs.plaintext_modulo
)
)
return Ciphertext(algorithm_name=self.algorithm_name, keys=self.cs.keys, value=ciphertext)

def decrypt(self, ciphertext: Ciphertext) -> int:
Expand Down
23 changes: 23 additions & 0 deletions lightphe/commons/calculations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from lightphe.commons.logger import Logger

logger = Logger()

# pylint: disable=no-else-return


def parse_int(value, modulo) -> int:
if isinstance(value, int) and value >= 0:
return value
elif isinstance(value, int) and value < 0:
return value % modulo
elif isinstance(value, float) and value >= 0:
decimal_places = len(str(value).split(".")[1])
scaling_factor = 10**decimal_places
integer_value = int(value * scaling_factor)
logger.debug(f"{integer_value}*{scaling_factor}^-1 mod {modulo}")
return integer_value * pow(scaling_factor, -1, modulo)
elif isinstance(value, float) and value < 0:
# TODO: think and implement this later
raise ValueError("Case constant float and negative not implemented yet")
else:
raise ValueError(f"Unimplemented case for constant type {type(value)}")
16 changes: 9 additions & 7 deletions lightphe/commons/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
from datetime import datetime

# pylint: disable=broad-except


class Logger:
def __init__(self):
Expand All @@ -15,25 +17,25 @@ def __init__(self):
)
self.log_level = logging.INFO

def debug(self, message):
if self.log_level <= logging.DEBUG:
self.dump_log(message)

def info(self, message):
if self.log_level <= logging.INFO:
self.dump_log(message)

def debug(self, message):
if self.log_level <= logging.DEBUG:
self.dump_log(f"🕷️ {message}")

def warn(self, message):
if self.log_level <= logging.WARNING:
self.dump_log(message)
self.dump_log(f"⚠️ {message}")

def error(self, message):
if self.log_level <= logging.ERROR:
self.dump_log(message)
self.dump_log(f"🔴 {message}")

def critical(self, message):
if self.log_level <= logging.CRITICAL:
self.dump_log(message)
self.dump_log(f"💥 {message}")

def dump_log(self, message):
print(f"{str(datetime.now())[2:-7]} - {message}")
36 changes: 7 additions & 29 deletions lightphe/models/Ciphertext.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from lightphe.cryptosystems.NaccacheStern import NaccacheStern
from lightphe.cryptosystems.GoldwasserMicali import GoldwasserMicali
from lightphe.cryptosystems.EllipticCurveElGamal import EllipticCurveElGamal
from lightphe.commons import calculations
from lightphe.commons.logger import Logger

logger = Logger()
Expand Down Expand Up @@ -79,7 +80,9 @@ def __mul__(self, other: Union["Ciphertext", int, float]) -> "Ciphertext":
elif isinstance(other, int):
result = self.cs.multiply_by_contant(ciphertext=self.value, constant=other)
elif isinstance(other, float):
constant = self.__convert_to_int(constant=other)
constant = calculations.parse_int(
value=other, modulo=self.cs.modulo or self.cs.plaintext_modulo
)
result = self.cs.multiply_by_contant(ciphertext=self.value, constant=constant)
else:
raise ValueError(
Expand All @@ -96,7 +99,9 @@ def __rmul__(self, constant: Union[int, float]) -> "Ciphertext":
scalar multiplication of ciphertext
"""
if isinstance(constant, float):
constant = self.__convert_to_int(constant=constant)
constant = calculations.parse_int(
value=constant, modulo=self.cs.modulo or self.cs.plaintext_modulo
)

# Handle multiplication with a constant on the right
result = self.cs.multiply_by_contant(ciphertext=self.value, constant=constant)
Expand All @@ -112,30 +117,3 @@ def __xor__(self, other: "Ciphertext") -> "Ciphertext":
"""
result = self.cs.xor(ciphertext1=self.value, ciphertext2=other.value)
return Ciphertext(algorithm_name=self.algorithm_name, keys=self.keys, value=result)

def __convert_to_int(self, constant: Union[int, float]) -> int:
"""
Convert a constant to integer if it is float or negative
"""
if hasattr(self.cs, "modulo") and self.cs.modulo:
modulo = self.cs.modulo
elif hasattr(self.cs, "plaintext_modulo") and self.cs.plaintext_modulo:
modulo = self.cs.plaintext_modulo
else:
raise ValueError("Cryptosystem must have either modulo or plaintext_modulo")

if isinstance(constant, int) and constant >= 0:
return constant
elif isinstance(constant, int) and constant < 0:
return constant % modulo
elif isinstance(constant, float) and constant >= 0:
decimal_places = len(str(constant).split(".")[1])
scaling_factor = 10**decimal_places
integer_value = int(constant * scaling_factor)
logger.debug(f"{integer_value}*{scaling_factor}^-1 mod {modulo}")
return integer_value * pow(scaling_factor, -1, modulo)
elif isinstance(constant, float) and constant < 0:
# TODO: think and implement this later
raise ValueError("Case constant float and negative not implemented yet")
else:
raise ValueError(f"Unimplemented case for constant type {type(constant)}")
18 changes: 16 additions & 2 deletions tests/test_rsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from lightphe.cryptosystems.RSA import RSA
from lightphe.commons.logger import Logger
from lightphe import LightPHE

logger = Logger()

Expand Down Expand Up @@ -36,8 +37,6 @@ def test_rsa():


def test_api():
from lightphe import LightPHE

cs = LightPHE(algorithm_name="RSA")

m1 = 17
Expand All @@ -63,3 +62,18 @@ def test_api():
_ = 5 * c1

logger.info("✅ RSA api test succeeded")


def test_float_multiplication():
cs = LightPHE(algorithm_name="RSA")

m1 = 10000
m2 = 1.05

c1 = cs.encrypt(plaintext=m1)
c2 = cs.encrypt(plaintext=m2)

# homomorphic addition
assert cs.decrypt(c1 * c2) == m1 * m2

logger.info("✅ RSA float multiplication test succeeded")

0 comments on commit dc2049e

Please sign in to comment.