Skip to content

Commit

Permalink
chore: fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Oct 31, 2024
1 parent b188648 commit dc52fe0
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 27 deletions.
4 changes: 2 additions & 2 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
class CommonTest(ABC):
data: ClassVar[dict]
"""Arguments data."""
addtional_data: ClassVar[dict] = {}
additional_data: ClassVar[dict] = {}
"""Additional data that will not be checked."""
tf_class: ClassVar[Optional[type]]
"""TensorFlow model class."""
Expand Down Expand Up @@ -128,7 +128,7 @@ def init_backend_cls(self, cls) -> Any:

def pass_data_to_cls(self, cls, data) -> Any:
"""Pass data to the class."""
return cls(**data, **self.addtional_data)
return cls(**data, **self.additional_data)

@abstractmethod
def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def setUp(self):
self.atype.sort()

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def setUp(self):
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def setUp(self):
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def setUp(self):
self.atype.sort()

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def setUp(self):
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/model/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def pass_data_to_cls(self, cls, data) -> Any:
return get_model_pt(data)
elif cls is EnergyModelJAX:
return get_model_jax(data)
return cls(**data, **self.addtional_data)
return cls(**data, **self.additional_data)

def setUp(self):
CommonTest.setUp(self)
Expand Down
25 changes: 7 additions & 18 deletions source/tests/consistent/model/test_zbl_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@

import numpy as np

from deepmd.dpmodel.model.ener_model import EnergyModel as EnergyModelDP
from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel as DPZBLModelDP
from deepmd.dpmodel.model.model import get_model as get_model_dp
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)

from ..common import (
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
SKIP_FLAG,
Expand All @@ -26,24 +25,17 @@

if INSTALLED_PT:
from deepmd.pt.model.model import get_model as get_model_pt
from deepmd.pt.model.model.ener_model import EnergyModel as EnergyModelPT

from deepmd.pt.model.model.dp_zbl_model import DPZBLModel as DPZBLModelPT
else:
EnergyModelPT = None
DPZBLModelPT = None
if INSTALLED_TF:
from deepmd.tf.model.ener import EnerModel as EnergyModelTF
from deepmd.tf.model.linear import EnerModel as DPZBLModelTF
else:
EnergyModelTF = None
DPZBLModelTF = None
from deepmd.utils.argcheck import (
model_args,
)

if INSTALLED_JAX:
from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX
from deepmd.jax.model.model import get_model as get_model_jax
else:
EnergyModelJAX = None


@parameterized(
(
Expand Down Expand Up @@ -92,7 +84,6 @@ def data(self) -> dict:
tf_class = EnergyModelTF
dp_class = EnergyModelDP
pt_class = EnergyModelPT
jax_class = EnergyModelJAX
args = model_args()

def get_reference_backend(self):
Expand All @@ -119,7 +110,7 @@ def skip_tf(self):

@property
def skip_jax(self):
return not INSTALLED_JAX
return True

def pass_data_to_cls(self, cls, data) -> Any:
"""Pass data to the class."""
Expand All @@ -128,9 +119,7 @@ def pass_data_to_cls(self, cls, data) -> Any:
return get_model_dp(data)
elif cls is EnergyModelPT:
return get_model_pt(data)
elif cls is EnergyModelJAX:
return get_model_jax(data)
return cls(**data, **self.addtional_data)
return cls(**data, **self.additional_data)

def setUp(self):
CommonTest.setUp(self)
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/test_type_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def data(self) -> dict:
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down

0 comments on commit dc52fe0

Please sign in to comment.