Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather #413

Merged
merged 24 commits into from
Jul 16, 2024

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Jun 19, 2024

we have landed fp8 all-gather optimizations in float8_experimental pytorch-labs/float8_experimental#266

this PR proposes torchtitan changes. also include fp8 in CI

from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)

FSDP2 fp8 all-gather are added to CI

CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp

TP fp8 all-gather are locally tested. will add them to CI after uploading a new tokenizer with vacab size 2560 (divisible by 16)

CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2

precompute scales after optimizer.step
Screenshot 2024-07-12 at 5 11 14 PM

FSDP2 pre-all-gather do not have any small all-reduces
Screenshot 2024-07-12 at 5 13 04 PM

TODO

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 19, 2024
@weifengpy weifengpy marked this pull request as draft June 19, 2024 22:25
weifengpy and others added 12 commits June 19, 2024 16:05
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summaiy:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy changed the title [DO NOT REVIEW] fsdp fp8-all-gather enable FSDP2 + fp8 all-gather Jul 13, 2024
@@ -273,6 +273,39 @@ def build_test_list():
"fsdp2_mem_tracker",
ngpu=4,
),
OverrideDefinitions(
Copy link
Contributor Author

@weifengpy weifengpy Jul 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added followings to CI

  • 1D fsdp original dtype all-gather
  • 1D fsdp fp8 all-gather
  • 1D fsdp fp8 all-gather with precomputed dynamic scales

need follow ups to enable TP fp8 all-gather in CI: current CI tokenizer has 2556, not divisible by 16) #461

  • 1D TP fp8 all-gather
  • 2D FSDP + TP fp8 all-gather

@weifengpy weifengpy marked this pull request as ready for review July 13, 2024 00:09
test_runner.py Outdated
"--training.fp8_linear",
]
],
"FSDP2 with bf16 all-gather",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe noob question: does all-gather always happen in bf16, or is it determined by param_dtype of FSDP2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question! I should change this to "all-gather in original dtype"

when mixed_precision is turned off (no param_dtype configed), FSDP2 communicate in model's original dtype

when mixed_precision is turned on (param_dtype=xxx), FSDP2 communicates according to param_dtype

test_runner.py Show resolved Hide resolved
test_runner.py Show resolved Hide resolved
@@ -398,6 +399,9 @@ def loss_fn(pred, labels):
optimizers.step()
lr_schedulers.step()

if job_config.training.precompute_float8_dynamic_scale_for_fsdp:
precompute_float8_dynamic_scale_for_fsdp(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a noob question: could you briefly explain what this is doing?
I wonder since we are already using context functions for FP8, can we have a context and run it in a .step() function here, just like optimizer, lr scheduler, and profiler. This would make the code consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you briefly explain what this is doing

precompute_float8_dynamic_scale_for_fsdp is a for-loop over model.parameters(). it issues a single all-reduce for all parameters, ie abs(max(param)) for param in model.parameters() and save amax/scale as param._precomputed_scale. this speed up the training loop since we do not need to compute amax/scale for each parameters in the training loop

we are already using context functions for FP8

do you refer to set_enable_fsdp_fp8_all_gather ? That's for model intiaitialization where we swap nn.Linear with user-defined float8 linear. precompute_float8_dynamic_scale_for_fsdp is for training loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per suggestion, raise error if use_fp8_linear=False or enable_fsdp_fp8_all_gather =False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob q: do we eventually want to just put this in fsdp2?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has to be done after optimizer step (since parameter values change). Are you suggesting to run this in the root module's pre-forward?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah anywhere between the n-1th optimizer step and the first all-gather in the nth step where fsdp2 has control (if there's any).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I think one concern is that FSDP is agnostic to the fp8 all-gather. FSDP does not know that the fsdp_pre_all_gather and fsdp_post_all_gather of the Float8Linear.weights are implemented to do fp8 all-gather, so at best, the user still would need to register a module forward pre-hook or something to run this method.

Copy link
Contributor

@yifuwang yifuwang Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Somehow I thought fsdp2 was fp8-aware

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy requested a review from tianyu-l July 15, 2024 21:25
@@ -398,6 +404,17 @@ def loss_fn(pred, labels):
optimizers.step()
lr_schedulers.step()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment to explain precompute_float8_dynamic_scale_for_fsdp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments to record what we discussed offline.

@@ -347,6 +347,18 @@ def __init__(self):
here: https://github.com/pytorch-labs/float8_experimental
""",
)
self.parser.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, let's refactor fp8 configs, e.g. have a dedicated field for enabling fp8 or not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed fp8_linear to enable_fp8_linear

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one thing to note is that right now this is a boolean which will swap to the default float8 recipe
Dynamic scaling x Tensor wise ScalingGranularity x all tensors involved in the matmul [ input, weight, grad]

I think we should brainstorm on an elegant solutions for users to express their desired config here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question. evetually we might have to expose args/kwargs from swap_linear_with_float8_linear for flexibility

@@ -27,8 +44,8 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig):
This will mutate the model inplace.
"""
use_fp8_linear = job_config.training.fp8_linear
enable_fsdp_fp8_all_gather = job_config.training.enable_fsdp_fp8_all_gather
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline: please check if it makes sense to enable it only when dp_degree > 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added check on parallel_dims.dp_enabled

train.py Outdated
@@ -218,6 +219,11 @@ def loss_fn(pred, labels):
# apply fp8 linear module swap
if job_config.training.fp8_linear:
build_fp8_linear(whole_model, job_config)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove this to favor simplicity if it is a no-op flag when fp8_linear=False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed ValueError on enable_fp8_linear=False

train.py Outdated
@@ -398,6 +404,17 @@ def loss_fn(pred, labels):
optimizers.step()
lr_schedulers.step()

if job_config.training.precompute_float8_dynamic_scale_for_fsdp:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline: can refactor to make it simpler

Copy link
Contributor Author

@weifengpy weifengpy Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed ValueError when enable_fp8_linear/enable_fsdp_fp8_all_gather=False

@weifengpy weifengpy marked this pull request as draft July 16, 2024 00:38
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
"""Get the parallel strategy for the transformer model.

This function handles the special case of using float8 with tensor parallelism.
"""
if job_config.training.fp8_linear == "dynamic":
Copy link
Contributor Author

@weifengpy weifengpy Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8_linear == "dynamic" is outdated after recent unification of dynamic/delayed scaling (Float8Linear)
#436

update it in this PR to make TP fp8 all-gather work again

EDIT: Will enable TP in CI to prevention after having a new tokenizer with vacab size 2560

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noop Q: do we need this in a context manager to make testing + resetting easier?

Copy link
Contributor Author

@weifengpy weifengpy Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm. set_enable_fsdp_fp8_all_gather is a context manager right now. do you mean "why" it should be a context manager ?

EDIT: I also see you mentioned "make testing + resetting easier", which answered why. so I am not sure if it's a question for me

and m.scaling_type_w is TensorScalingType.DELAYED
for m in module.modules()
):
raise NotImplementedError("Only supports delayed scaling")
Copy link
Contributor

@drisspg drisspg Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be "only supports dynamic scaling" right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! will change it

test_runner.py Outdated
"--training.tensor_parallel_degree 2",
]
],
"FSDP2 with fp8 all-gather and precomputed dynamic scales",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment for 2D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the end I have to remove 2D from this PR. current CI tokenizer has vacab size = 2556. However, fp8 gemm need the vacab size to be divisible by 16 #461

I can follow up with you on how to have a tokenizer with vacab size = 2560 to unblock 1D TP + fp8, and 2D + fp8 in CI

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy marked this pull request as ready for review July 16, 2024 19:36
@weifengpy weifengpy requested a review from tianyu-l July 16, 2024 19:36
@weifengpy weifengpy changed the title enable FSDP2 + fp8 all-gather enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather Jul 16, 2024
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR looks good to me, except that there is an ongoing discussion on what's the best way to support fp8 GEMM on tensors of irregular shapes (when not divisible by 16) e.g. coming from TP + test tokenizer model in this repo (see #461). Maybe we should allow irregular shapes but with potential perf regression. Will need to discuss more on this later.

"""
This function converts the linear layers to `Float8Linear`. Note that today,
only dynamic tensor scaling (the default) is supported.

This will mutate the model inplace.
"""
use_fp8_linear = job_config.training.fp8_linear
enable_fp8_linear = job_config.training.enable_fp8_linear
if not enable_fp8_linear:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove this redundant check

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy merged commit f025335 into pytorch:main Jul 16, 2024
3 of 5 checks passed
@@ -19,6 +19,7 @@

import torch
import torch.nn.functional as F
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
Copy link
Contributor

@wanchaol wanchaol Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weifengpy I think we should hide this import to the path where enable_fp8_allgather path happened?

The problem here is that for every feature that requires an additional install from other dependency, we should try to hide the import to the path that uses it instead of import it globally, otherwise for users who didn't install the float8_experimental, if they rebase, and it would just fail to train for them.

Please submit a follow up PR to fix this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got you. I am moving it from top-level to if-else now #464

thanks for the timely reminder

tianyu-l added a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 13, 2024
* Set `record_shapes=True` for profiler

ghstack-source-id: 6f1ed49d15ce311f1bf118820965cdb5309a8030
Pull Request resolved: pytorch#419

* Improved `repeat_kv` eager perf

ghstack-source-id: 39e484954814e61cdfb2ba661f0a98c83bc0ce60
Pull Request resolved: pytorch#418

* Adding FSDP Memory Tracking and Estimation

ghstack-source-id: c8ed20fc585957bd164dd963307616a53991615d
Pull Request resolved: pytorch#425

* Adding integration test for FSDP Memory Tracking and Estimation

ghstack-source-id: cc224db8951ec7a133fd769845a4765cbedc6454
Pull Request resolved: pytorch#426

