diff --git a/test/test_common.py b/test/test_common.py index d5ade626..627f126d 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -14,11 +14,15 @@ from crypten.encoder import FixedPointEncoder, nearest_integer_division -def get_test_tensor(max_value=10, float=False): +def get_test_tensor(max_value=10, is_float=False): """Create simple test tensor.""" - tensor = torch.LongTensor(list(range(max_value))) - if float: - tensor = tensor.float() + dtype = torch.long + step = 1 + + if is_float: + dtype = torch.float + step = 0.2 + tensor = torch.arange(start=0, end=max_value, step=step, dtype=dtype) return tensor @@ -28,7 +32,11 @@ class TestCommon(unittest.TestCase): """ def _check(self, tensor, reference, msg): - test_passed = (tensor == reference).all().item() == 1 + if tensor.dtype == torch.long: + tensor = tensor.float() + if reference.dtype == torch.long: + reference = reference.float() + test_passed = torch.allclose(tensor, reference, rtol=1/2 ** 15) self.assertTrue(test_passed, msg=msg) def test_encode_decode(self): @@ -38,7 +46,7 @@ def test_encode_decode(self): fpe = FixedPointEncoder(precision_bits=16) else: fpe = FixedPointEncoder(precision_bits=0) - tensor = get_test_tensor(float=float) + tensor = get_test_tensor(is_float=float) decoded = fpe.decode(fpe.encode(tensor)) self._check( decoded, @@ -50,7 +58,7 @@ def test_encode_decode(self): crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedFirstParty) crypten.init() - tensor = get_test_tensor(float=True) + tensor = get_test_tensor(is_float=True) encrypted_tensor = crypten.cryptensor(tensor) encrypted_tensor = fpe.encode(encrypted_tensor) self._check(