Skip to content

Commit

Permalink
remove destroy_process_group() from finally wrapper as it can hang (#884
Browse files Browse the repository at this point in the history
)

* remove cleanup from finally as it can hang

* we need destroy_group for tests; at least dump error message and try to exit gracefully

* patch tests

* fix linting :(
  • Loading branch information
misko authored Oct 23, 2024
1 parent f22dbdf commit 712511f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 65 deletions.
132 changes: 67 additions & 65 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,72 +1033,74 @@ class _TrainingContext:
distutils.setup(config)
if config["gp_gpus"] is not None:
gp_utils.setup_gp(config)
try:
setup_imports(config)
trainer_name = config.get("trainer", "ocp")
# backwards compatibility for older configs
if trainer_name in ["forces", "equiformerv2_forces"]:
task_name = "s2ef"
elif trainer_name in ["energy", "equiformerv2_energy"]:
task_name = "is2re"
elif "multitask" in trainer_name:
task_name = "multitask"
else:
task_name = "ocp"

trainer_cls = registry.get_trainer_class(trainer_name)
assert trainer_cls is not None, "Trainer not found"

trainer_config = {
"model": config["model"],
"optimizer": config["optim"],
"identifier": config["identifier"],
"timestamp_id": config.get("timestamp_id", None),
"run_dir": config.get("run_dir", "./"),
"is_debug": config.get("is_debug", False),
"print_every": config.get("print_every", 10),
"seed": config.get("seed", 0),
"logger": config.get("logger", "wandb"),
"local_rank": config["local_rank"],
"amp": config.get("amp", False),
"cpu": config.get("cpu", False),
"slurm": config.get("slurm", {}),
"name": task_name,
"gp_gpus": config.get("gp_gpus"),
}

if task_name == "multitask":
trainer_config.update(
{
"tasks": config.get("tasks", {}),
"dataset_configs": config["datasets"],
"combined_dataset_config": config.get("combined_dataset", {}),
"evaluations": config.get("evaluations", {}),
}
)
else:
trainer_config.update(
{
"task": config.get("task", {}),
"outputs": config.get("outputs", {}),
"dataset": config["dataset"],
"loss_functions": config.get("loss_functions", {}),
"evaluation_metrics": config.get("evaluation_metrics", {}),
}
)
trainer = trainer_cls(**trainer_config)

task_cls = registry.get_task_class(config["mode"])
assert task_cls is not None, "Task not found"
task = task_cls(config)
start_time = time.time()
ctx = _TrainingContext(config=original_config, task=task, trainer=trainer)
yield ctx
distutils.synchronize()
if distutils.is_master():
logging.info(f"Total time taken: {time.time() - start_time}")
finally:
distutils.cleanup()
setup_imports(config)
trainer_name = config.get("trainer", "ocp")
# backwards compatibility for older configs
if trainer_name in ["forces", "equiformerv2_forces"]:
task_name = "s2ef"
elif trainer_name in ["energy", "equiformerv2_energy"]:
task_name = "is2re"
elif "multitask" in trainer_name:
task_name = "multitask"
else:
task_name = "ocp"

trainer_cls = registry.get_trainer_class(trainer_name)
assert trainer_cls is not None, "Trainer not found"

trainer_config = {
"model": config["model"],
"optimizer": config["optim"],
"identifier": config["identifier"],
"timestamp_id": config.get("timestamp_id", None),
"run_dir": config.get("run_dir", "./"),
"is_debug": config.get("is_debug", False),
"print_every": config.get("print_every", 10),
"seed": config.get("seed", 0),
"logger": config.get("logger", "wandb"),
"local_rank": config["local_rank"],
"amp": config.get("amp", False),
"cpu": config.get("cpu", False),
"slurm": config.get("slurm", {}),
"name": task_name,
"gp_gpus": config.get("gp_gpus"),
}

if task_name == "multitask":
trainer_config.update(
{
"tasks": config.get("tasks", {}),
"dataset_configs": config["datasets"],
"combined_dataset_config": config.get("combined_dataset", {}),
"evaluations": config.get("evaluations", {}),
}
)
else:
trainer_config.update(
{
"task": config.get("task", {}),
"outputs": config.get("outputs", {}),
"dataset": config["dataset"],
"loss_functions": config.get("loss_functions", {}),
"evaluation_metrics": config.get("evaluation_metrics", {}),
}
)
trainer = trainer_cls(**trainer_config)

task_cls = registry.get_task_class(config["mode"])
assert task_cls is not None, "Task not found"
task = task_cls(config)
start_time = time.time()
ctx = _TrainingContext(config=original_config, task=task, trainer=trainer)
yield ctx
distutils.synchronize()
if distutils.is_master():
logging.info(f"Total time taken: {time.time() - start_time}")

logging.debug("Task complete. Running disutils cleanup")
distutils.cleanup()
logging.debug("Runner() complete")


def _resolve_scale_factor_submodule(model: nn.Module, name: str):
Expand Down
4 changes: 4 additions & 0 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import numpy.testing as npt
import pytest
from fairchem.core.common import distutils
from test_e2e_commons import (
_run_main,
oc20_lmdb_train_and_val_from_paths,
Expand Down Expand Up @@ -285,6 +286,7 @@ def test_convert_checkpoint_and_config_to_hydra(self, configs, tutorial_val_src)
new_yaml_fn=hydra_yaml,
new_checkpoint_fn=hydra_checkpoint,
)
distutils.cleanup()

# not all models are tested with otf normalization estimation
# only gemnet_oc, escn, equiformer, and their hydra versions
Expand Down Expand Up @@ -415,6 +417,8 @@ def test_max_num_atoms(self, configs, tutorial_val_src, torch_deterministic):
input_yaml=configs["equiformer_v2_hydra"],
)

distutils.cleanup()

@pytest.mark.parametrize(
("world_size"),
[
Expand Down

0 comments on commit 712511f

Please sign in to comment.