From 3691d804b502aed5248424bd5bb7ac5884fc9cee Mon Sep 17 00:00:00 2001 From: George Muraru Date: Sat, 12 Dec 2020 15:32:47 +0200 Subject: [PATCH] Fix decode --- crypten/encoder.py | 2 +- test/test_common.py | 22 +++++++++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/crypten/encoder.py b/crypten/encoder.py index c12842ec..b5f23ccc 100644 --- a/crypten/encoder.py +++ b/crypten/encoder.py @@ -64,7 +64,7 @@ def decode(self, tensor): assert is_int_tensor(tensor), "input must be a LongTensor" if self._scale > 1: correction = (tensor < 0).long() - dividend = tensor / self._scale - correction + dividend = tensor // self._scale - correction remainder = tensor % self._scale remainder += (remainder == 0).long() * self._scale * correction diff --git a/test/test_common.py b/test/test_common.py index 119d870b..572473f2 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -14,11 +14,15 @@ from crypten.encoder import FixedPointEncoder, nearest_integer_division -def get_test_tensor(max_value=10, float=False): +def get_test_tensor(max_value=10, is_float=False): """Create simple test tensor.""" - tensor = torch.LongTensor(list(range(max_value))) - if float: - tensor = tensor.float() + dtype = torch.long + step = 1 + + if is_float: + dtype = torch.float + step = 0.2 + tensor = torch.arange(start=0, end=max_value, step=step, dtype=dtype) return tensor @@ -28,7 +32,11 @@ class TestCommon(unittest.TestCase): """ def _check(self, tensor, reference, msg): - test_passed = (tensor == reference).all().item() == 1 + if tensor.dtype == torch.long: + tensor = tensor.float() + if reference.dtype == torch.long: + reference = reference.float() + test_passed = torch.allclose(tensor, reference, rtol=1/2 ** 15) self.assertTrue(test_passed, msg=msg) def test_encode_decode(self): @@ -38,7 +46,7 @@ def test_encode_decode(self): fpe = FixedPointEncoder(precision_bits=16) else: fpe = FixedPointEncoder(precision_bits=0) - tensor = get_test_tensor(float=float) + tensor = get_test_tensor(is_float=float) decoded = fpe.decode(fpe.encode(tensor)) self._check( decoded, @@ -50,7 +58,7 @@ def test_encode_decode(self): crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedFirstParty) crypten.init() - tensor = get_test_tensor(float=True) + tensor = get_test_tensor(is_float=True) encrypted_tensor = crypten.cryptensor(tensor) encrypted_tensor = fpe.encode(encrypted_tensor) self._check(