You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I finish one epoch in trianing, the main_worker function will call ts.collect_state_dict(model, state_dict).
But because the limit of GPU resource, it will raise Out of Memory in my machine, when call ts.collect_state_dict(model, state_dict).
I found that will gather the state_dict in GPU, is it anyway to gather in CPU?
The text was updated successfully, but these errors were encountered:
It is impossible to perform gather operation in cpu because the operation is based on NCCL backend. But there is a way to avoid gathering in GPU on-the-fly, that is to save state_dict of each shard locally and then write a post process script to hub them together. For example, if using 16 GPUs within 16 ranks, save 16 checkpoints during training, like model_state_rank_001.pth, model_state_rank_002.pth, … and model_state_rank_016.pth. After finishing training, write a post process script to gather these 16 checkpoints into one. Pay attention to keep right order for each shard state and run the inference test to check result.
It is impossible to perform gather operation in cpu because the operation is based on NCCL backend. But there is a way to avoid gathering in GPU on-the-fly, that is to save state_dict of each shard locally and then write a post process script to hub them together. For example, if using 16 GPUs within 16 ranks, save 16 checkpoints during training, like model_state_rank_001.pth, model_state_rank_002.pth, … and model_state_rank_016.pth. After finishing training, write a post process script to gather these 16 checkpoints into one. Pay attention to keep right order for each shard state and run the inference test to check result.
When I finish one epoch in trianing, the main_worker function will call ts.collect_state_dict(model, state_dict).
But because the limit of GPU resource, it will raise Out of Memory in my machine, when call ts.collect_state_dict(model, state_dict).
I found that will gather the state_dict in GPU, is it anyway to gather in CPU?
The text was updated successfully, but these errors were encountered: