Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the option to turn on async-TP #429

Merged
merged 2 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ def __init__(self):
action="store_true",
help="Whether to apply loss parallel when sequence parallel is enabled",
)
self.parser.add_argument(
"--experimental.enable_async_tensor_parallel",
default=False,
action="store_true",
help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_degree",
type=int,
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
parallelize_plan=layer_plan,
)

if job_config.experimental.enable_async_tensor_parallel:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

logger.info("Applied Tensor Parallelism to the model")
return model

Expand Down
Loading