Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8] FP8 GPT 1.5B Delayed Scaling 2x Slower than BF16 #1297

Closed
OrenLeung opened this issue Nov 16, 2024 · 5 comments
Closed

[float8] FP8 GPT 1.5B Delayed Scaling 2x Slower than BF16 #1297

OrenLeung opened this issue Nov 16, 2024 · 5 comments
Assignees

Comments

@OrenLeung
Copy link

Hi Torch Team,

I am currently experimenting with native torch float8 training & comparing it to the Transformer Engine using the delayed scaling recipe on GPT 1.5B at batch=12 seq=1024 on 700W H100 SXM 80G SKU.

I see that fp8 transformer engine provides slight perf include compared to autocast bf16 but unfortunately torchao.float8 is almost 2x slower. I attempted to improve performance by trying to enable fp8 & using bf16 autocast at the same time but unfortunately I ran into ValueError: All layers must have the same last seen input_dtype, got {torch.float32, torch.bfloat16} error. enabling fp8 & using bf16 autocast is something that TE does but not sure if it is needed for torchao.

Can you provide some guidance on how to improve performance on torchao.float8?

Thanks!

BF16 Autocast: 493.17 TFLOP/s
FP8 TE: 501.2 TFLOP/s
torchao.float8: 240.67 TFLOP/s 

Reprod Script

import torch
import torch.nn as nn
from torchao.float8 import (
    convert_to_float8_training,
    sync_float8_amax_and_scale_history,
    Float8LinearConfig,
    ScalingType,
    CastConfig,
)
import torch.nn.functional as F
import fire

class CausalSelfAttention(nn.Module):
    def __init__(self, d_embd, n_heads, **kwargs):
        super().__init__()
        self.d_head = d_embd // n_heads  # D
        self.attn_proj = nn.Linear(d_embd, 3*d_embd)
        self.out_proj = nn.Linear(d_embd, d_embd)
 
    def forward(self, x_BTE):
        qkv = self.attn_proj(x_BTE).split(x_BTE.size(-1), -1)
        split_attn_head = lambda z: z.unflatten(-1, [-1, self.d_head]).transpose(1, 2)
        q_BHTD, k_BHTD, v_BHTD = map(split_attn_head, qkv)
        o_BHTD = F.scaled_dot_product_attention(q_BHTD, k_BHTD, v_BHTD, dropout_p=0.0, is_causal=True)
        o_BTE = o_BHTD.transpose(1, 2).flatten(-2)
        y_BTE = self.out_proj(o_BTE)
        return y_BTE

class GPTBlock(nn.Module):
    def __init__(self, d_embd, **kwargs):
        super().__init__()
        self.attn_norm = nn.LayerNorm(d_embd)
        self.attn = CausalSelfAttention(d_embd, **kwargs)
        self.ffn_norm = nn.LayerNorm(d_embd)
        self.ffn = nn.Sequential(
            nn.Linear(d_embd, 4*d_embd),
            nn.GELU(),
            nn.Linear(4*d_embd, d_embd)
        )

    def forward(self, x_BTE):
        x_BTE = x_BTE + self.attn(self.attn_norm(x_BTE))
        y_BTE = x_BTE + self.ffn(self.ffn_norm(x_BTE))
        return y_BTE

class GPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.pos_embd = nn.Embedding(max_seq_len, d_embd)
        self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
        self.out_norm = nn.LayerNorm(d_embd)

    def forward(self, idx_BT, **kwargs):
        pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
        x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)

        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE)

        x_BTE = self.out_norm(x_BTE)
        logits_BTV = x_BTE @ self.tok_embd.weight.T  # Weight tying

        return logits_BTV

# configure delayed scaling
config = Float8LinearConfig(
    cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
    cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
    cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
    # enable_amax_init=False,  # only needed for autocast + compile + FSDP +  float8 delayed
    # enable_pre_and_post_forward=False  # only needed for autocast + compile + FSDP +  float8 delayed
)

def main(enable_fp8=True):
    torch.manual_seed(3985)
    torch.cuda.set_device(0)
    
    # GPT 1.5B
    cfg_json = {
        "n_layers": 48,
        "n_heads": 25,
        "d_embd": 1600,
        "max_seq_len": 1024,
        "vocab_size": 50304,
        "arch_name": "gpt"
    }
    model = GPT(**cfg_json).to('cuda:0')
    
    N = sum(p.numel() for p in model.parameters())  # get param count

    flops_per_iter = 6 * N * 16 * 1024
    
    optimizer = torch.optim.AdamW(model.parameters(), fused=True)

    if enable_fp8:
        convert_to_float8_training(model, config=config)

    model = torch.compile(model)
    
    for step_idx in range(100):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        input_BT = torch.randint(50304, [12, 1024], dtype=torch.int64).to('cuda:0')
        label_BT = torch.randint(50304, [12, 1024], dtype=torch.int64).to('cuda:0')

        start.record()
        if not enable_fp8:
            with torch.amp.autocast('cuda', torch.bfloat16):
                logits_BTV = model(input_BT)
                loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
        else:
            logits_BTV = model(input_BT)
            loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
        loss.backward()
        if enable_fp8:
            sync_float8_amax_and_scale_history(model)

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        end.record()

        torch.cuda.synchronize()
        t = start.elapsed_time(end) / 1e3
        flops_per_sec = flops_per_iter / t
        print(f"finish {step_idx} step: {(flops_per_sec/1e12):.2f} TFLOP/s")

if __name__ == "__main__":
    fire.Fire(main)

Dependencies

$ pip list | grep torch
pytorch-triton               3.1.0+cf34004b8a
torch                        2.6.0.dev20241030+cu124
torch-tb-profiler            0.4.3
torchao                      0.7.0.dev20241112+cu121
@vkuzo vkuzo self-assigned this Nov 18, 2024
@vkuzo
Copy link
Contributor

vkuzo commented Nov 18, 2024

Hi @OrenLeung , thanks for the repro! This looks like a bug in how we handle delayed scaling + autocast, let me take a look.

vkuzo added a commit that referenced this issue Nov 18, 2024
Summary:

Fixes a bug with delayed scaling + autocast.

Before, the last input dtype when in autocast was queried from the input
to `torch._scaled_mm`:

```
x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm
```

This is incorrect because the dtype was saved from before the place
where autocast could change it.  This happened to work if `x_hp` was
already of the correct dtype, but did not work in cases such as the new
test case added in this PR, or real models such as the repro from
#1297.  The reason we haven't caught
this for so long is we've been using FSDP's mixed precision and not
single-GPU autocast.

The fix I'm taking here is to query the original post-autocast dtype based
on the output of `torch._scaled_mm`.  Since this dtype is based on the
dtype of the input to `torch._scaled_mm`, this will properly capture
autocasting:

```
x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here}
```

Test Plan:

```
// first, test the updated test case - it passes

// second - test a modified version of the repro in
// #1297:
// code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7
// logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10
// we now see a speedup with float8
```

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit that referenced this issue Nov 18, 2024
Summary:

Fixes a bug with delayed scaling + autocast.

Before, the last input dtype when in autocast was queried from the input
to `torch._scaled_mm`:

```
x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm
```

This is incorrect because the dtype was saved from before the place
where autocast could change it.  This happened to work if `x_hp` was
already of the correct dtype, but did not work in cases such as the new
test case added in this PR, or real models such as the repro from
#1297.  The reason we haven't caught
this for so long is we've been using FSDP's mixed precision and not
single-GPU autocast.

The fix I'm taking here is to query the original post-autocast dtype based
on the output of `torch._scaled_mm`.  Since this dtype is based on the
dtype of the input to `torch._scaled_mm`, this will properly capture
autocasting:

```
x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here}
```

Test Plan:

```
// first, test the updated test case - it passes

// second - test a modified version of the repro in
// #1297:
// code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7
// logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10
// we now see a speedup with float8
```

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo
Copy link
Contributor

vkuzo commented Nov 18, 2024

thanks again for the report, #1306 should fix this. With that PR on my H100 machine:

  • baseline (bf16 + compile): 418 TFLOPS/s
  • experiment (fp8 delayed + compile): 436 TFLOPS/s (+4.3%)

@OrenLeung
Copy link
Author

@vkuzo Thanks for the quick fix!

I am guessing you did your benchmark the 500W h100 version?

I can confirm the fix using #1306 ! I am seeing the following:

@vkuzo
Copy link
Contributor

vkuzo commented Nov 19, 2024

I am guessing you did your benchmark the 500W h100 version?

Yes, that's correct.

vkuzo added a commit that referenced this issue Nov 19, 2024
Summary:

Fixes a bug with delayed scaling + autocast.

Before, the last input dtype when in autocast was queried from the input
to `torch._scaled_mm`:

```
x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm
```

This is incorrect because the dtype was saved from before the place
where autocast could change it.  This happened to work if `x_hp` was
already of the correct dtype, but did not work in cases such as the new
test case added in this PR, or real models such as the repro from
#1297.  The reason we haven't caught
this for so long is we've been using FSDP's mixed precision and not
single-GPU autocast.

The fix I'm taking here is to query the original post-autocast dtype based
on the output of `torch._scaled_mm`.  Since this dtype is based on the
dtype of the input to `torch._scaled_mm`, this will properly capture
autocasting:

```
x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here}
```

Test Plan:

```
// first, test the updated test case - it passes

// second - test a modified version of the repro in
// #1297:
// code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7
// logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10
// we now see a speedup with float8
```

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo
Copy link
Contributor

vkuzo commented Nov 19, 2024

closing since the fix landed

@vkuzo vkuzo closed this as completed Nov 19, 2024
sunjiweiswift pushed a commit to sunjiweiswift/ao that referenced this issue Nov 25, 2024
Summary:

Fixes a bug with delayed scaling + autocast.

Before, the last input dtype when in autocast was queried from the input
to `torch._scaled_mm`:

```
x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm
```

This is incorrect because the dtype was saved from before the place
where autocast could change it.  This happened to work if `x_hp` was
already of the correct dtype, but did not work in cases such as the new
test case added in this PR, or real models such as the repro from
pytorch#1297.  The reason we haven't caught
this for so long is we've been using FSDP's mixed precision and not
single-GPU autocast.

The fix I'm taking here is to query the original post-autocast dtype based
on the output of `torch._scaled_mm`.  Since this dtype is based on the
dtype of the input to `torch._scaled_mm`, this will properly capture
autocasting:

```
x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here}
```

Test Plan:

```
// first, test the updated test case - it passes

// second - test a modified version of the repro in
// pytorch#1297:
// code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7
// logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10
// we now see a speedup with float8
```

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants