Skip to content

Commit

Permalink
Add tests for hydra gemnet OC scaling factor generation and loading; …
Browse files Browse the repository at this point in the history
…Raise error on fail to load scaling factors (#831)

* raise on gemnet oc fail to load scaling; add scaling tests for hydra oc aswell

* remove assert

* always raise error if scalings factors cannot be loaded

* add debug to figure out why github ci is failing

* rename for always ddp

* remove loading scaling factors for num blocks >3
  • Loading branch information
misko authored Oct 1, 2024
1 parent 21a5ea2 commit 3012925
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 9 deletions.
1 change: 0 additions & 1 deletion docs/legacy_tutorials/OCP_Tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,6 @@ model = {
"otf_graph": False,
"output_init": "HeOrthogonal",
"activation": "silu",
"scale_file": "./gemnet-dT.json",
"regress_forces": False,
"direct_forces": False,
}
Expand Down
9 changes: 8 additions & 1 deletion src/fairchem/core/modules/scaling/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,18 @@ def load_scales_compat(module: nn.Module, scale_file: str | ScaleDict | None) ->
logging.debug(
f"Found the following scale factors: {[(k, name) for k, (_, name) in scale_factors.items()]}"
)
missing_keys = set(scale_factors.keys()) - set(scale_dict.keys())
if len(missing_keys) > 0:
raise ValueError(
"Failed to load scaling values. Missing entries for,",
missing_keys,
"\nHave",
scale_dict.keys(),
)
for name, scale in scale_dict.items():
if name not in scale_factors:
logging.warning(f"Scale factor {name} not found in model")
continue

scale_module, module_name = scale_factors[name]
logging.debug(f"Loading scale factor {scale} for ({name} => {module_name})")
scale_module.set_(scale)
10 changes: 9 additions & 1 deletion src/fairchem/core/modules/scaling/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import math
import re
import readline
import sys
from itertools import islice
Expand Down Expand Up @@ -202,9 +203,16 @@ def index_fn(name: str = name) -> None:
trainer.config["cmd"]["checkpoint_dir"] = ckpt_file.parent
trainer.is_debug = False

def rename_module(name):
name = name.replace(".scale_factor", "")
# remove DDP wrapper
name = re.sub("^module.", "", name)
# remove hydra backbone
return re.sub("^backbone.", "", name)

torch.save(
{
x[0].replace(".scale_factor", ""): x[1]
rename_module(x[0]): x[1]
for x in trainer.model.to("cpu").named_parameters()
if ".scale_" in x[0]
},
Expand Down
26 changes: 20 additions & 6 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,15 @@ def smoke_test_train(
energy_from_train, energy_from_checkpoint, rtol=1e-6, atol=1e-6
)

def test_gemnet_fit_scaling(self, configs, tutorial_val_src):
@pytest.mark.parametrize(
("model_name"),
[
("gemnet_oc"),
("gemnet_oc_hydra"),
("gemnet_oc_hydra_grad"),
],
)
def test_gemnet_fit_scaling(self, model_name, configs, tutorial_val_src):

with tempfile.TemporaryDirectory() as tempdirname:
# (1) generate scaling factors for gemnet config
Expand All @@ -130,7 +138,7 @@ def test_gemnet_fit_scaling(self, configs, tutorial_val_src):
]
)
update_yaml_with_dict(
configs["gemnet_oc"],
configs[model_name],
config_yaml,
update_dict_with={
"dataset": oc20_lmdb_train_and_val_from_paths(
Expand All @@ -143,24 +151,30 @@ def test_gemnet_fit_scaling(self, configs, tutorial_val_src):
config = build_config(args, override_args)

# (2) if existing scaling factors are present remove them
if "scale_file" in config["model"]:
config["model"].pop("scale_file")
config["model"].pop("scale_file", None)
if "backbone" in config["model"]:
config["model"]["backbone"].pop("scale_file", None)

compute_scaling_factors(config)

model_config_change = (
{"backbone": {"scale_file": scaling_pt}}
if "backbone" in config["model"]
else {"scale_file": scaling_pt}
)
# (3) try to run the config with the newly generated scaling factors
_ = _run_main(
rundir=tempdirname,
update_dict_with={
"optim": {"max_epochs": 1},
"model": {"use_pbc_single": True, "scale_file": scaling_pt},
"model": model_config_change,
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
),
},
input_yaml=configs["gemnet_oc"],
input_yaml=configs[model_name],
)

def test_convert_checkpoint_and_config_to_hydra(self, configs, tutorial_val_src):
Expand Down

0 comments on commit 3012925

Please sign in to comment.