-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separate full finetune into multi-gpu and single device recipes #482
Conversation
✅ Deploy Preview for torchtune-preview ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/482
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit b3dbe4e with merge base 20c323a (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -6,7 +6,7 @@ | |||
# Tokenizer | |||
tokenizer: | |||
_component_: torchtune.models.llama2.llama2_tokenizer | |||
path: /tmp/llama2/tokenizer.model | |||
path: /home/rvarm1/local/dev/assets/tokenizer.model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will revert these prior to land
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bumping this
@@ -227,9 +228,13 @@ def _setup_model( | |||
) | |||
|
|||
model.load_state_dict(model_state_dict) | |||
|
|||
# Validate model was loaded in with the expected dtype. | |||
utils.validate_expected_param_dtype(model, dtype=self._training_precision) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This validates that all params in the model are of the expected type. Would be useful for catching issues where some parameters dont end up as fp32, maybe due to accidental overwrite, or state_dict hook manipulating them, etc. Can take it out if needed.
recipes/full_finetune_distributed.py
Outdated
) | ||
self._optimizer.step() | ||
if log_this_iteration: | ||
get_memory_summary( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These logs seem overly intrusive and overly frequent. We're currently printing this every N steps where N is tied to how frequently we log other metrics. I don't think we need memory stats to be logged this frequently. Can we just move this to one place (eg: end of iteration) and specify this with a different frequency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Different frequency sounds good. The reason of multiple calls within iteration is to help debug memory spikes during different portion of the training. For example, if we just log once at end, we don't know if memory peaked in forward, backward, or optim step. More granular logs help clearly show where the memory usage spikes and isolates the memory debugging there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I also wonder what the best approach here is. Personally I have been copy-pasting stuff analogous to this a ton and it'd be nice to just have an easy way to configure it, so I think this is a nice step in that direction. But do agree it's a bit intrusive. While it's useful for us, do you think most users will be debugging memory spikes in forward/backward/optimizer step on a regular basis? My inclination is no, but lmk your thoughts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to end of iteration for now.
|
||
# logging attributes | ||
self._output_dir = cfg.output_dir | ||
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why 10?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Responded below!
recipes/full_finetune_distributed.py
Outdated
# logging attributes | ||
self._output_dir = cfg.output_dir | ||
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1 | ||
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why 10? I would want to log loss a lot more frequently than this right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each time a loss is logged, it requires a CPU / GPU synchronization, which in traces reveal a long GPU-side wait. I think having this explicit host sync every iteration is unnecessarily expensive. If training in some nontrivially large N, I feel like I don't lose much by logging the loss every 10 instead of 1 iteration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separately imo we should not spread our defaults across multiple places. Rn most defaults are in the yaml file, I think we should stay consistent with that here.
model.load_state_dict(model_state_dict) | ||
|
||
# Validate model was loaded in with the expected dtype. | ||
utils.validate_expected_param_dtype(model, dtype=self._training_precision) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Responded in the other comment.
self._sampler.set_epoch(curr_epoch) | ||
|
||
for idx, batch in enumerate( | ||
pbar := tqdm(self._dataloader, disable=not (rank == 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need to disable this
input_ids = input_ids.to(self._device) | ||
labels = labels.to(self._device) | ||
if log_this_iteration: | ||
get_memory_summary( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment about this as distributed recipe. Let's reduce the frequency of these logs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about log every 100 steps, but keep the frequency in terms of it being after forward, after backward, etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made it just end of iteration for now.
# logging attributes | ||
self._output_dir = cfg.output_dir | ||
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1 | ||
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 10 | ||
self._log_peak_memory_every_n_steps = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment here, just define in the config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's been discussion in the past about what should be configurable so as to not bloat configs. I'll defer to @kartikayk on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In #514, we hardcoded 100 so sticking with that in a variable for now seems reasonable.
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" | ||
) | ||
|
||
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is gloo moot if we don't support CPU training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For unittest until we have GPU support.
- FSDP and activation checkpointing. This is enabled by default but can be | ||
configured using the ``enable_fsdp`` and ``enable_activation_checkpointing`` flags. | ||
- Mixed precision training - fp32, fp16 and bf16 are supported. | ||
- Full bf16 training via setting the ``dtype`` flag to bf16. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Comment is on L38-39). We should make sure we're aligned on the right default for AC, as #514 changes the default for distributed LoRA to no AC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default has been on for memory efficiency, and I can't tell why #514 turns it off by default (doesn't appear to be in the PR description). So sticking with leaving it on for now.
|
||
# Update the number of steps when the weights are updated | ||
self.total_training_steps += 1 | ||
loss.backward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we lose grad accumulation in here somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great call (but again, CI didn't catch it, unfortunate)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think our grad accumulation test is not running for the distributed recipe. I will look into setting this up with the distributed tests
log_this_iteration = ( | ||
self.total_training_steps % self._log_every_n_steps == 0 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this isn't really making the code clearer. If anything would define a variable for self.total_training_steps % self._log_peak_memory_every_n_steps
logits = logits.transpose(1, 2) | ||
# Compute loss | ||
loss = self._loss_fn(logits, labels) | ||
if self.total_training_steps % self._log_peak_memory_every_n_steps == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering about all these logs when we have grad accumulation turned on. In that case are we logging all this stuff separately for every iteration of the step?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, made the memory log at just the end of iteration for now.
tests/recipes/test_full_finetune.py
Outdated
@@ -94,7 +95,8 @@ def fetch_checkpointer(self, ckpt): | |||
if ckpt == "small_test_ckpt_meta": | |||
return "FullModelMetaCheckpointer" | |||
|
|||
def test_loss(self, capsys, pytestconfig, tmpdir, monkeypatch): | |||
@pytest.mark.parametrize("single_device", [False]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any particular reason for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding true back in - was just testing.
torchtune/utils/__init__.py
Outdated
@@ -53,6 +53,7 @@ | |||
"transform_opt_state_dict", | |||
"validate_checkpoint", | |||
"get_autocast", | |||
"get_memory_summary", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be memory_stats_log
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two more quick comments, otherwise looks good! Just make sure to run both recipes in the final state before landing.
Memory Allocated: {torch.cuda.memory_allocated() / 1000**3:.2f} GB | ||
Memory Reserved: {torch.cuda.memory_reserved() / 1000**3:.2f} GB | ||
Peak Memory: {torch.cuda.max_memory_allocated() / 1000**3:.2f} GB | ||
def memory_stats_log( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @kartikayk, this may cause API confict with #524.
Context
dtype
flag, in accordance with RFC [RFC] Configuring low precision training in torchtune #504 is also enabled for both recipes.print_peak_memory
util to print the peak memory during training. We need to log this to wandB, will be done in follow up PRs.Changelog
Caveats
Test plan
Comparison to fp32 runs
Loss curves for bf16 and fp32 are comparable. Still need to run e2e evals for bf16 runs for both full and LoRA finetunes.
bf16 loss curve:
fp32 loss curve: