Skip to content

Commit

Permalink
test: set more lossy precision requirements (#3726)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Tests**
- Updated test cases to specify precision digits directly, enhancing the
accuracy of model compression tests.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
nahso authored Apr 30, 2024
1 parent 8fb7e91 commit ee47e75
Showing 1 changed file with 19 additions and 23 deletions.
42 changes: 19 additions & 23 deletions source/tests/tf/test_model_compression_se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,36 +28,36 @@ def _file_delete(file):
os.remove(file)


# 4 tests:
# - type embedding FP64, se_atten FP64
# - type embedding FP64, se_atten FP32
# - type embedding FP32, se_atten FP64
# - type embedding FP32, se_atten FP32
tests = [
{
"se_atten precision": "float64",
"type embedding precision": "float64",
"smooth_type_embedding": True,
"precision_digit": 10,
},
{
"se_atten precision": "float64",
"type embedding precision": "float64",
"smooth_type_embedding": False,
"precision_digit": 10,
},
{
"se_atten precision": "float64",
"type embedding precision": "float32",
"smooth_type_embedding": True,
"precision_digit": 2,
},
{
"se_atten precision": "float32",
"type embedding precision": "float64",
"smooth_type_embedding": True,
"precision_digit": 2,
},
{
"se_atten precision": "float32",
"type embedding precision": "float32",
"smooth_type_embedding": True,
"precision_digit": 2,
},
]

Expand Down Expand Up @@ -158,10 +158,6 @@ def _init_models_exclude_types():
INPUTS_ET, FROZEN_MODELS_ET, COMPRESSED_MODELS_ET = _init_models_exclude_types()


def _get_default_places(nth_test):
return 10 if nth_test == 0 else 3


@unittest.skipIf(
parse_version(tf.__version__) < parse_version("2"),
f"The current tf version {tf.__version__} is too low to run the new testing model.",
Expand Down Expand Up @@ -200,7 +196,7 @@ def test_attrs(self):
for i in range(len(tests)):
dp_original = self.dp_originals[i]
dp_compressed = self.dp_compresseds[i]
default_places = _get_default_places(i)
default_places = tests[i]["precision_digit"]

self.assertEqual(dp_original.get_ntypes(), 2)
self.assertAlmostEqual(dp_original.get_rcut(), 6.0, places=default_places)
Expand All @@ -218,7 +214,7 @@ def test_1frame(self):
for i in range(len(tests)):
dp_original = self.dp_originals[i]
dp_compressed = self.dp_compresseds[i]
default_places = _get_default_places(i)
default_places = tests[i]["precision_digit"]

ee0, ff0, vv0 = dp_original.eval(
self.coords, self.box, self.atype, atomic=False
Expand All @@ -244,7 +240,7 @@ def test_1frame_atm(self):
for i in range(len(tests)):
dp_original = self.dp_originals[i]
dp_compressed = self.dp_compresseds[i]
default_places = _get_default_places(i)
default_places = tests[i]["precision_digit"]

ee0, ff0, vv0, ae0, av0 = dp_original.eval(
self.coords, self.box, self.atype, atomic=True
Expand Down Expand Up @@ -276,7 +272,7 @@ def test_2frame_atm(self):
for i in range(len(tests)):
dp_original = self.dp_originals[i]
dp_compressed = self.dp_compresseds[i]
default_places = _get_default_places(i)
default_places = tests[i]["precision_digit"]

coords2 = np.concatenate((self.coords, self.coords))
box2 = np.concatenate((self.box, self.box))
Expand Down Expand Up @@ -346,7 +342,7 @@ def test_1frame(self):
for i in range(len(tests)):
dp_original = self.dp_originals[i]
dp_compressed = self.dp_compresseds[i]
default_places = _get_default_places(i)
default_places = tests[i]["precision_digit"]

ee0, ff0, vv0 = dp_original.eval(
self.coords, self.box, self.atype, atomic=False
Expand All @@ -372,7 +368,7 @@ def test_1frame_atm(self):
for i in range(len(tests)):
dp_original = self.dp_originals[i]
dp_compressed = self.dp_compresseds[i]
default_places = _get_default_places(i)
default_places = tests[i]["precision_digit"]

ee0, ff0, vv0, ae0, av0 = dp_original.eval(
self.coords, self.box, self.atype, atomic=True
Expand Down Expand Up @@ -404,7 +400,7 @@ def test_2frame_atm(self):
for i in range(len(tests)):
dp_original = self.dp_originals[i]
dp_compressed = self.dp_compresseds[i]
default_places = _get_default_places(i)
default_places = tests[i]["precision_digit"]

coords2 = np.concatenate((self.coords, self.coords))
ee0, ff0, vv0, ae0, av0 = dp_original.eval(
Expand Down Expand Up @@ -473,7 +469,7 @@ def test_1frame(self):
for i in range(len(tests)):
dp_original = self.dp_originals[i]
dp_compressed = self.dp_compresseds[i]
default_places = _get_default_places(i)
default_places = tests[i]["precision_digit"]

ee0, ff0, vv0 = dp_original.eval(
self.coords, self.box, self.atype, atomic=False
Expand Down Expand Up @@ -505,7 +501,7 @@ def test_1frame_atm(self):
for i in range(len(tests)):
dp_original = self.dp_originals[i]
dp_compressed = self.dp_compresseds[i]
default_places = _get_default_places(i)
default_places = tests[i]["precision_digit"]

ee0, ff0, vv0, ae0, av0 = dp_original.eval(
self.coords, self.box, self.atype, atomic=True
Expand Down Expand Up @@ -535,7 +531,7 @@ def test_1frame_atm(self):

def test_ase(self):
for i in range(len(tests)):
default_places = _get_default_places(i)
default_places = tests[i]["precision_digit"]
from ase import (
Atoms,
)
Expand Down Expand Up @@ -628,7 +624,7 @@ def test_attrs(self):
for i in range(len(tests)):
dp_original = self.dp_originals[i]
dp_compressed = self.dp_compresseds[i]
default_places = _get_default_places(i)
default_places = tests[i]["precision_digit"]

self.assertEqual(dp_original.get_ntypes(), 2)
self.assertAlmostEqual(dp_original.get_rcut(), 6.0, places=default_places)
Expand All @@ -646,7 +642,7 @@ def test_1frame(self):
for i in range(len(tests)):
dp_original = self.dp_originals[i]
dp_compressed = self.dp_compresseds[i]
default_places = _get_default_places(i)
default_places = tests[i]["precision_digit"]

ee0, ff0, vv0 = dp_original.eval(
self.coords, self.box, self.atype, atomic=False
Expand All @@ -672,7 +668,7 @@ def test_1frame_atm(self):
for i in range(len(tests)):
dp_original = self.dp_originals[i]
dp_compressed = self.dp_compresseds[i]
default_places = _get_default_places(i)
default_places = tests[i]["precision_digit"]

ee0, ff0, vv0, ae0, av0 = dp_original.eval(
self.coords, self.box, self.atype, atomic=True
Expand Down Expand Up @@ -704,7 +700,7 @@ def test_2frame_atm(self):
for i in range(len(tests)):
dp_original = self.dp_originals[i]
dp_compressed = self.dp_compresseds[i]
default_places = _get_default_places(i)
default_places = tests[i]["precision_digit"]

coords2 = np.concatenate((self.coords, self.coords))
box2 = np.concatenate((self.box, self.box))
Expand Down

0 comments on commit ee47e75

Please sign in to comment.