Skip to content

Commit

Permalink
Fix load_ignore_keys with rng (#2803)
Browse files Browse the repository at this point in the history
* fix rng load

* lint
  • Loading branch information
mvpatel2000 authored Jan 2, 2024
1 parent 910223e commit db424e5
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
2 changes: 1 addition & 1 deletion composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def _restore_checkpoint(
exclude_algorithms=exclude_algorithms,
algorithm_passes=algorithm_passes,
)
return state_dict['rng']
return state_dict.get('rng', None)


def save_checkpoint(
Expand Down
42 changes: 41 additions & 1 deletion tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from composer.optim import ExponentialScheduler
from composer.trainer import trainer
from composer.trainer.trainer import Trainer
from composer.utils import dist, is_tar
from composer.utils import dist, is_tar, reproducibility
from composer.utils.checkpoint import _ensure_valid_checkpoint, glob_filter
from composer.utils.object_store.object_store import ObjectStore
from composer.utils.object_store.s3_object_store import S3ObjectStore
Expand Down Expand Up @@ -740,6 +740,46 @@ def test_load_weights(self, device, load_weights_only, save_metrics):
if save_metrics:
assert metrics_equal

@pytest.mark.parametrize('load_ignore_keys,weights_equal,callbacks_equal,rng_equal', [
['state/model/*', False, True, True],
['state/callbacks/*', True, False, True],
['rng', True, True, False],
])
@pytest.mark.filterwarnings('ignore:.* is not in the state_dict.*:UserWarning')
def test_load_ignore_keys(self, load_ignore_keys, weights_equal, callbacks_equal, rng_equal):

trainer_1 = self.get_trainer(save_folder='first')
trainer_1.fit()
trainer_1_rng_state = reproducibility.get_rng_state()
trainer_1.close()

last_checkpoint = os.path.join('first', 'ep2.pt')
trainer_2 = self.get_trainer(
load_path=last_checkpoint,
load_ignore_keys=[load_ignore_keys],
)

# Check weights loaded properly
with contextlib.nullcontext() if weights_equal else pytest.raises(AssertionError):
self._assert_weights_equivalent(
trainer_1.state.model,
trainer_2.state.model,
)

# Check callbacks state
stateful_callbacks_equal = self._stateful_callbacks_equal(
trainer_1.state.callbacks,
trainer_2.state.callbacks,
)
if callbacks_equal:
assert stateful_callbacks_equal
else:
assert not stateful_callbacks_equal

if rng_equal:
assert trainer_1_rng_state is not None
deep_compare(trainer_1_rng_state, trainer_2._rng_state)

@pytest.mark.remote
@device('cpu')
@pytest.mark.parametrize('load_weights_only', [True, False])
Expand Down

0 comments on commit db424e5

Please sign in to comment.