Skip to content

Commit

Permalink
add act checkpointing to escn (#852)
Browse files Browse the repository at this point in the history
* add act checkpointing to escn

* remove re-entrant from non checkpoint

* assign output correctly

* remove escneqv2heads

* fix up comment
  • Loading branch information
misko authored Sep 24, 2024
1 parent 83fd9d2 commit 21a5ea2
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 8 deletions.
21 changes: 13 additions & 8 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
distance_resolution: float = 0.02,
show_timing_info: bool = False,
resolution: int | None = None,
activation_checkpoint: bool | None = False,
) -> None:
if mmax_list is None:
mmax_list = [2]
Expand All @@ -101,6 +102,7 @@ def __init__(
logging.error("You need to install the e3nn library to use the SCN model")
raise ImportError

self.activation_checkpoint = activation_checkpoint
self.regress_forces = regress_forces
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
Expand Down Expand Up @@ -287,22 +289,19 @@ def forward(self, data):
###############################################################

for i in range(self.num_layers):
if i > 0:
x_message = self.layer_blocks[i](
if self.activation_checkpoint:
x_message = torch.utils.checkpoint.checkpoint(
self.layer_blocks[i],
x,
atomic_numbers,
graph.edge_distance,
graph.edge_index,
self.SO3_edge_rot,
mappingReduced,
use_reentrant=not self.training,
)

# Residual layer for all layers past the first
x.embedding = x.embedding + x_message.embedding

else:
# No residual for the first layer
x = self.layer_blocks[i](
x_message = self.layer_blocks[i](
x,
atomic_numbers,
graph.edge_distance,
Expand All @@ -311,6 +310,12 @@ def forward(self, data):
mappingReduced,
)

if i > 0:
# Residual layer for all layers past the first
x.embedding = x.embedding + x_message.embedding
else:
x = x_message

# Sample the spherical channels (node embeddings) at evenly distributed points on the sphere.
# These values are fed into the output blocks.
x_pt = torch.tensor([], device=device)
Expand Down
46 changes: 46 additions & 0 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,52 @@ def test_train_and_predict(
num_workers=0,
)

# test that both escn and equiv2 run with activation checkpointing
@pytest.mark.parametrize(
("model_name"),
[
("escn_hydra"),
("equiformer_v2_hydra"),
],
)
def test_train_and_predict_with_checkpointing(
self,
model_name,
configs,
tutorial_val_src,
):
with tempfile.TemporaryDirectory() as tempdirname:
# first train a very simple model, checkpoint
train_rundir = Path(tempdirname) / "train"
train_rundir.mkdir()
checkpoint_path = str(train_rundir / "checkpoint.pt")
training_predictions_filename = str(train_rundir / "train_predictions.npz")
update_dict = {
"optim": {
"max_epochs": 2,
"eval_every": 8,
"batch_size": 5,
"num_workers": 1,
},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
),
}
if "hydra" in model_name:
update_dict["model"] = {"backbone": {"activation_checkpoint": True}}
else:
update_dict["model"] = {"activation_checkpoint": True}
acc = _run_main(
rundir=str(train_rundir),
input_yaml=configs[model_name],
update_dict_with=update_dict,
save_checkpoint_to=checkpoint_path,
save_predictions_to=training_predictions_filename,
world_size=1,
)

def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic):
with tempfile.TemporaryDirectory() as tempdirname:
tempdir = Path(tempdirname)
Expand Down

0 comments on commit 21a5ea2

Please sign in to comment.