* by default disable heavy memory profiling

ghstack-source-id: cad7b3c41fd60ec19c0e6e7d058e8aa00602a187
Pull Request resolved: pytorch#430

* Add the option to turn on async-TP

ghstack-source-id: 0a03379eeb3a63b2d1ad4dff84d0e61ca82b1bbf
Pull Request resolved: pytorch#429

* Modifying memory estimation options and minor changes

ghstack-source-id: 5f09824cddaed6585cc094095e1e95dd070d76f4
Pull Request resolved: pytorch#435

* add comment pointing to Sequence Parallel optimization example

ghstack-source-id: 6fa0dcd4bca876e10a6a8349283fb940a59ad234
Pull Request resolved: pytorch#438

* switch float8 logic from Float8DynamicLinear to Float8Linear (pytorch#436)

Summary:

After pytorch-labs/float8_experimental#300,
`Float8Linear` with default settings is equivalent to
`Float8DynamicLinear`. This PR changes `torchtitan` to use
`Float8Linear`.

To support the new UX of `float8_experimental` better, I also switched
the `fp8_linear` configuration to be a boolean on whether to swap the
linears or not. In the future we can add new options on how to configure
each linear (scaling type, scaling granularity, etc) - saving that for a
future PR.

Test Plan:

```
// run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs,
// verify performance and loss values do not change meaningfully between
// baseline and this PR

// baseline (before this PR)
// 1. compile, bf16
// 2. compile, float8
// 3. compile, float8, fdsp_fp8_allgather=True
// 4. compile, float8, fdsp_fp8_allgather=True, tp=2
// logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce

// experiment (this PR): repeat all of the above, but with Float8Linear
// logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Removed `_experimental_support_context_fn_in_torch_utils_checkpoint`

ghstack-source-id: 50b2d0c2b4c22e2f045cafd8630c16f3a8c6d35f
Pull Request resolved: pytorch#444

* Reordered TP parallel plan to follow execution order

ghstack-source-id: b4924952adeb5f16d08b60faa54690762841c422
Pull Request resolved: pytorch#445

* Made some stylistic changes to `apply_dp`

ghstack-source-id: fb78e9eb8aa406ba87d6ad6cf2229c1027dae42f
Pull Request resolved: pytorch#446

* Refactored activation checkpointing

ghstack-source-id: 785c7e47651cda97ea22d0147d14b8d061ce042d
Pull Request resolved: pytorch#447

* compiled RMSNorm

ghstack-source-id: c4efb81ec6acc5442955908cc376df3e6d889af3
Pull Request resolved: pytorch#442

* Renamed parallel styles for transformer block weights

ghstack-source-id: 5fb0bf3d08cacf27242ec0f85d5dd3cdc03b739e
Pull Request resolved: pytorch#448

* Added type annotations and more stylistic changes

ghstack-source-id: 1bd5b9d5abc8644785132f8eb2baaf8b1cfc5fb5
Pull Request resolved: pytorch#449

* [Cleanup] Remove libuv from run_llama_train.sh

libuv is now enabled by default.

we can proably do without the educational blurb there, and don't need
the env either since the default has landed.

ghstack-source-id: 68c8d2abe7eb0777e2add8df7634367c31b7ec06
Pull Request resolved: pytorch#453

* [Cleanup] Organize run_llama_train.sh options

Just a little code motion but it looks cleaner to me this way

ghstack-source-id: 055fbd557cd9cf189e6b9bd6a7048f1204e1dc5c
Pull Request resolved: pytorch#454

* [Cleanup] Split run_llama_train.sh and run_memory_estimation.sh

Make each script simpler to read

ghstack-source-id: ba3aa65feb6e304736c73daf5bc8ab5fb254f196
Pull Request resolved: pytorch#455

* [Cleanup] Remove unused TRAINER_DIR

This argument seems to be left over from older times- it is not used
anywhere in the codebase.

ghstack-source-id: abbcf82ed4d1b8fbb71c6a6b48acbc1296dbec64
Pull Request resolved: pytorch#456

* Add educational code pointers to top level README

ghstack-source-id: 522aa2fa0bf1679f55d9f3a8a38fdcd319d5e3df
Pull Request resolved: pytorch#457

* enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather (pytorch#413)

we have landed fp8 all-gather optimizations in float8_experimental
pytorch-labs/float8_experimental#266

this PR proposes torchtitan changes. also include fp8 in CI
```
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
```

FSDP2 fp8 all-gather are added to CI
```
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp
```

TP fp8 all-gather are locally tested. will add them to CI after
uploading a new tokenizer with vacab size 2560 (divisible by 16)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2
```

precompute scales after optimizer.step
<img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM"
src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817">

FSDP2 pre-all-gather do not have any small all-reduces
<img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM"
src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08">

TODO
* upload tokenizer with vacab size 2560 to enable CI on TP fp8
all-gather
* torch.compile complains about fp8
* add delayed scaling and brainstorm about best config option to express
fp8
* compare perf between delayed scaling and dynamic scaling
https://github.com/pytorch-labs/float8_experimental/pull/312/files

* import float8_experimental only when fp8 is enabled and install it in CI (pytorch#464)

make sure to only import float8_experimental when fp8 is enabled

for 4 gpu CI, make sure we can import float8_experimental correctly in
CI

`python -m pip install
git+https://github.com/pytorch-labs/float8_experimental.git`

* skip fp8 CI on non-H100 GPUs (pytorch#465)

skip fp8 tests on non-H100 GPUs by checking
`torch.cuda.get_device_capability() >= (9, 0)`

this makes 4 GPU CI healthy again

* clean up float8 configs in torchtitan (pytorch#466)

Summary:

1. standardizes on `float8` instead of `fp8` for config names
2. removes usage of non-public objects such as `Float8Linear`

Test Plan:

```
with-proxy NGPU=1 CUDA_VISIBLE_DEVICES=7 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.compile --training.enable_float8_linear
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Add support of DDP and experimental CompiledAutograd

Summary:
Address the comments in pytorch#319 and resubmit the PR to fit the current code base.

Test Plan:
```
CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600   --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000
```

ghstack-source-id: 81dc85d42df13df4ed727bebd825681879af936b
Pull Request resolved: pytorch#432

* add torch.compile + FSDP2 float8 all-gather in CI (pytorch#468)

fixed my bug in float8_experimental. now we can torch.compile
transfromer blocks with FSDP float8 all-gather
pytorch-labs/float8_experimental#321

local test: `CONFIG_FILE="./train_configs/debug_model.toml"
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp --training.compile`

profiler traces: I can see compiled region in cpu thread and float8
malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream
<img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM"
src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129">

* [float8] keep model.output as `nn.Linear` (high precision, not fp8) (pytorch#469)

**keep model.output as nn.Linear**: it's a common practice to NOT apply
fp8 on final output layer
* specify `skip_fqn_list` in swapping
* when applying TP to model.output, use plain `ColwiseParallel` instead
of `Float8ColwiseParallel`

credit to @awgu, we do not need tokentizer vacab size to be divisible by
16 pytorch#461

1D TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4`

1D TP + float8 all-gather, compile mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4
--training.compile`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2 --training.compile`

1D TP + float8 all-gather trace: see float8 and all-gather in the trace
<img width="1611" alt="Screenshot 2024-07-19 at 1 16 59 PM"
src="https://github.com/user-attachments/assets/9a95dfd9-40e0-4133-b2bb-e22ddf5b8472">

2D + float8 all-gather trace: see float8 and FSDP collectives and TP
collectives
<img width="1038" alt="Screenshot 2024-07-19 at 1 29 59 PM"
src="https://github.com/user-attachments/assets/6a34bcaa-bcae-402b-9994-cc892554fec7">

* remove CI for FSDP2 + fp8 all-gather (pytorch#470)

per discussion from
pytorch#469 (comment)

we are planning BC breaking changes in float8_experimental. remove CI
for FSDP2 + fp8 all-gather for now. When public APIs are finalized, we
can discuss bringing it back

* dynamically update torch.compile cache config to ensure async tp support, enhance async tp UX (pytorch#471)

This PR adds some enhancements for supporting async tp:

1 - if async tp is active, auto updates the torch.dynamo cache limit to
10K. If this is not updated, async tp will not be activated on larger
models as it will quietly stop compilation due to 'cache limit reached'
with no info for the user.
This config update is logged. 

2 - if async tp is enabled, verifies that torch.compile is set to true
for this job config. If not, it warns and then activates torch.compile
to ensure user gets working async tp. (see WARNING in below screenshot)

<img width="1345" alt="Screenshot 2024-07-20 at 4 33 04 PM"
src="https://github.com/user-attachments/assets/26e5a48e-4bb8-4f33-b1b5-8939c1517c1d">

3 - Updates the 'Applied Tensor Parallel' to the model to be 'Applied
Async Tensor Parallel' when async tp is active to make it clear in the
logs which TP is active. (see above screenshot)

* Fix 8gpu PP failure due to 2D DCP disablement

DCP recently added safeties to avoid using it for 2D/3D since strided
sharding (a feature needed for safe 2D/3D resharding) is not ready yet.

PP uses DCP to load a seed checkpoint.  Disabling the safety mechanism
is enough to make 3D/PP still work (for the case where we train from the
beginning or do not re-shard.

(Resharding refers to saving a checkpoint from one world
size/parallelism config and loading/resuming under a different one).

ghstack-source-id: c069d2186c79517c72f5b3c99485cebdc15df08f
Pull Request resolved: pytorch#460

* update float8 integration after UX changes (pytorch#484)

Summary:

float8_experimental landed various BC-breaking UX changes last week.
This PR updates torchtitan to work with the version of
float8_experimental after
pytorch-labs/float8_experimental#332 and
pytorch-labs/float8_experimental#337

Test Plan:

```
with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Re-enable FSDP2 Mem Tracker integration tests

ghstack-source-id: 8344603f7a5596cb2909c9bf04dd1b9e4730c9b8
Pull Request resolved: pytorch#485

* Used `partial` instead of global vars for LR scheduling

ghstack-source-id: 12c4418b0574d93e1441f4ca3d1de79c8aad7a40
Pull Request resolved: pytorch#487

* [EZ] Add logs for some basic training params so that we can verify in… (pytorch#491)

As title, while testing on 405B model, I found that we need to somehow
need the logs for some training params. So added some here. Tested
locally and the logging is shown as in the screenshot:


<img width="900" alt="image"
src="https://github.com/user-attachments/assets/b94e34f5-3e88-4c5f-94ed-75f50dde9786">

* make float8 scaling type configurable (pytorch#489)

Summary:

Adds config options to configure float8 scaling type for input, weight,
grad_output.

Performance is not ideal yet, but that's because we have not optimized
it.

Test Plan:

```
// repeat for input, weight, grad_out
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.float8_scaling_type_weight delayed --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [PP] add flexible interleaved 1f1b schedule pytorch#490 (pytorch#493)

This was approved in pytorch#490, but
merged into the wrong branch, merging this into main

* move float8 callsites to torchao.float8 (pytorch#492)

Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [BE][1/n] simplify train.py

ghstack-source-id: 3879e764e7b33afde5d778810c71d1d2a8f82f6d
Pull Request resolved: pytorch#494

* [BE][2/n] use proper method signatures in parallelize_llama

ghstack-source-id: 17a1ee9f03f13423a30183c5c8d7ad30f8c8dbfc
Pull Request resolved: pytorch#495

* [BE][3/n] wrap fp8 logic using Float8Handler

ghstack-source-id: e94c7f6f4fad87c5432262c54beabd02de5541b8
Pull Request resolved: pytorch#496

* Bring LLaMa 3.1 405B to TorchTitan family (pytorch#481)

With the official launch of LLaMa 3.1 model, we want to add the config
to TorchTitan. Of course, there are more work to be done, but we want to
go an incremental way. So more PRs will be needed.

For now, we try on 128 GPUs with current config (TP=8, FSDP=16). The
perf number is wps: 109 mfu: 29%.

Loss curve for 3000 steps with 600 warmup (lr = 0.8e-4).
<img width="1037" alt="image"
src="https://github.com/user-attachments/assets/f57dd3fa-07d8-4ef4-8f68-8f7a08e9652e">


Loss curve for 3000 steps with 600 warmup (lr = 1.1e-4).

![image](https://github.com/user-attachments/assets/429b9738-94cb-4b37-90ef-049a5587ddd0)

* [TP] Infer local n_heads instead of ad-hoc model changes

ghstack-source-id: 587e3d6e5270714ca734b8031ce41a962e6394ea
Pull Request resolved: pytorch#498

* some compile-related updates

ghstack-source-id: 63af8025c184fd5ad34f2f57bf78a37dda2cd33d
Pull Request resolved: pytorch#443

* [EZ][405B] Use scientific notation for 405B model lr (pytorch#504)

As title, use `8e-5` rather than `0.8e-4`.

* [BE][4/n] split pipeline_llama into a separate file

ghstack-source-id: 5ebb4adf3152f413fa33a923c272c9aa3ce1f775
Pull Request resolved: pytorch#499

* [fix] float8 should be applied on all model_parts

ghstack-source-id: 52ed6836de39e82c4c5824a40ecfc1d9ec7ed2bd
Pull Request resolved: pytorch#500

* Add warning to compile rmsnorm (pytorch#505)

as titled, add warning to compile rmsnorm as it's not fully ready yet,
i.e. this issue pytorch#497

We can remove this warning once we fix the issue

* add float8 to README (pytorch#509)

add float8 link in README so we can redirect people from dev-discuss
post to torchtitan repo


README looks like this after rendering
<img width="518" alt="Screenshot 2024-08-06 at 5 42 10 PM"
src="https://github.com/user-attachments/assets/50af99d7-93be-459a-89d7-8c08b8fb95d4">

float8.md looks like this
<img width="563" alt="Screenshot 2024-08-06 at 5 04 17 PM"
src="https://github.com/user-attachments/assets/06d30aad-4133-4cec-9037-cfcf155b45c4">

I tried the command locally and traces are looking good
<img width="726" alt="Screenshot 2024-08-06 at 5 00 00 PM"
src="https://github.com/user-attachments/assets/bdfa3d7e-efe1-4009-92a1-0f5c310013fb">

* address TODOs as 2D recompiles is fixed

ghstack-source-id: 2927f0a8082171da3e9f59a5d04f8325cbdf3653
Pull Request resolved: pytorch#508

* [BE][5/n] simply pp vs. non-pp set up

ghstack-source-id: 003bfbfbcf1511ddbd18e15d031b39f597d8e7db
Pull Request resolved: pytorch#510

* [BE][6/n] replace large c4_mini datasets by c4_test with the first 2K entries

ghstack-source-id: 319f4961b092778703101b98937803073132afa1
Pull Request resolved: pytorch#512

* Create composability.md (pytorch#511)

Explain the rationale and challenges behind certain changes we made to
llama model to support 3D parallelism.

---------

Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>

* depend on torchdata 0.8.0 instead of nightly

ghstack-source-id: 1965d3122885fed3c28e2e058c55581187e7816c
Pull Request resolved: pytorch#513

---------

Co-authored-by: Andrew Gu <andgu@fb.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@meta.com>
Co-authored-by: Yifu Wang <yifu@fb.com>
Co-authored-by: Vasiliy Kuznetsov <vkuzo@users.noreply.github.com>
Co-authored-by: Will Constable <whc@meta.com>
Co-authored-by: Wei (Will) Feng <134637289+weifengpy@users.noreply.github.com>
Co-authored-by: Chien-Chin Huang <chienchin@fb.com>
Co-authored-by: Less Wright <lessw@etrillium.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@fb.com>
Co-authored-by: Hugo <6937752+fduwjj@users.noreply.github.com>
Co-authored-by: Howard Huang <howardhuang96@gmail.com>
Co-authored-by: Ke Wen <kw2501@meta.com>
Co-authored-by: Wanchao <wanchaol@users.noreply.github.com>
Co-authored-by: Will Constable <willconstable@gmail.com>
tianyu-l added a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 13, 2024
* Set `record_shapes=True` for profiler

ghstack-source-id: 6f1ed49d15ce311f1bf118820965cdb5309a8030
Pull Request resolved: pytorch#419

* Improved `repeat_kv` eager perf

ghstack-source-id: 39e484954814e61cdfb2ba661f0a98c83bc0ce60
Pull Request resolved: pytorch#418

* Adding FSDP Memory Tracking and Estimation

ghstack-source-id: c8ed20fc585957bd164dd963307616a53991615d
Pull Request resolved: pytorch#425

* Adding integration test for FSDP Memory Tracking and Estimation

ghstack-source-id: cc224db8951ec7a133fd769845a4765cbedc6454
Pull Request resolved: pytorch#426

* by default disable heavy memory profiling

ghstack-source-id: cad7b3c41fd60ec19c0e6e7d058e8aa00602a187
Pull Request resolved: pytorch#430

* Add the option to turn on async-TP

ghstack-source-id: 0a03379eeb3a63b2d1ad4dff84d0e61ca82b1bbf
Pull Request resolved: pytorch#429

* Modifying memory estimation options and minor changes

ghstack-source-id: 5f09824cddaed6585cc094095e1e95dd070d76f4
Pull Request resolved: pytorch#435

* add comment pointing to Sequence Parallel optimization example

ghstack-source-id: 6fa0dcd4bca876e10a6a8349283fb940a59ad234
Pull Request resolved: pytorch#438

* switch float8 logic from Float8DynamicLinear to Float8Linear (pytorch#436)

Summary:

After pytorch-labs/float8_experimental#300,
`Float8Linear` with default settings is equivalent to
`Float8DynamicLinear`. This PR changes `torchtitan` to use
`Float8Linear`.

To support the new UX of `float8_experimental` better, I also switched
the `fp8_linear` configuration to be a boolean on whether to swap the
linears or not. In the future we can add new options on how to configure
each linear (scaling type, scaling granularity, etc) - saving that for a
future PR.

Test Plan:

```
// run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs,
// verify performance and loss values do not change meaningfully between
// baseline and this PR

// baseline (before this PR)
// 1. compile, bf16
// 2. compile, float8
// 3. compile, float8, fdsp_fp8_allgather=True
// 4. compile, float8, fdsp_fp8_allgather=True, tp=2
// logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce

// experiment (this PR): repeat all of the above, but with Float8Linear
// logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Removed `_experimental_support_context_fn_in_torch_utils_checkpoint`

ghstack-source-id: 50b2d0c2b4c22e2f045cafd8630c16f3a8c6d35f
Pull Request resolved: pytorch#444

* Reordered TP parallel plan to follow execution order

ghstack-source-id: b4924952adeb5f16d08b60faa54690762841c422
Pull Request resolved: pytorch#445

* Made some stylistic changes to `apply_dp`

ghstack-source-id: fb78e9eb8aa406ba87d6ad6cf2229c1027dae42f
Pull Request resolved: pytorch#446

* Refactored activation checkpointing

ghstack-source-id: 785c7e47651cda97ea22d0147d14b8d061ce042d
Pull Request resolved: pytorch#447

* compiled RMSNorm

ghstack-source-id: c4efb81ec6acc5442955908cc376df3e6d889af3
Pull Request resolved: pytorch#442

* Renamed parallel styles for transformer block weights

ghstack-source-id: 5fb0bf3d08cacf27242ec0f85d5dd3cdc03b739e
Pull Request resolved: pytorch#448

* Added type annotations and more stylistic changes

ghstack-source-id: 1bd5b9d5abc8644785132f8eb2baaf8b1cfc5fb5
Pull Request resolved: pytorch#449

* [Cleanup] Remove libuv from run_llama_train.sh

libuv is now enabled by default.

we can proably do without the educational blurb there, and don't need
the env either since the default has landed.

ghstack-source-id: 68c8d2abe7eb0777e2add8df7634367c31b7ec06
Pull Request resolved: pytorch#453

* [Cleanup] Organize run_llama_train.sh options

Just a little code motion but it looks cleaner to me this way

ghstack-source-id: 055fbd557cd9cf189e6b9bd6a7048f1204e1dc5c
Pull Request resolved: pytorch#454

* [Cleanup] Split run_llama_train.sh and run_memory_estimation.sh

Make each script simpler to read

ghstack-source-id: ba3aa65feb6e304736c73daf5bc8ab5fb254f196
Pull Request resolved: pytorch#455

* [Cleanup] Remove unused TRAINER_DIR

This argument seems to be left over from older times- it is not used
anywhere in the codebase.

ghstack-source-id: abbcf82ed4d1b8fbb71c6a6b48acbc1296dbec64
Pull Request resolved: pytorch#456

* Add educational code pointers to top level README

ghstack-source-id: 522aa2fa0bf1679f55d9f3a8a38fdcd319d5e3df
Pull Request resolved: pytorch#457

* enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather (pytorch#413)

we have landed fp8 all-gather optimizations in float8_experimental
pytorch-labs/float8_experimental#266

this PR proposes torchtitan changes. also include fp8 in CI
```
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
```

FSDP2 fp8 all-gather are added to CI
```
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp
```

TP fp8 all-gather are locally tested. will add them to CI after
uploading a new tokenizer with vacab size 2560 (divisible by 16)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2
```

precompute scales after optimizer.step
<img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM"
src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817">

FSDP2 pre-all-gather do not have any small all-reduces
<img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM"
src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08">

TODO
* upload tokenizer with vacab size 2560 to enable CI on TP fp8
all-gather
* torch.compile complains about fp8
* add delayed scaling and brainstorm about best config option to express
fp8
* compare perf between delayed scaling and dynamic scaling
https://github.com/pytorch-labs/float8_experimental/pull/312/files

* import float8_experimental only when fp8 is enabled and install it in CI (pytorch#464)

make sure to only import float8_experimental when fp8 is enabled

for 4 gpu CI, make sure we can import float8_experimental correctly in
CI

`python -m pip install
git+https://github.com/pytorch-labs/float8_experimental.git`

* skip fp8 CI on non-H100 GPUs (pytorch#465)

skip fp8 tests on non-H100 GPUs by checking
`torch.cuda.get_device_capability() >= (9, 0)`

this makes 4 GPU CI healthy again

* clean up float8 configs in torchtitan (pytorch#466)

Summary:

1. standardizes on `float8` instead of `fp8` for config names
2. removes usage of non-public objects such as `Float8Linear`

Test Plan:

```
with-proxy NGPU=1 CUDA_VISIBLE_DEVICES=7 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.compile --training.enable_float8_linear
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Add support of DDP and experimental CompiledAutograd

Summary:
Address the comments in pytorch#319 and resubmit the PR to fit the current code base.

Test Plan:
```
CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600   --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000
```

ghstack-source-id: 81dc85d42df13df4ed727bebd825681879af936b
Pull Request resolved: pytorch#432

* add torch.compile + FSDP2 float8 all-gather in CI (pytorch#468)

fixed my bug in float8_experimental. now we can torch.compile
transfromer blocks with FSDP float8 all-gather
pytorch-labs/float8_experimental#321

local test: `CONFIG_FILE="./train_configs/debug_model.toml"
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp --training.compile`

profiler traces: I can see compiled region in cpu thread and float8
malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream
<img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM"
src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129">

* [float8] keep model.output as `nn.Linear` (high precision, not fp8) (pytorch#469)

**keep model.output as nn.Linear**: it's a common practice to NOT apply
fp8 on final output layer
* specify `skip_fqn_list` in swapping
* when applying TP to model.output, use plain `ColwiseParallel` instead
of `Float8ColwiseParallel`

credit to @awgu, we do not need tokentizer vacab size to be divisible by
16 pytorch#461

1D TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4`

1D TP + float8 all-gather, compile mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4
--training.compile`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2 --training.compile`

1D TP + float8 all-gather trace: see float8 and all-gather in the trace
<img width="1611" alt="Screenshot 2024-07-19 at 1 16 59 PM"
src="https://github.com/user-attachments/assets/9a95dfd9-40e0-4133-b2bb-e22ddf5b8472">

2D + float8 all-gather trace: see float8 and FSDP collectives and TP
collectives
<img width="1038" alt="Screenshot 2024-07-19 at 1 29 59 PM"
src="https://github.com/user-attachments/assets/6a34bcaa-bcae-402b-9994-cc892554fec7">

* remove CI for FSDP2 + fp8 all-gather (pytorch#470)

per discussion from
pytorch#469 (comment)

we are planning BC breaking changes in float8_experimental. remove CI
for FSDP2 + fp8 all-gather for now. When public APIs are finalized, we
can discuss bringing it back

* dynamically update torch.compile cache config to ensure async tp support, enhance async tp UX (pytorch#471)

This PR adds some enhancements for supporting async tp:

1 - if async tp is active, auto updates the torch.dynamo cache limit to
10K. If this is not updated, async tp will not be activated on larger
models as it will quietly stop compilation due to 'cache limit reached'
with no info for the user.
This config update is logged. 

2 - if async tp is enabled, verifies that torch.compile is set to true
for this job config. If not, it warns and then activates torch.compile
to ensure user gets working async tp. (see WARNING in below screenshot)

<img width="1345" alt="Screenshot 2024-07-20 at 4 33 04 PM"
src="https://github.com/user-attachments/assets/26e5a48e-4bb8-4f33-b1b5-8939c1517c1d">

3 - Updates the 'Applied Tensor Parallel' to the model to be 'Applied
Async Tensor Parallel' when async tp is active to make it clear in the
logs which TP is active. (see above screenshot)

* Fix 8gpu PP failure due to 2D DCP disablement

DCP recently added safeties to avoid using it for 2D/3D since strided
sharding (a feature needed for safe 2D/3D resharding) is not ready yet.

PP uses DCP to load a seed checkpoint.  Disabling the safety mechanism
is enough to make 3D/PP still work (for the case where we train from the
beginning or do not re-shard.

(Resharding refers to saving a checkpoint from one world
size/parallelism config and loading/resuming under a different one).

ghstack-source-id: c069d2186c79517c72f5b3c99485cebdc15df08f
Pull Request resolved: pytorch#460

* update float8 integration after UX changes (pytorch#484)

Summary:

float8_experimental landed various BC-breaking UX changes last week.
This PR updates torchtitan to work with the version of
float8_experimental after
pytorch-labs/float8_experimental#332 and
pytorch-labs/float8_experimental#337

Test Plan:

```
with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Re-enable FSDP2 Mem Tracker integration tests

ghstack-source-id: 8344603f7a5596cb2909c9bf04dd1b9e4730c9b8
Pull Request resolved: pytorch#485

* Used `partial` instead of global vars for LR scheduling

ghstack-source-id: 12c4418b0574d93e1441f4ca3d1de79c8aad7a40
Pull Request resolved: pytorch#487

* [EZ] Add logs for some basic training params so that we can verify in… (pytorch#491)

As title, while testing on 405B model, I found that we need to somehow
need the logs for some training params. So added some here. Tested
locally and the logging is shown as in the screenshot:


<img width="900" alt="image"
src="https://github.com/user-attachments/assets/b94e34f5-3e88-4c5f-94ed-75f50dde9786">

* make float8 scaling type configurable (pytorch#489)

Summary:

Adds config options to configure float8 scaling type for input, weight,
grad_output.

Performance is not ideal yet, but that's because we have not optimized
it.

Test Plan:

```
// repeat for input, weight, grad_out
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.float8_scaling_type_weight delayed --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [PP] add flexible interleaved 1f1b schedule pytorch#490 (pytorch#493)

This was approved in pytorch#490, but
merged into the wrong branch, merging this into main

* move float8 callsites to torchao.float8 (pytorch#492)

Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [BE][1/n] simplify train.py

ghstack-source-id: 3879e764e7b33afde5d778810c71d1d2a8f82f6d
Pull Request resolved: pytorch#494

* [BE][2/n] use proper method signatures in parallelize_llama

ghstack-source-id: 17a1ee9f03f13423a30183c5c8d7ad30f8c8dbfc
Pull Request resolved: pytorch#495

* [BE][3/n] wrap fp8 logic using Float8Handler

ghstack-source-id: e94c7f6f4fad87c5432262c54beabd02de5541b8
Pull Request resolved: pytorch#496

* Bring LLaMa 3.1 405B to TorchTitan family (pytorch#481)

With the official launch of LLaMa 3.1 model, we want to add the config
to TorchTitan. Of course, there are more work to be done, but we want to
go an incremental way. So more PRs will be needed.

For now, we try on 128 GPUs with current config (TP=8, FSDP=16). The
perf number is wps: 109 mfu: 29%.

Loss curve for 3000 steps with 600 warmup (lr = 0.8e-4).
<img width="1037" alt="image"
src="https://github.com/user-attachments/assets/f57dd3fa-07d8-4ef4-8f68-8f7a08e9652e">


Loss curve for 3000 steps with 600 warmup (lr = 1.1e-4).

![image](https://github.com/user-attachments/assets/429b9738-94cb-4b37-90ef-049a5587ddd0)

* [TP] Infer local n_heads instead of ad-hoc model changes

ghstack-source-id: 587e3d6e5270714ca734b8031ce41a962e6394ea
Pull Request resolved: pytorch#498

* some compile-related updates

ghstack-source-id: 63af8025c184fd5ad34f2f57bf78a37dda2cd33d
Pull Request resolved: pytorch#443

* [EZ][405B] Use scientific notation for 405B model lr (pytorch#504)

As title, use `8e-5` rather than `0.8e-4`.

* [BE][4/n] split pipeline_llama into a separate file

ghstack-source-id: 5ebb4adf3152f413fa33a923c272c9aa3ce1f775
Pull Request resolved: pytorch#499

* [fix] float8 should be applied on all model_parts

ghstack-source-id: 52ed6836de39e82c4c5824a40ecfc1d9ec7ed2bd
Pull Request resolved: pytorch#500

* Add warning to compile rmsnorm (pytorch#505)

as titled, add warning to compile rmsnorm as it's not fully ready yet,
i.e. this issue pytorch#497

We can remove this warning once we fix the issue

* add float8 to README (pytorch#509)

add float8 link in README so we can redirect people from dev-discuss
post to torchtitan repo


README looks like this after rendering
<img width="518" alt="Screenshot 2024-08-06 at 5 42 10 PM"
src="https://github.com/user-attachments/assets/50af99d7-93be-459a-89d7-8c08b8fb95d4">

float8.md looks like this
<img width="563" alt="Screenshot 2024-08-06 at 5 04 17 PM"
src="https://github.com/user-attachments/assets/06d30aad-4133-4cec-9037-cfcf155b45c4">

I tried the command locally and traces are looking good
<img width="726" alt="Screenshot 2024-08-06 at 5 00 00 PM"
src="https://github.com/user-attachments/assets/bdfa3d7e-efe1-4009-92a1-0f5c310013fb">

* address TODOs as 2D recompiles is fixed

ghstack-source-id: 2927f0a8082171da3e9f59a5d04f8325cbdf3653
Pull Request resolved: pytorch#508

* [BE][5/n] simply pp vs. non-pp set up

ghstack-source-id: 003bfbfbcf1511ddbd18e15d031b39f597d8e7db
Pull Request resolved: pytorch#510

* [BE][6/n] replace large c4_mini datasets by c4_test with the first 2K entries

ghstack-source-id: 319f4961b092778703101b98937803073132afa1
Pull Request resolved: pytorch#512

* Create composability.md (pytorch#511)

Explain the rationale and challenges behind certain changes we made to
llama model to support 3D parallelism.

---------

Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>

* depend on torchdata 0.8.0 instead of nightly

ghstack-source-id: 1965d3122885fed3c28e2e058c55581187e7816c
Pull Request resolved: pytorch#513

* add support for torchbench

---------

Co-authored-by: Andrew Gu <andgu@fb.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@meta.com>
Co-authored-by: Yifu Wang <yifu@fb.com>
Co-authored-by: Vasiliy Kuznetsov <vkuzo@users.noreply.github.com>
Co-authored-by: Will Constable <whc@meta.com>
Co-authored-by: Wei (Will) Feng <134637289+weifengpy@users.noreply.github.com>
Co-authored-by: Chien-Chin Huang <chienchin@fb.com>
Co-authored-by: Less Wright <lessw@etrillium.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@fb.com>
Co-authored-by: Hugo <6937752+fduwjj@users.noreply.github.com>
Co-authored-by: Howard Huang <howardhuang96@gmail.com>
Co-authored-by: Ke Wen <kw2501@meta.com>
Co-authored-by: Wanchao <wanchaol@users.noreply.github.com>
Co-authored-by: Will Constable <willconstable@gmail.com>
tianyu-l added a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
* Set `record_shapes=True` for profiler

ghstack-source-id: 6f1ed49d15ce311f1bf118820965cdb5309a8030
Pull Request resolved: pytorch#419

* Improved `repeat_kv` eager perf

ghstack-source-id: 39e484954814e61cdfb2ba661f0a98c83bc0ce60
Pull Request resolved: pytorch#418

* Adding FSDP Memory Tracking and Estimation

ghstack-source-id: c8ed20fc585957bd164dd963307616a53991615d
Pull Request resolved: pytorch#425

* Adding integration test for FSDP Memory Tracking and Estimation

ghstack-source-id: cc224db8951ec7a133fd769845a4765cbedc6454
Pull Request resolved: pytorch#426

* by default disable heavy memory profiling

ghstack-source-id: cad7b3c41fd60ec19c0e6e7d058e8aa00602a187
Pull Request resolved: pytorch#430

* Add the option to turn on async-TP

ghstack-source-id: 0a03379eeb3a63b2d1ad4dff84d0e61ca82b1bbf
Pull Request resolved: pytorch#429

* Modifying memory estimation options and minor changes

ghstack-source-id: 5f09824cddaed6585cc094095e1e95dd070d76f4
Pull Request resolved: pytorch#435

* add comment pointing to Sequence Parallel optimization example

ghstack-source-id: 6fa0dcd4bca876e10a6a8349283fb940a59ad234
Pull Request resolved: pytorch#438

* switch float8 logic from Float8DynamicLinear to Float8Linear (pytorch#436)

Summary:

After pytorch-labs/float8_experimental#300,
`Float8Linear` with default settings is equivalent to
`Float8DynamicLinear`. This PR changes `torchtitan` to use
`Float8Linear`.

To support the new UX of `float8_experimental` better, I also switched
the `fp8_linear` configuration to be a boolean on whether to swap the
linears or not. In the future we can add new options on how to configure
each linear (scaling type, scaling granularity, etc) - saving that for a
future PR.

Test Plan:

```
// run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs,
// verify performance and loss values do not change meaningfully between
// baseline and this PR

// baseline (before this PR)
// 1. compile, bf16
// 2. compile, float8
// 3. compile, float8, fdsp_fp8_allgather=True
// 4. compile, float8, fdsp_fp8_allgather=True, tp=2
// logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce

// experiment (this PR): repeat all of the above, but with Float8Linear
// logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Removed `_experimental_support_context_fn_in_torch_utils_checkpoint`

ghstack-source-id: 50b2d0c2b4c22e2f045cafd8630c16f3a8c6d35f
Pull Request resolved: pytorch#444

* Reordered TP parallel plan to follow execution order

ghstack-source-id: b4924952adeb5f16d08b60faa54690762841c422
Pull Request resolved: pytorch#445

* Made some stylistic changes to `apply_dp`

ghstack-source-id: fb78e9eb8aa406ba87d6ad6cf2229c1027dae42f
Pull Request resolved: pytorch#446

* Refactored activation checkpointing

ghstack-source-id: 785c7e47651cda97ea22d0147d14b8d061ce042d
Pull Request resolved: pytorch#447

* compiled RMSNorm

ghstack-source-id: c4efb81ec6acc5442955908cc376df3e6d889af3
Pull Request resolved: pytorch#442

* Renamed parallel styles for transformer block weights

ghstack-source-id: 5fb0bf3d08cacf27242ec0f85d5dd3cdc03b739e
Pull Request resolved: pytorch#448

* Added type annotations and more stylistic changes

ghstack-source-id: 1bd5b9d5abc8644785132f8eb2baaf8b1cfc5fb5
Pull Request resolved: pytorch#449

* [Cleanup] Remove libuv from run_llama_train.sh

libuv is now enabled by default.

we can proably do without the educational blurb there, and don't need
the env either since the default has landed.

ghstack-source-id: 68c8d2abe7eb0777e2add8df7634367c31b7ec06
Pull Request resolved: pytorch#453

* [Cleanup] Organize run_llama_train.sh options

Just a little code motion but it looks cleaner to me this way

ghstack-source-id: 055fbd557cd9cf189e6b9bd6a7048f1204e1dc5c
Pull Request resolved: pytorch#454

* [Cleanup] Split run_llama_train.sh and run_memory_estimation.sh

Make each script simpler to read

ghstack-source-id: ba3aa65feb6e304736c73daf5bc8ab5fb254f196
Pull Request resolved: pytorch#455

* [Cleanup] Remove unused TRAINER_DIR

This argument seems to be left over from older times- it is not used
anywhere in the codebase.

ghstack-source-id: abbcf82ed4d1b8fbb71c6a6b48acbc1296dbec64
Pull Request resolved: pytorch#456

* Add educational code pointers to top level README

ghstack-source-id: 522aa2fa0bf1679f55d9f3a8a38fdcd319d5e3df
Pull Request resolved: pytorch#457

* enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather (pytorch#413)

we have landed fp8 all-gather optimizations in float8_experimental
pytorch-labs/float8_experimental#266

this PR proposes torchtitan changes. also include fp8 in CI
```
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
```

FSDP2 fp8 all-gather are added to CI
```
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp
```

TP fp8 all-gather are locally tested. will add them to CI after
uploading a new tokenizer with vacab size 2560 (divisible by 16)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2
```

precompute scales after optimizer.step
<img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM"
src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817">

FSDP2 pre-all-gather do not have any small all-reduces
<img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM"
src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08">

TODO
* upload tokenizer with vacab size 2560 to enable CI on TP fp8
all-gather
* torch.compile complains about fp8
* add delayed scaling and brainstorm about best config option to express
fp8
* compare perf between delayed scaling and dynamic scaling
https://github.com/pytorch-labs/float8_experimental/pull/312/files

* import float8_experimental only when fp8 is enabled and install it in CI (pytorch#464)

make sure to only import float8_experimental when fp8 is enabled

for 4 gpu CI, make sure we can import float8_experimental correctly in
CI

`python -m pip install
git+https://github.com/pytorch-labs/float8_experimental.git`

* skip fp8 CI on non-H100 GPUs (pytorch#465)

skip fp8 tests on non-H100 GPUs by checking
`torch.cuda.get_device_capability() >= (9, 0)`

this makes 4 GPU CI healthy again

* clean up float8 configs in torchtitan (pytorch#466)

Summary:

1. standardizes on `float8` instead of `fp8` for config names
2. removes usage of non-public objects such as `Float8Linear`

Test Plan:

```
with-proxy NGPU=1 CUDA_VISIBLE_DEVICES=7 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.compile --training.enable_float8_linear
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Add support of DDP and experimental CompiledAutograd

Summary:
Address the comments in pytorch#319 and resubmit the PR to fit the current code base.

Test Plan:
```
CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600   --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000
```

ghstack-source-id: 81dc85d42df13df4ed727bebd825681879af936b
Pull Request resolved: pytorch#432

* add torch.compile + FSDP2 float8 all-gather in CI (pytorch#468)

fixed my bug in float8_experimental. now we can torch.compile
transfromer blocks with FSDP float8 all-gather
pytorch-labs/float8_experimental#321

local test: `CONFIG_FILE="./train_configs/debug_model.toml"
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp --training.compile`

profiler traces: I can see compiled region in cpu thread and float8
malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream
<img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM"
src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129">

* [float8] keep model.output as `nn.Linear` (high precision, not fp8) (pytorch#469)

**keep model.output as nn.Linear**: it's a common practice to NOT apply
fp8 on final output layer
* specify `skip_fqn_list` in swapping
* when applying TP to model.output, use plain `ColwiseParallel` instead
of `Float8ColwiseParallel`

credit to @awgu, we do not need tokentizer vacab size to be divisible by
16 pytorch#461

1D TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4`

1D TP + float8 all-gather, compile mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4
--training.compile`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2 --training.compile`

1D TP + float8 all-gather trace: see float8 and all-gather in the trace
<img width="1611" alt="Screenshot 2024-07-19 at 1 16 59 PM"
src="https://github.com/user-attachments/assets/9a95dfd9-40e0-4133-b2bb-e22ddf5b8472">

2D + float8 all-gather trace: see float8 and FSDP collectives and TP
collectives
<img width="1038" alt="Screenshot 2024-07-19 at 1 29 59 PM"
src="https://github.com/user-attachments/assets/6a34bcaa-bcae-402b-9994-cc892554fec7">

* remove CI for FSDP2 + fp8 all-gather (pytorch#470)

per discussion from
pytorch#469 (comment)

we are planning BC breaking changes in float8_experimental. remove CI
for FSDP2 + fp8 all-gather for now. When public APIs are finalized, we
can discuss bringing it back

* dynamically update torch.compile cache config to ensure async tp support, enhance async tp UX (pytorch#471)

This PR adds some enhancements for supporting async tp:

1 - if async tp is active, auto updates the torch.dynamo cache limit to
10K. If this is not updated, async tp will not be activated on larger
models as it will quietly stop compilation due to 'cache limit reached'
with no info for the user.
This config update is logged. 

2 - if async tp is enabled, verifies that torch.compile is set to true
for this job config. If not, it warns and then activates torch.compile
to ensure user gets working async tp. (see WARNING in below screenshot)

<img width="1345" alt="Screenshot 2024-07-20 at 4 33 04 PM"
src="https://github.com/user-attachments/assets/26e5a48e-4bb8-4f33-b1b5-8939c1517c1d">

3 - Updates the 'Applied Tensor Parallel' to the model to be 'Applied
Async Tensor Parallel' when async tp is active to make it clear in the
logs which TP is active. (see above screenshot)

* Fix 8gpu PP failure due to 2D DCP disablement

DCP recently added safeties to avoid using it for 2D/3D since strided
sharding (a feature needed for safe 2D/3D resharding) is not ready yet.

PP uses DCP to load a seed checkpoint.  Disabling the safety mechanism
is enough to make 3D/PP still work (for the case where we train from the
beginning or do not re-shard.

(Resharding refers to saving a checkpoint from one world
size/parallelism config and loading/resuming under a different one).

ghstack-source-id: c069d2186c79517c72f5b3c99485cebdc15df08f
Pull Request resolved: pytorch#460

* update float8 integration after UX changes (pytorch#484)

Summary:

float8_experimental landed various BC-breaking UX changes last week.
This PR updates torchtitan to work with the version of
float8_experimental after
pytorch-labs/float8_experimental#332 and
pytorch-labs/float8_experimental#337

Test Plan:

```
with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Re-enable FSDP2 Mem Tracker integration tests

ghstack-source-id: 8344603f7a5596cb2909c9bf04dd1b9e4730c9b8
Pull Request resolved: pytorch#485

* Used `partial` instead of global vars for LR scheduling

ghstack-source-id: 12c4418b0574d93e1441f4ca3d1de79c8aad7a40
Pull Request resolved: pytorch#487

* [EZ] Add logs for some basic training params so that we can verify in… (pytorch#491)

As title, while testing on 405B model, I found that we need to somehow
need the logs for some training params. So added some here. Tested
locally and the logging is shown as in the screenshot:


<img width="900" alt="image"
src="https://github.com/user-attachments/assets/b94e34f5-3e88-4c5f-94ed-75f50dde9786">

* make float8 scaling type configurable (pytorch#489)

Summary:

Adds config options to configure float8 scaling type for input, weight,
grad_output.

Performance is not ideal yet, but that's because we have not optimized
it.

Test Plan:

```
// repeat for input, weight, grad_out
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.float8_scaling_type_weight delayed --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [PP] add flexible interleaved 1f1b schedule pytorch#490 (pytorch#493)

This was approved in pytorch#490, but
merged into the wrong branch, merging this into main

* move float8 callsites to torchao.float8 (pytorch#492)

Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [BE][1/n] simplify train.py

ghstack-source-id: 3879e764e7b33afde5d778810c71d1d2a8f82f6d
Pull Request resolved: pytorch#494

* [BE][2/n] use proper method signatures in parallelize_llama

ghstack-source-id: 17a1ee9f03f13423a30183c5c8d7ad30f8c8dbfc
Pull Request resolved: pytorch#495

* [BE][3/n] wrap fp8 logic using Float8Handler

ghstack-source-id: e94c7f6f4fad87c5432262c54beabd02de5541b8
Pull Request resolved: pytorch#496

* Bring LLaMa 3.1 405B to TorchTitan family (pytorch#481)

With the official launch of LLaMa 3.1 model, we want to add the config
to TorchTitan. Of course, there are more work to be done, but we want to
go an incremental way. So more PRs will be needed.

For now, we try on 128 GPUs with current config (TP=8, FSDP=16). The
perf number is wps: 109 mfu: 29%.

Loss curve for 3000 steps with 600 warmup (lr = 0.8e-4).
<img width="1037" alt="image"
src="https://github.com/user-attachments/assets/f57dd3fa-07d8-4ef4-8f68-8f7a08e9652e">


Loss curve for 3000 steps with 600 warmup (lr = 1.1e-4).

![image](https://github.com/user-attachments/assets/429b9738-94cb-4b37-90ef-049a5587ddd0)

* [TP] Infer local n_heads instead of ad-hoc model changes

ghstack-source-id: 587e3d6e5270714ca734b8031ce41a962e6394ea
Pull Request resolved: pytorch#498

* some compile-related updates

ghstack-source-id: 63af8025c184fd5ad34f2f57bf78a37dda2cd33d
Pull Request resolved: pytorch#443

* [EZ][405B] Use scientific notation for 405B model lr (pytorch#504)

As title, use `8e-5` rather than `0.8e-4`.

* [BE][4/n] split pipeline_llama into a separate file

ghstack-source-id: 5ebb4adf3152f413fa33a923c272c9aa3ce1f775
Pull Request resolved: pytorch#499

* [fix] float8 should be applied on all model_parts

ghstack-source-id: 52ed6836de39e82c4c5824a40ecfc1d9ec7ed2bd
Pull Request resolved: pytorch#500

* Add warning to compile rmsnorm (pytorch#505)

as titled, add warning to compile rmsnorm as it's not fully ready yet,
i.e. this issue pytorch#497

We can remove this warning once we fix the issue

* add float8 to README (pytorch#509)

add float8 link in README so we can redirect people from dev-discuss
post to torchtitan repo


README looks like this after rendering
<img width="518" alt="Screenshot 2024-08-06 at 5 42 10 PM"
src="https://github.com/user-attachments/assets/50af99d7-93be-459a-89d7-8c08b8fb95d4">

float8.md looks like this
<img width="563" alt="Screenshot 2024-08-06 at 5 04 17 PM"
src="https://github.com/user-attachments/assets/06d30aad-4133-4cec-9037-cfcf155b45c4">

I tried the command locally and traces are looking good
<img width="726" alt="Screenshot 2024-08-06 at 5 00 00 PM"
src="https://github.com/user-attachments/assets/bdfa3d7e-efe1-4009-92a1-0f5c310013fb">

* address TODOs as 2D recompiles is fixed

ghstack-source-id: 2927f0a8082171da3e9f59a5d04f8325cbdf3653
Pull Request resolved: pytorch#508

* [BE][5/n] simply pp vs. non-pp set up

ghstack-source-id: 003bfbfbcf1511ddbd18e15d031b39f597d8e7db
Pull Request resolved: pytorch#510

* [BE][6/n] replace large c4_mini datasets by c4_test with the first 2K entries

ghstack-source-id: 319f4961b092778703101b98937803073132afa1
Pull Request resolved: pytorch#512

* Create composability.md (pytorch#511)

Explain the rationale and challenges behind certain changes we made to
llama model to support 3D parallelism.

---------

Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>

* depend on torchdata 0.8.0 instead of nightly

ghstack-source-id: 1965d3122885fed3c28e2e058c55581187e7816c
Pull Request resolved: pytorch#513

---------

Co-authored-by: Andrew Gu <andgu@fb.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@meta.com>
Co-authored-by: Yifu Wang <yifu@fb.com>
Co-authored-by: Vasiliy Kuznetsov <vkuzo@users.noreply.github.com>
Co-authored-by: Will Constable <whc@meta.com>
Co-authored-by: Wei (Will) Feng <134637289+weifengpy@users.noreply.github.com>
Co-authored-by: Chien-Chin Huang <chienchin@fb.com>
Co-authored-by: Less Wright <lessw@etrillium.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@fb.com>
Co-authored-by: Hugo <6937752+fduwjj@users.noreply.github.com>
Co-authored-by: Howard Huang <howardhuang96@gmail.com>
Co-authored-by: Ke Wen <kw2501@meta.com>
Co-authored-by: Wanchao <wanchaol@users.noreply.github.com>
Co-authored-by: Will Constable <willconstable@gmail.com>
tianyu-l added a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
* Set `record_shapes=True` for profiler

ghstack-source-id: 6f1ed49d15ce311f1bf118820965cdb5309a8030
Pull Request resolved: pytorch#419

* Improved `repeat_kv` eager perf

ghstack-source-id: 39e484954814e61cdfb2ba661f0a98c83bc0ce60
Pull Request resolved: pytorch#418

* Adding FSDP Memory Tracking and Estimation

ghstack-source-id: c8ed20fc585957bd164dd963307616a53991615d
Pull Request resolved: pytorch#425

* Adding integration test for FSDP Memory Tracking and Estimation

ghstack-source-id: cc224db8951ec7a133fd769845a4765cbedc6454
Pull Request resolved: pytorch#426

* by default disable heavy memory profiling

ghstack-source-id: cad7b3c41fd60ec19c0e6e7d058e8aa00602a187
Pull Request resolved: pytorch#430

* Add the option to turn on async-TP

ghstack-source-id: 0a03379eeb3a63b2d1ad4dff84d0e61ca82b1bbf
Pull Request resolved: pytorch#429

* Modifying memory estimation options and minor changes

ghstack-source-id: 5f09824cddaed6585cc094095e1e95dd070d76f4
Pull Request resolved: pytorch#435

* add comment pointing to Sequence Parallel optimization example

ghstack-source-id: 6fa0dcd4bca876e10a6a8349283fb940a59ad234
Pull Request resolved: pytorch#438

* switch float8 logic from Float8DynamicLinear to Float8Linear (pytorch#436)

Summary:

After pytorch-labs/float8_experimental#300,
`Float8Linear` with default settings is equivalent to
`Float8DynamicLinear`. This PR changes `torchtitan` to use
`Float8Linear`.

To support the new UX of `float8_experimental` better, I also switched
the `fp8_linear` configuration to be a boolean on whether to swap the
linears or not. In the future we can add new options on how to configure
each linear (scaling type, scaling granularity, etc) - saving that for a
future PR.

Test Plan:

```
// run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs,
// verify performance and loss values do not change meaningfully between
// baseline and this PR

// baseline (before this PR)
// 1. compile, bf16
// 2. compile, float8
// 3. compile, float8, fdsp_fp8_allgather=True
// 4. compile, float8, fdsp_fp8_allgather=True, tp=2
// logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce

// experiment (this PR): repeat all of the above, but with Float8Linear
// logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Removed `_experimental_support_context_fn_in_torch_utils_checkpoint`

ghstack-source-id: 50b2d0c2b4c22e2f045cafd8630c16f3a8c6d35f
Pull Request resolved: pytorch#444

* Reordered TP parallel plan to follow execution order

ghstack-source-id: b4924952adeb5f16d08b60faa54690762841c422
Pull Request resolved: pytorch#445

* Made some stylistic changes to `apply_dp`

ghstack-source-id: fb78e9eb8aa406ba87d6ad6cf2229c1027dae42f
Pull Request resolved: pytorch#446

* Refactored activation checkpointing

ghstack-source-id: 785c7e47651cda97ea22d0147d14b8d061ce042d
Pull Request resolved: pytorch#447

* compiled RMSNorm

ghstack-source-id: c4efb81ec6acc5442955908cc376df3e6d889af3
Pull Request resolved: pytorch#442

* Renamed parallel styles for transformer block weights

ghstack-source-id: 5fb0bf3d08cacf27242ec0f85d5dd3cdc03b739e
Pull Request resolved: pytorch#448

* Added type annotations and more stylistic changes

ghstack-source-id: 1bd5b9d5abc8644785132f8eb2baaf8b1cfc5fb5
Pull Request resolved: pytorch#449

* [Cleanup] Remove libuv from run_llama_train.sh

libuv is now enabled by default.

we can proably do without the educational blurb there, and don't need
the env either since the default has landed.

ghstack-source-id: 68c8d2abe7eb0777e2add8df7634367c31b7ec06
Pull Request resolved: pytorch#453

* [Cleanup] Organize run_llama_train.sh options

Just a little code motion but it looks cleaner to me this way

ghstack-source-id: 055fbd557cd9cf189e6b9bd6a7048f1204e1dc5c
Pull Request resolved: pytorch#454

* [Cleanup] Split run_llama_train.sh and run_memory_estimation.sh

Make each script simpler to read

ghstack-source-id: ba3aa65feb6e304736c73daf5bc8ab5fb254f196
Pull Request resolved: pytorch#455

* [Cleanup] Remove unused TRAINER_DIR

This argument seems to be left over from older times- it is not used
anywhere in the codebase.

ghstack-source-id: abbcf82ed4d1b8fbb71c6a6b48acbc1296dbec64
Pull Request resolved: pytorch#456

* Add educational code pointers to top level README

ghstack-source-id: 522aa2fa0bf1679f55d9f3a8a38fdcd319d5e3df
Pull Request resolved: pytorch#457

* enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather (pytorch#413)

we have landed fp8 all-gather optimizations in float8_experimental
pytorch-labs/float8_experimental#266

this PR proposes torchtitan changes. also include fp8 in CI
```
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
```

FSDP2 fp8 all-gather are added to CI
```
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp
```

TP fp8 all-gather are locally tested. will add them to CI after
uploading a new tokenizer with vacab size 2560 (divisible by 16)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2
```

precompute scales after optimizer.step
<img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM"
src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817">

FSDP2 pre-all-gather do not have any small all-reduces
<img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM"
src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08">

TODO
* upload tokenizer with vacab size 2560 to enable CI on TP fp8
all-gather
* torch.compile complains about fp8
* add delayed scaling and brainstorm about best config option to express
fp8
* compare perf between delayed scaling and dynamic scaling
https://github.com/pytorch-labs/float8_experimental/pull/312/files

* import float8_experimental only when fp8 is enabled and install it in CI (pytorch#464)

make sure to only import float8_experimental when fp8 is enabled

for 4 gpu CI, make sure we can import float8_experimental correctly in
CI

`python -m pip install
git+https://github.com/pytorch-labs/float8_experimental.git`

* skip fp8 CI on non-H100 GPUs (pytorch#465)

skip fp8 tests on non-H100 GPUs by checking
`torch.cuda.get_device_capability() >= (9, 0)`

this makes 4 GPU CI healthy again

* clean up float8 configs in torchtitan (pytorch#466)

Summary:

1. standardizes on `float8` instead of `fp8` for config names
2. removes usage of non-public objects such as `Float8Linear`

Test Plan:

```
with-proxy NGPU=1 CUDA_VISIBLE_DEVICES=7 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.compile --training.enable_float8_linear
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Add support of DDP and experimental CompiledAutograd

Summary:
Address the comments in pytorch#319 and resubmit the PR to fit the current code base.

Test Plan:
```
CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600   --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000
```

ghstack-source-id: 81dc85d42df13df4ed727bebd825681879af936b
Pull Request resolved: pytorch#432

* add torch.compile + FSDP2 float8 all-gather in CI (pytorch#468)

fixed my bug in float8_experimental. now we can torch.compile
transfromer blocks with FSDP float8 all-gather
pytorch-labs/float8_experimental#321

local test: `CONFIG_FILE="./train_configs/debug_model.toml"
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp --training.compile`

profiler traces: I can see compiled region in cpu thread and float8
malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream
<img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM"
src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129">

* [float8] keep model.output as `nn.Linear` (high precision, not fp8) (pytorch#469)

**keep model.output as nn.Linear**: it's a common practice to NOT apply
fp8 on final output layer
* specify `skip_fqn_list` in swapping
* when applying TP to model.output, use plain `ColwiseParallel` instead
of `Float8ColwiseParallel`

credit to @awgu, we do not need tokentizer vacab size to be divisible by
16 pytorch#461

1D TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4`

1D TP + float8 all-gather, compile mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4
--training.compile`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2 --training.compile`

1D TP + float8 all-gather trace: see float8 and all-gather in the trace
<img width="1611" alt="Screenshot 2024-07-19 at 1 16 59 PM"
src="https://github.com/user-attachments/assets/9a95dfd9-40e0-4133-b2bb-e22ddf5b8472">

2D + float8 all-gather trace: see float8 and FSDP collectives and TP
collectives
<img width="1038" alt="Screenshot 2024-07-19 at 1 29 59 PM"
src="https://github.com/user-attachments/assets/6a34bcaa-bcae-402b-9994-cc892554fec7">

* remove CI for FSDP2 + fp8 all-gather (pytorch#470)

per discussion from
pytorch#469 (comment)

we are planning BC breaking changes in float8_experimental. remove CI
for FSDP2 + fp8 all-gather for now. When public APIs are finalized, we
can discuss bringing it back

* dynamically update torch.compile cache config to ensure async tp support, enhance async tp UX (pytorch#471)

This PR adds some enhancements for supporting async tp:

1 - if async tp is active, auto updates the torch.dynamo cache limit to
10K. If this is not updated, async tp will not be activated on larger
models as it will quietly stop compilation due to 'cache limit reached'
with no info for the user.
This config update is logged. 

2 - if async tp is enabled, verifies that torch.compile is set to true
for this job config. If not, it warns and then activates torch.compile
to ensure user gets working async tp. (see WARNING in below screenshot)

<img width="1345" alt="Screenshot 2024-07-20 at 4 33 04 PM"
src="https://github.com/user-attachments/assets/26e5a48e-4bb8-4f33-b1b5-8939c1517c1d">

3 - Updates the 'Applied Tensor Parallel' to the model to be 'Applied
Async Tensor Parallel' when async tp is active to make it clear in the
logs which TP is active. (see above screenshot)

* Fix 8gpu PP failure due to 2D DCP disablement

DCP recently added safeties to avoid using it for 2D/3D since strided
sharding (a feature needed for safe 2D/3D resharding) is not ready yet.

PP uses DCP to load a seed checkpoint.  Disabling the safety mechanism
is enough to make 3D/PP still work (for the case where we train from the
beginning or do not re-shard.

(Resharding refers to saving a checkpoint from one world
size/parallelism config and loading/resuming under a different one).

ghstack-source-id: c069d2186c79517c72f5b3c99485cebdc15df08f
Pull Request resolved: pytorch#460

* update float8 integration after UX changes (pytorch#484)

Summary:

float8_experimental landed various BC-breaking UX changes last week.
This PR updates torchtitan to work with the version of
float8_experimental after
pytorch-labs/float8_experimental#332 and
pytorch-labs/float8_experimental#337

Test Plan:

```
with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Re-enable FSDP2 Mem Tracker integration tests

ghstack-source-id: 8344603f7a5596cb2909c9bf04dd1b9e4730c9b8
Pull Request resolved: pytorch#485

* Used `partial` instead of global vars for LR scheduling

ghstack-source-id: 12c4418b0574d93e1441f4ca3d1de79c8aad7a40
Pull Request resolved: pytorch#487

* [EZ] Add logs for some basic training params so that we can verify in… (pytorch#491)

As title, while testing on 405B model, I found that we need to somehow
need the logs for some training params. So added some here. Tested
locally and the logging is shown as in the screenshot:


<img width="900" alt="image"
src="https://github.com/user-attachments/assets/b94e34f5-3e88-4c5f-94ed-75f50dde9786">

* make float8 scaling type configurable (pytorch#489)

Summary:

Adds config options to configure float8 scaling type for input, weight,
grad_output.

Performance is not ideal yet, but that's because we have not optimized
it.

Test Plan:

```
// repeat for input, weight, grad_out
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.float8_scaling_type_weight delayed --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [PP] add flexible interleaved 1f1b schedule pytorch#490 (pytorch#493)

This was approved in pytorch#490, but
merged into the wrong branch, merging this into main

* move float8 callsites to torchao.float8 (pytorch#492)

Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [BE][1/n] simplify train.py

ghstack-source-id: 3879e764e7b33afde5d778810c71d1d2a8f82f6d
Pull Request resolved: pytorch#494

* [BE][2/n] use proper method signatures in parallelize_llama

ghstack-source-id: 17a1ee9f03f13423a30183c5c8d7ad30f8c8dbfc
Pull Request resolved: pytorch#495

* [BE][3/n] wrap fp8 logic using Float8Handler

ghstack-source-id: e94c7f6f4fad87c5432262c54beabd02de5541b8
Pull Request resolved: pytorch#496

* Bring LLaMa 3.1 405B to TorchTitan family (pytorch#481)

With the official launch of LLaMa 3.1 model, we want to add the config
to TorchTitan. Of course, there are more work to be done, but we want to
go an incremental way. So more PRs will be needed.

For now, we try on 128 GPUs with current config (TP=8, FSDP=16). The
perf number is wps: 109 mfu: 29%.

Loss curve for 3000 steps with 600 warmup (lr = 0.8e-4).
<img width="1037" alt="image"
src="https://github.com/user-attachments/assets/f57dd3fa-07d8-4ef4-8f68-8f7a08e9652e">


Loss curve for 3000 steps with 600 warmup (lr = 1.1e-4).

![image](https://github.com/user-attachments/assets/429b9738-94cb-4b37-90ef-049a5587ddd0)

* [TP] Infer local n_heads instead of ad-hoc model changes

ghstack-source-id: 587e3d6e5270714ca734b8031ce41a962e6394ea
Pull Request resolved: pytorch#498

* some compile-related updates

ghstack-source-id: 63af8025c184fd5ad34f2f57bf78a37dda2cd33d
Pull Request resolved: pytorch#443

* [EZ][405B] Use scientific notation for 405B model lr (pytorch#504)

As title, use `8e-5` rather than `0.8e-4`.

* [BE][4/n] split pipeline_llama into a separate file

ghstack-source-id: 5ebb4adf3152f413fa33a923c272c9aa3ce1f775
Pull Request resolved: pytorch#499

* [fix] float8 should be applied on all model_parts

ghstack-source-id: 52ed6836de39e82c4c5824a40ecfc1d9ec7ed2bd
Pull Request resolved: pytorch#500

* Add warning to compile rmsnorm (pytorch#505)

as titled, add warning to compile rmsnorm as it's not fully ready yet,
i.e. this issue pytorch#497

We can remove this warning once we fix the issue

* add float8 to README (pytorch#509)

add float8 link in README so we can redirect people from dev-discuss
post to torchtitan repo


README looks like this after rendering
<img width="518" alt="Screenshot 2024-08-06 at 5 42 10 PM"
src="https://github.com/user-attachments/assets/50af99d7-93be-459a-89d7-8c08b8fb95d4">

float8.md looks like this
<img width="563" alt="Screenshot 2024-08-06 at 5 04 17 PM"
src="https://github.com/user-attachments/assets/06d30aad-4133-4cec-9037-cfcf155b45c4">

I tried the command locally and traces are looking good
<img width="726" alt="Screenshot 2024-08-06 at 5 00 00 PM"
src="https://github.com/user-attachments/assets/bdfa3d7e-efe1-4009-92a1-0f5c310013fb">

* address TODOs as 2D recompiles is fixed

ghstack-source-id: 2927f0a8082171da3e9f59a5d04f8325cbdf3653
Pull Request resolved: pytorch#508

* [BE][5/n] simply pp vs. non-pp set up

ghstack-source-id: 003bfbfbcf1511ddbd18e15d031b39f597d8e7db
Pull Request resolved: pytorch#510

* [BE][6/n] replace large c4_mini datasets by c4_test with the first 2K entries

ghstack-source-id: 319f4961b092778703101b98937803073132afa1
Pull Request resolved: pytorch#512

* Create composability.md (pytorch#511)

Explain the rationale and challenges behind certain changes we made to
llama model to support 3D parallelism.

---------

Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>

* depend on torchdata 0.8.0 instead of nightly

ghstack-source-id: 1965d3122885fed3c28e2e058c55581187e7816c
Pull Request resolved: pytorch#513

* add support for torchbench

---------

Co-authored-by: Andrew Gu <andgu@fb.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@meta.com>
Co-authored-by: Yifu Wang <yifu@fb.com>
Co-authored-by: Vasiliy Kuznetsov <vkuzo@users.noreply.github.com>
Co-authored-by: Will Constable <whc@meta.com>
Co-authored-by: Wei (Will) Feng <134637289+weifengpy@users.noreply.github.com>
Co-authored-by: Chien-Chin Huang <chienchin@fb.com>
Co-authored-by: Less Wright <lessw@etrillium.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@fb.com>
Co-authored-by: Hugo <6937752+fduwjj@users.noreply.github.com>
Co-authored-by: Howard Huang <howardhuang96@gmail.com>
Co-authored-by: Ke Wen <kw2501@meta.com>
Co-authored-by: Wanchao <wanchaol@users.noreply.github.com>
Co-authored-by: Will Constable <willconstable@gmail.com>
tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
we have landed fp8 all-gather optimizations in float8_experimental
pytorch-labs/float8_experimental#266

this PR proposes torchtitan changes. also include fp8 in CI
```
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
```

FSDP2 fp8 all-gather are added to CI
```
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp
```

TP fp8 all-gather are locally tested. will add them to CI after
uploading a new tokenizer with vacab size 2560 (divisible by 16)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2
```

precompute scales after optimizer.step
<img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM"
src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817">

FSDP2 pre-all-gather do not have any small all-reduces
<img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM"
src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08">

TODO
* upload tokenizer with vacab size 2560 to enable CI on TP fp8
all-gather
* torch.compile complains about fp8
* add delayed scaling and brainstorm about best config option to express
fp8
* compare perf between delayed scaling and dynamic scaling
https://github.com/pytorch-labs/float8_experimental/pull/312/files
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
we have landed fp8 all-gather optimizations in float8_experimental
pytorch-labs/float8_experimental#266

this PR proposes torchtitan changes. also include fp8 in CI
```
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
```

FSDP2 fp8 all-gather are added to CI
```
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp
```

TP fp8 all-gather are locally tested. will add them to CI after
uploading a new tokenizer with vacab size 2560 (divisible by 16)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2
```

precompute scales after optimizer.step
<img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM"
src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817">

FSDP2 pre-all-gather do not have any small all-reduces
<img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM"
src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08">

TODO
* upload tokenizer with vacab size 2560 to enable CI on TP fp8
all-gather
* torch.compile complains about fp8
* add delayed scaling and brainstorm about best config option to express
fp8
* compare perf between delayed scaling and dynamic scaling
https://github.com/pytorch-labs/float8_experimental/pull/312/files
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants