Skip to content

Commit

Permalink
[NeMo-UX] Rename weights path during resume (#10391) (#10516)
Browse files Browse the repository at this point in the history
* rename weights path to avoid confusion



* use pathlib utils rather than os



* update resume_from_path and context_path



* address comment



---------

Signed-off-by: ashors1 <ashors@nvidia.com>
Co-authored-by: Pablo Garay <palenq@gmail.com>
  • Loading branch information
ashors1 and pablo-garay authored Sep 25, 2024
1 parent 0a4bfce commit aac89b9
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions nemo/lightning/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class AutoResume:

WEIGHTS_PATH = "weights"

def get_model_weights_path(self, path):
def get_weights_path(self, path):
return Path(path) / self.WEIGHTS_PATH

def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
Expand Down Expand Up @@ -127,9 +127,9 @@ def _try_import_model(

if adapter_path:

maybe_model_weights_path = self.get_model_weights_path(adapter_path)
if os.path.isdir(maybe_model_weights_path):
adapter_path = maybe_model_weights_path
maybe_weights_path = self.get_weights_path(adapter_path)
if maybe_weights_path.is_dir():
adapter_path = maybe_weights_path

new_path = AdapterPath(Path(adapter_path), base_model_path=new_path)

Expand Down Expand Up @@ -244,15 +244,15 @@ def get_context_path(self, model: Optional[io.ConnectorMixin] = None) -> Optiona
checkpoint = self._find_trainer_ckpt_path()

if checkpoint:
maybe_model_weights_path = Path(checkpoint) / "context"
if os.path.isdir(maybe_model_weights_path):
checkpoint = maybe_model_weights_path
maybe_context_path = Path(checkpoint) / "context"
if maybe_context_path.is_dir():
checkpoint = maybe_context_path
return checkpoint

def get_trainer_ckpt_path(self, model: Optional[io.ConnectorMixin] = None) -> Optional[Path]:
if self.resume_from_path:
maybe_model_weights_path = self.get_model_weights_path(self.resume_from_path)
return maybe_model_weights_path if os.path.isdir(maybe_model_weights_path) else self.resume_from_path
maybe_weights_path = self.get_weights_path(self.resume_from_path)
return maybe_weights_path if maybe_weights_path.is_dir() else self.resume_from_path

checkpoint = None
app_state = AppState()
Expand All @@ -261,9 +261,9 @@ def get_trainer_ckpt_path(self, model: Optional[io.ConnectorMixin] = None) -> Op
checkpoint = self._find_trainer_ckpt_path()

if checkpoint:
maybe_model_weights_path = self.get_model_weights_path(checkpoint)
if os.path.isdir(maybe_model_weights_path):
checkpoint = maybe_model_weights_path
maybe_weights_path = self.get_weights_path(checkpoint)
if maybe_weights_path.is_dir():
checkpoint = maybe_weights_path

if checkpoint:
if self.adapter_path:
Expand Down

0 comments on commit aac89b9

Please sign in to comment.