Skip to content

Commit

Permalink
[do NOT land] CP+torch.compile debugging attempt
Browse files Browse the repository at this point in the history
ghstack-source-id: 9e02b6203ccce720f6558508343f12380b37c86c
Pull Request resolved: #791
  • Loading branch information
XilunWu committed Jan 17, 2025
1 parent 95677cb commit 3a054dc
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 2 deletions.
2 changes: 1 addition & 1 deletion torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
"8B": ModelArgs(
dim=4096,
n_layers=32,
n_layers=1,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
Expand Down
11 changes: 11 additions & 0 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ def init_weights(self, init_std: float):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)

@torch.compiler.disable
def SDPA(self, *args, **kwargs):
return F.scaled_dot_product_attention(*args, **kwargs)

@torch.compiler.disable
def noop(self):
return None

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -205,8 +213,11 @@ def forward(
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)

# self.noop()
# we use casual mask for training
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
# output = self.SDPA(xq, xk, xv, is_causal=True)
# self.noop()
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,10 @@ def apply_compile(model: nn.Module):
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
# torch._inductor.config.force_disable_caches = True

for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(transformer_block, fullgraph=True)
transformer_block = torch.compile(transformer_block, fullgraph=False)
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiling each TransformerBlock with torch.compile")
Expand Down
1 change: 1 addition & 0 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def context(cp_context: Optional[Generator[None, None, None]] = None):
stack.enter_context(
sdpa_kernel(
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
# [SDPBackend.CUDNN_ATTENTION]
)
)
stack.enter_context(cp_context)
Expand Down

0 comments on commit 3a054dc

Please sign in to comment.