Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mikaylagawarecki committed Sep 9, 2024
1 parent 8a16c76 commit 91f7d43
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 76 deletions.
10 changes: 6 additions & 4 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ def setup(self, cfg: DictConfig) -> None:
self._compile = cfg.compile
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)

# hack to toggle to the low cpu ram version of the reparametrize_as_dtype
# hook based on the config.
common_utils._use_low_cpu_ram = cfg.get("low_cpu_ram", False)

# set up model
self._model = self._setup_model(
cfg_model=cfg.model,
Expand Down Expand Up @@ -702,12 +706,12 @@ def train(self) -> None:
prof.step()

self.epochs_run += 1
start_save_checkpoint = time.time()
start_save_checkpoint = time.perf_counter()
log.info("Starting checkpoint save...")
self.save_checkpoint(epoch=curr_epoch)
log.info(
"Checkpoint saved in {:.2f} seconds.".format(
time.time() - start_save_checkpoint
time.perf_counter() - start_save_checkpoint
)
)

Expand All @@ -725,8 +729,6 @@ def recipe_main(cfg: DictConfig) -> None:
- Overwritten by arguments from the command-line
"""
config.log_config(recipe_name="LoRAFinetuneRecipeSingleDevice", cfg=cfg)
if cfg.get("low_cpu_ram", False):
common_utils._use_low_cpu_ram = True
recipe = LoRAFinetuneRecipeSingleDevice(cfg=cfg)
recipe.setup(cfg=cfg)
recipe.train()
Expand Down
139 changes: 67 additions & 72 deletions torchtune/modules/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,82 +56,76 @@ def reparametrize_as_dtype_state_dict_post_hook(
state_dict[k] = state_dict[k].cpu()


# mmap.MAP_SHARED is not supported on Windows but this change targets colab.
if torch.__version__ >= "2.5.0.dev20240906" and not sys.platform == "win32":

def _low_ram_reparametrize_as_dtype_state_dict_post_hook(
model: nn.Module,
state_dict: Dict[str, Any],
*args: Tuple[Any, ...],
dtype: torch.dtype = torch.bfloat16,
offload_to_cpu: bool = True,
**kwargs: Dict[Any, Any],
):
"""
A state_dict hook that replaces NF4 tensors with their restored
higher-precision weight and optionally offloads the restored weight to CPU.
Use this hook to avoid increased peak GPU memory usage during checkpoint
save when training with QLoRA.
This hook is similar to ``reparametrize_as_dtype_state_dict_post_hook`` but uses
FakeTensor and mmap(2) to avoid CPU OOM on colab.
This function is meant to be used with PyTorch's ``nn.Module._register_state_dict_hook``, i.e.
>>> m = MyModule()
>>> m._register_state_dict_hook(reparametrize_as_dtype_state_dict_post_hook)
If the hook is registered per the above process, this hook will be called _after_ the module's
``state_dict`` method is called. The hook will replace all ``NF4Tensor`` instances by unquantizing
them to the original dtype, and optionally offload the restored weight to CPU.
Args:
model (nn.Module): the model to take ``state_dict()`` on
state_dict (Dict[str, Any]): the state dict to modify
*args (Tuple[Any, ...]): Unused args passed when running this as a state_dict hook.
dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``.
offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``.
**kwargs (Dict[Any, Any]): Unused keyword args passed when running this as a state_dict hook.
"""
# Create a state dict of FakeTensors that matches the state_dict
mode = FakeTensorMode()
converter = FakeTensorConverter()
fake_state_dict = OrderedDict()
for k, v in state_dict.items():
if isinstance(v, NF4Tensor):
fake_state_dict[k] = converter.from_real_tensor(mode, v).to(dtype)
else:
fake_state_dict[k] = converter.from_real_tensor(mode, v)
def _low_ram_reparametrize_as_dtype_state_dict_post_hook(
model: nn.Module,
state_dict: Dict[str, Any],
*args: Tuple[Any, ...],
dtype: torch.dtype = torch.bfloat16,
offload_to_cpu: bool = True,
**kwargs: Dict[Any, Any],
):
"""
A state_dict hook that replaces NF4 tensors with their restored
higher-precision weight and optionally offloads the restored weight to CPU.
Use this hook to avoid increased peak GPU memory usage during checkpoint
save when training with QLoRA.
if offload_to_cpu:
fake_state_dict[k] = fake_state_dict[k].cpu()

# Create a state_dict on disk with space reserved for storage bytes
# Then load with mmap and MAP_SHARED (can writeback to disk file)
dest_state_dict_path = "/tmp/fake_state_dict.pt"
with torch.serialization.skip_data(materialize_fake_tensors=True):
torch.save(fake_state_dict, dest_state_dict_path)
with torch.serialization.set_default_mmap_options(mmap.MAP_SHARED):
dest_state_dict = torch.load(
dest_state_dict_path, mmap=True, weights_only=True
)
This hook is similar to ``reparametrize_as_dtype_state_dict_post_hook`` but uses
FakeTensor and mmap(2) to avoid CPU OOM on colab.
This function is meant to be used with PyTorch's ``nn.Module._register_state_dict_hook``, i.e.
>>> m = MyModule()
>>> m._register_state_dict_hook(reparametrize_as_dtype_state_dict_post_hook)
# Do D2H and upcast one by one and since dest_state_dict is backed by mmap --> won't OOM
# even when there is no swap space (e.g. colab)
for k in state_dict.keys():
if isinstance(state_dict[k], NF4Tensor):
dest_state_dict[k].copy_(state_dict[k].to(torch.bfloat16))
else:
dest_state_dict[k].copy_(state_dict[k])
If the hook is registered per the above process, this hook will be called _after_ the module's
``state_dict`` method is called. The hook will replace all ``NF4Tensor`` instances by unquantizing
them to the original dtype, and optionally offload the restored weight to CPU.
# In place update original state_dict object. Although the private state dict
# post hook supports out of place behavior, the semantic actually buggy. We eventually want
# to use the public state_dict post hook which does not support out of place behavior.
for k in state_dict.keys():
state_dict[k] = dest_state_dict[k]
Args:
model (nn.Module): the model to take ``state_dict()`` on
state_dict (Dict[str, Any]): the state dict to modify
*args (Tuple[Any, ...]): Unused args passed when running this as a state_dict hook.
dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``.
offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``.
**kwargs (Dict[Any, Any]): Unused keyword args passed when running this as a state_dict hook.
"""
# Create a state dict of FakeTensors that matches the state_dict
mode = FakeTensorMode()
converter = FakeTensorConverter()
fake_state_dict = OrderedDict()
for k, v in state_dict.items():
if isinstance(v, NF4Tensor):
fake_state_dict[k] = converter.from_real_tensor(mode, v).to(dtype)
else:
fake_state_dict[k] = converter.from_real_tensor(mode, v)

if offload_to_cpu:
fake_state_dict[k] = fake_state_dict[k].cpu()

# Create a state_dict on disk with space reserved for storage bytes
# Then load with mmap and MAP_SHARED (can writeback to disk file)
dest_state_dict_path = "/tmp/fake_state_dict.pt"
with torch.serialization.skip_data(materialize_fake_tensors=True):
torch.save(fake_state_dict, dest_state_dict_path)
with torch.serialization.set_default_mmap_options(mmap.MAP_SHARED):
dest_state_dict = torch.load(
dest_state_dict_path, mmap=True, weights_only=True
)

# Do D2H and upcast one by one and since dest_state_dict is backed by mmap --> won't OOM
# even when there is no swap space (e.g. colab)
for k in state_dict.keys():
if isinstance(state_dict[k], NF4Tensor):
dest_state_dict[k].copy_(state_dict[k].to(dtype))
else:
dest_state_dict[k].copy_(state_dict[k])

else:
_low_ram_reparametrize_as_dtype_state_dict_post_hook = None
# In place update original state_dict object. Although the private state dict
# post hook supports out of place behavior, the semantic actually buggy. We eventually want
# to use the public state_dict post hook which does not support out of place behavior.
for k in state_dict.keys():
state_dict[k] = dest_state_dict[k]


def _register_reparametrize_state_dict_hooks(
Expand Down Expand Up @@ -159,6 +153,7 @@ def _register_reparametrize_state_dict_hooks(
"Low RAM reparametrize_as_dtype_state_dict_post_hook requires PyTorch 2.5.0.dev20240906 or later."
)
elif sys.platform == "win32":
# mmap.MAP_SHARED is not supported on Windows but this change targets colab.
raise RuntimeError(
"Low RAM reparametrize_as_dtype_state_dict_post_hook is not supported on Windows."
)
Expand Down

0 comments on commit 91f7d43

Please sign in to comment.