Skip to content

Commit

Permalink
fix post merge
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Aug 21, 2024
1 parent 2878bf6 commit da81a89
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
self.dtype = data.pos.dtype
self.device = data.pos.device
atomic_numbers = data.atomic_numbers.long()
assert (
atomic_numbers.max().item() < self.max_num_elements
), "Atomic number exceeds that given in model config"
graph = self.generate_graph(
data,
enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly,
Expand Down
4 changes: 2 additions & 2 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,15 @@ def test_max_num_atoms(self, configs, tutorial_val_src, torch_deterministic):
rundir=str(tempdir),
update_dict_with={
"optim": {"max_epochs": 1},
"model": {"max_num_elements": 2},
"model": {"backbone": {"max_num_elements": 2}},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
),
},
update_run_args_with=extra_args,
input_yaml=configs["equiformer_v2"],
input_yaml=configs["equiformer_v2_hydra"],
)

@pytest.mark.parametrize(
Expand Down

0 comments on commit da81a89

Please sign in to comment.