Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate full finetune into multi-gpu and single device recipes #482

Merged
merged 30 commits into from
Mar 19, 2024
Merged

Conversation

rohan-varma
Copy link
Member

@rohan-varma rohan-varma commented Mar 11, 2024

Context

Changelog

  • See above

Caveats

  • As mentioned above, full memory efficiency is follow up work and not yet enabled.

Test plan

  • Run recipe tests:
  • Run full ft: tune full_finetune_single_device --config recipes/configs/alpaca_llama2_full_finetune_single_device.yaml
  • Run distributed full ft: tune --nproc_per_node 2 full_finetune_distributed --config recipes/configs/alpaca_llama2_full_finetune_distributed.yaml

Comparison to fp32 runs

Loss curves for bf16 and fp32 are comparable. Still need to run e2e evals for bf16 runs for both full and LoRA finetunes.

bf16 loss curve:

image

fp32 loss curve:

image

@rohan-varma rohan-varma marked this pull request as draft March 11, 2024 09:29
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 11, 2024
Copy link

netlify bot commented Mar 11, 2024

Deploy Preview for torchtune-preview ready!

Name Link
🔨 Latest commit b3dbe4e
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/65f9dca6265ee5000846c0fb
😎 Deploy Preview https://deploy-preview-482--torchtune-preview.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

Copy link

pytorch-bot bot commented Mar 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/482

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit b3dbe4e with merge base 20c323a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@rohan-varma rohan-varma marked this pull request as ready for review March 18, 2024 17:46
@@ -6,7 +6,7 @@
# Tokenizer
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/llama2/tokenizer.model
path: /home/rvarm1/local/dev/assets/tokenizer.model
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will revert these prior to land

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bumping this

@rohan-varma rohan-varma changed the title [WIP] Separate full finetune into multi-gpu and single device recipes Separate full finetune into multi-gpu and single device recipes Mar 18, 2024
@@ -227,9 +228,13 @@ def _setup_model(
)

model.load_state_dict(model_state_dict)

# Validate model was loaded in with the expected dtype.
utils.validate_expected_param_dtype(model, dtype=self._training_precision)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validates that all params in the model are of the expected type. Would be useful for catching issues where some parameters dont end up as fp32, maybe due to accidental overwrite, or state_dict hook manipulating them, etc. Can take it out if needed.

)
self._optimizer.step()
if log_this_iteration:
get_memory_summary(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These logs seem overly intrusive and overly frequent. We're currently printing this every N steps where N is tied to how frequently we log other metrics. I don't think we need memory stats to be logged this frequently. Can we just move this to one place (eg: end of iteration) and specify this with a different frequency?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different frequency sounds good. The reason of multiple calls within iteration is to help debug memory spikes during different portion of the training. For example, if we just log once at end, we don't know if memory peaked in forward, backward, or optim step. More granular logs help clearly show where the memory usage spikes and isolates the memory debugging there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I also wonder what the best approach here is. Personally I have been copy-pasting stuff analogous to this a ton and it'd be nice to just have an easy way to configure it, so I think this is a nice step in that direction. But do agree it's a bit intrusive. While it's useful for us, do you think most users will be debugging memory spikes in forward/backward/optimizer step on a regular basis? My inclination is no, but lmk your thoughts

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to end of iteration for now.


# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 10?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Responded below!

# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 10? I would want to log loss a lot more frequently than this right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each time a loss is logged, it requires a CPU / GPU synchronization, which in traces reveal a long GPU-side wait. I think having this explicit host sync every iteration is unnecessarily expensive. If training in some nontrivially large N, I feel like I don't lose much by logging the loss every 10 instead of 1 iteration?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately imo we should not spread our defaults across multiple places. Rn most defaults are in the yaml file, I think we should stay consistent with that here.

model.load_state_dict(model_state_dict)

# Validate model was loaded in with the expected dtype.
utils.validate_expected_param_dtype(model, dtype=self._training_precision)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Responded in the other comment.

recipes/full_finetune_single_device.py Outdated Show resolved Hide resolved
self._sampler.set_epoch(curr_epoch)

for idx, batch in enumerate(
pbar := tqdm(self._dataloader, disable=not (rank == 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need to disable this

input_ids = input_ids.to(self._device)
labels = labels.to(self._device)
if log_this_iteration:
get_memory_summary(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about this as distributed recipe. Let's reduce the frequency of these logs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about log every 100 steps, but keep the frequency in terms of it being after forward, after backward, etc?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made it just end of iteration for now.

# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 10
self._log_peak_memory_every_n_steps = 100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here, just define in the config?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's been discussion in the past about what should be configurable so as to not bloat configs. I'll defer to @kartikayk on this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #514, we hardcoded 100 so sticking with that in a variable for now seems reasonable.

"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)

init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is gloo moot if we don't support CPU training?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For unittest until we have GPU support.

Comment on lines 38 to +40
- FSDP and activation checkpointing. This is enabled by default but can be
configured using the ``enable_fsdp`` and ``enable_activation_checkpointing`` flags.
- Mixed precision training - fp32, fp16 and bf16 are supported.
- Full bf16 training via setting the ``dtype`` flag to bf16.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Comment is on L38-39). We should make sure we're aligned on the right default for AC, as #514 changes the default for distributed LoRA to no AC

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default has been on for memory efficiency, and I can't tell why #514 turns it off by default (doesn't appear to be in the PR description). So sticking with leaving it on for now.


# Update the number of steps when the weights are updated
self.total_training_steps += 1
loss.backward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we lose grad accumulation in here somewhere?

Copy link
Member Author

@rohan-varma rohan-varma Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great call (but again, CI didn't catch it, unfortunate)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our grad accumulation test is not running for the distributed recipe. I will look into setting this up with the distributed tests

Comment on lines 322 to 324
log_this_iteration = (
self.total_training_steps % self._log_every_n_steps == 0
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this isn't really making the code clearer. If anything would define a variable for self.total_training_steps % self._log_peak_memory_every_n_steps

logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)
if self.total_training_steps % self._log_peak_memory_every_n_steps == 0:
Copy link
Contributor

@ebsmothers ebsmothers Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering about all these logs when we have grad accumulation turned on. In that case are we logging all this stuff separately for every iteration of the step?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, made the memory log at just the end of iteration for now.

@@ -94,7 +95,8 @@ def fetch_checkpointer(self, ckpt):
if ckpt == "small_test_ckpt_meta":
return "FullModelMetaCheckpointer"

def test_loss(self, capsys, pytestconfig, tmpdir, monkeypatch):
@pytest.mark.parametrize("single_device", [False])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding true back in - was just testing.

@@ -53,6 +53,7 @@
"transform_opt_state_dict",
"validate_checkpoint",
"get_autocast",
"get_memory_summary",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be memory_stats_log?

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two more quick comments, otherwise looks good! Just make sure to run both recipes in the final state before landing.

Memory Allocated: {torch.cuda.memory_allocated() / 1000**3:.2f} GB
Memory Reserved: {torch.cuda.memory_reserved() / 1000**3:.2f} GB
Peak Memory: {torch.cuda.max_memory_allocated() / 1000**3:.2f} GB
def memory_stats_log(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @kartikayk, this may cause API confict with #524.

@rohan-varma rohan-varma merged commit 65aec15 into main Mar 19, 2024
21 checks passed
@joecummings joecummings deleted the ft branch April 11, 2024 15:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants