Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grad accum in LoRA distributed recipe #644

Merged
merged 2 commits into from
Apr 4, 2024
Merged

Conversation

ebsmothers
Copy link
Contributor

Context

  • We have it in all our other finetune recipes, might as well add it to this one too

Implemented following the approach in full_finetune_distributed.py

Run without grad accumulation

CUDA_VISIBLE_DEVICES=5,6 tune run --nproc_per_node 2 --rdzv-backend=c10d --rdzv-endpoint=localhost:20000 lora_finetune_distributed --config llama2/7B_lora checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer checkpointer.checkpoint_dir=/data/users/ebs/checkpoints checkpointer.checkpoint_files=['llama2-7b-torchtune.pt'] checkpointer.output_dir=/data/users/ebs/checkpoints/new_tokenizer tokenizer.path=/data/users/ebs/checkpoints/lora-debug/tokenizer.model max_steps_per_epoch=20 metric_logger=torchtune.utils.metric_logging.WandBLogger metric_logger.project=testing
...
1|20|Loss: 1.3451744318008423:   0%|▎            

Can see from wandb that there are 20 iterations logged

Run with grad accumulation

CUDA_VISIBLE_DEVICES=5,6 tune run --nproc_per_node 2 --rdzv-backend=c10d --rdzv-endpoint=localhost:20000 lora_finetune_distributed --config llama2/7B_lora checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer checkpointer.checkpoint_dir=/data/users/ebs/checkpoints checkpointer.checkpoint_files=['llama2-7b-torchtune.pt'] checkpointer.output_dir=/data/users/ebs/checkpoints/new_tokenizer tokenizer.path=/data/users/ebs/checkpoints/lora-debug/tokenizer.model max_steps_per_epoch=20 gradient_accumulation_steps=2 metric_logger=torchtune.utils.metric_logging.WandBLogger metric_logger.project=testing

...
1|40|Loss: 1.5071550607681274:   0%|▌ 

Since tqdm logs iteration number, not step number, we can see that both cases run for the expected number of iterations.

Copy link

pytorch-bot bot commented Apr 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/644

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9d6b4c6 with merge base 32d66df (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 3, 2024
@ebsmothers ebsmothers changed the title grad accum in LoRA recipe grad accum in LoRA distributed recipe Apr 3, 2024
@@ -468,6 +472,17 @@ def save_checkpoint(
intermediate_checkpoint=intermediate_checkpoint,
)

def _should_update_weights(self, current_iteration: int) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a utility at this point since it's in every recipe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I just don't know where to put it tbh. None of the existing files feel relevant and I don't wanna create a new file just for this. Lmk if you have suggestions here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe torchtune/utils/training.py or torchtune/utils/optim_utils.py? I don't see any harm in making a new file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh I'm inclined to go the other way and say that we should just do this inline. I definitely do not want to add even more indirection than we already have. And this is literally just doing a mod check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also @joecummings @kartikayk or @rohan-varma if you have thoughts here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do it inline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm yeah, that's cleaner than I thought it would be. excellent choice

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean - thanks for adding this!

@kartikayk
Copy link
Contributor

Since tqdm logs iteration number, not step number

How motivated are you to fix this? :) I think it should be a two line change tbh:

pbar = tqdm(desc=f"Training Epoch: {epoch+1}", total=self._dataloader)
....
# update pbar whenever we take a step
pbar.update(1)

I think this should work and then we'll no longer have the discrepancy between tqdm and our wndb logs. Can happen in a follow up as well.

@ebsmothers
Copy link
Contributor Author

Since tqdm logs iteration number, not step number

How motivated are you to fix this? :) I think it should be a two line change tbh:

pbar = tqdm(desc=f"Training Epoch: {epoch+1}", total=self._dataloader)
....
# update pbar whenever we take a step
pbar.update(1)

I think this should work and then we'll no longer have the discrepancy between tqdm and our wndb logs. Can happen in a follow up as well.

I will probably just punt to a follow-up if it's all the same

@ebsmothers ebsmothers merged commit aacaadd into main Apr 4, 2024
20 checks passed
@ebsmothers ebsmothers deleted the lora-distributed-grad-accum branch April 4, 2024 02:19
@ebsmothers ebsmothers mentioned this pull request Apr 5, 2024
Merged
tcapelle pushed a commit to tcapelle/torchtune that referenced this pull request Apr 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants