Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a few related optimization passes for fp8 gemm custom-calls. #16975

Closed
wants to merge 1 commit into from

Conversation

elfiegg
Copy link
Contributor

@elfiegg elfiegg commented Sep 10, 2024

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:

NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016

After:

NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958

@elfiegg elfiegg changed the title Add a few related optimization pass for fp8 gemm rerwriter. Add a few related optimization passes for fp8 gemm rerwriter. Sep 10, 2024
@elfiegg elfiegg changed the title Add a few related optimization passes for fp8 gemm rerwriter. Add a few related optimization passes for fp8 gemm custom-calls. Sep 10, 2024
@NaiyerRizz NaiyerRizz self-assigned this Sep 10, 2024
@cheshire
Copy link
Contributor

Is it possible to add tests? Do we have an explanation why those passes are required for correctness?

@kaixih
Copy link
Contributor

kaixih commented Sep 10, 2024

Let me provide more context:

We are migrating from our original fake-quantization-like FP8 pattern to a new direct-quantization approach, where DQ scaling is applied as the epilogue of the dot operation. This change allows us to avoid concerns about other XLA optimizer passes breaking our patterns, as the dot can now directly handle FP8 inputs in direct-quantization and it will always be lowered to call fp8 gemm.

During the migration, @elfiegg found that wheh Triton GEMM falls back to the GEMM rewriter for cuBLAS, these three specific passes, i.e. LayoutNormalization, GpuAlgebraicSimplifier, and ScatterSimplifier, are necessary to ensure correctness. I think she is still investigating for the specific reason and looking for the unit test.

cc. @nouiz @sergachev @wenscarl @jprabhas

@cheshire
Copy link
Contributor

OK thanks for the context!

@elfiegg
Copy link
Contributor Author

elfiegg commented Sep 11, 2024

Thanks Kaixi for bringing everyone on the same page! Also sorry for the delay.

As mentioned, during debugging, we found that layout normalization is crucial for ensuring numerical correctness. We consistently reproduced the numerical issue when using different operand layout permutations with cuBLAS - operands such as f8e4m3fn[12288,4096]{1,0} vs. f8e4m3fn[4096,12288]{0,1} would cause numerical difference. This suggests that cublas runtime thunk is sensitive to logical layout changes, making LayoutNormalization necessary for the cublas GEMM rewriter.

I've also added a unit test to ensure numerical correctness in the pipeline, both with Triton fusion falling back to cuBLAS and without Triton fusion(the test will fail without the changes in this PR.)

@elfiegg
Copy link
Contributor Author

elfiegg commented Sep 16, 2024

Removed a duplicate GpuAlgebraicSimplifier pass that is already in place in a later stage.

Upon further investigation, it seems the cublas runtime thunk correctly processed the layout, and the final MatrixLayout of both operands and buffer assignments are identical for the cublasLt custom calls in both cases: f8e4m3fn[12288,4096]{1,0} and f8e4m3fn[4096,12288]{0,1}. Despite this, cublasLt produces different numerical results for the two calls. It's unclear what the root cause is.

Since LayoutNormalization has been used in previous cublasLT FP8 GEMM calls without Triton fp8 gemm, I believe it's safe to proceed with adding this pass for now, pending further investigation.

@elfiegg
Copy link
Contributor Author

elfiegg commented Sep 17, 2024

gentle ping @cheshire @reedwm

@elfiegg
Copy link
Contributor Author

elfiegg commented Sep 17, 2024

Further investigation shows transpose seems broken for fp8 operands as following modules generate different numerics, for which I'll file another bug for tracking.

  ENTRY main {
    %p0 = f8e5m2[12288,4096]{0,1} parameter(0)
    %b = f8e5m2[4096,12288]{1,0} bitcast(%p0)
    ROOT %transpose = f8e5m2[12288,4096]{1,0} transpose(%b), dimensions={1,0}
  }
ENTRY main {
    %p0 = f8e5m2[12288,4096]{0,1} parameter(0)
    %transpose = f8e5m2[4096,12288]{0,1} transpose(%p0), dimensions={1,0}
    ROOT %bitcast = f8e5m2[12288,4096]{1,0} bitcast(%transpose)
  }

Copy link
Member

@reedwm reedwm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this is necessary. It seems the core issue is that certain layouts have incorrect numerics. That said, I'm ok taking this for now if it does fix the layout issue, but in the long term it's better to directly ensure different layouts still have correct numerics.

Does LayoutNormalization even affect a cublas gemm custom call? I see we pass NormalizeLayoutForGpuCustomCalls, but that only affects convolutions, not gemms.


HloModuleConfig config;
DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest();
triton_enabled_debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why disable dynamic slice fusion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No shame in copy-pasting from the other test ;:^^)

Comment on lines 503 to 504
config.set_replica_count(1);
config.set_num_partitions(1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to set these, as they are the defaults

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 506 to 508
// Load autotuning DB. We shouldn't depend on actual execution times in a unit
// test.
std::string path =
tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
"gpu_compiler_test_autotune_db.textproto");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not load autotune results, but instead disable the cublas fallback for the triton case by calling triton_enabled_debug_options.set_xla_gpu_cublas_fallback(false). That way you don't have to make sure gpu_compiler_test_autotune_db.textproto has the exact gemm config that this HLO generates. Maybe disable autotuning as well if you want the outputs to be deterministic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah let me clarify - the test is comparing triton enabled but falling back to cublasLT vs. triton disabled paths. Regarding that do you have any suggestion / concern?
Renamed the test to be clearer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, and the test name clarifies things. I forgot that GemmRewriter is called twice: Once to handle gemms that the Triton rewriter didn't handle, and again to handle formerly-Triton fusions that the autotuner decided to use cublas for. So good to test the two cases are equivalent numerically.

@kaixih
Copy link
Contributor

kaixih commented Sep 17, 2024

Does LayoutNormalization even affect a cublas gemm custom call?

@elfie just filed an issue to further narrow down the issue, hopefully. Here's the current understanding:

We’ve learned that the GEMM rewriter inserts a transpose for one operand as follows:

(x, y){0,1} -> transpose -> (y, x){0,1} -> gemm

However, directly running this results in incorrect outputs.

Upon investigation, we found that the layout normalization pass inserts a bitcast to ensure the following pattern works:

(x, y){0,1} -> bitcast -> (y, x){1,0} -> transpose -> (x, y){1,0} -> bitcast -> (y, x){0,1} -> gemm

This modification produces the correct results. Given this, we are curious if there is a usage restriction that the transpose must operate over a {1,0} layout for it to work properly. This might suggest that the transpose function or the underlying operation requires this specific layout to avoid issues.

@reedwm
Copy link
Member

reedwm commented Sep 17, 2024

That makes sense. Thanks @elfiegg for the steps to reproduce the issue! CC @mooskagh

@elfiegg
Copy link
Contributor Author

elfiegg commented Sep 17, 2024

Thanks again Kaixi for helping bridge the communication gaps.

Does LayoutNormalization even affect a cublas gemm custom call? I see we pass NormalizeLayoutForGpuCustomCalls, but that only affects convolutions, not gemms.

Regarding this, I had the same confusion. It seems that the pipeline may still use legacy naming for historical reasons, which we should consider updating. However, as mentioned, the pass is indeed normalizing non-default layouts of transform instructions to the default layout. We have also observed numerical differences in the execution results after this normalization.

Comment on lines 506 to 508
// Load autotuning DB. We shouldn't depend on actual execution times in a unit
// test.
std::string path =
tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
"gpu_compiler_test_autotune_db.textproto");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, and the test name clarifies things. I forgot that GemmRewriter is called twice: Once to handle gemms that the Triton rewriter didn't handle, and again to handle formerly-Triton fusions that the autotuner decided to use cublas for. So good to test the two cases are equivalent numerically.

copybara-service bot pushed a commit that referenced this pull request Sep 17, 2024
…calls.

Imported from GitHub PR #16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
90f5968 by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 90f5968
PiperOrigin-RevId: 675755585
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Sep 17, 2024
…calls.

Imported from GitHub PR openxla/xla#16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
90f596851f20459e37b713a10283499658ebf41e by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16975 from elfiegg:pass 90f596851f20459e37b713a10283499658ebf41e
PiperOrigin-RevId: 675755585
copybara-service bot pushed a commit that referenced this pull request Sep 18, 2024
…calls.

Imported from GitHub PR #16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
90f5968 by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 90f5968
PiperOrigin-RevId: 675755585
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Sep 18, 2024
…calls.

Imported from GitHub PR openxla/xla#16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
90f596851f20459e37b713a10283499658ebf41e by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

Reverts 23b5e27

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16975 from elfiegg:pass 90f596851f20459e37b713a10283499658ebf41e
PiperOrigin-RevId: 675755585
copybara-service bot pushed a commit that referenced this pull request Sep 18, 2024
…calls.

Imported from GitHub PR #16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
90f5968 by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 90f5968
PiperOrigin-RevId: 675755585
copybara-service bot pushed a commit that referenced this pull request Oct 10, 2024
…calls.

Imported from GitHub PR #16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
81af29c by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 81af29c
PiperOrigin-RevId: 684532401
copybara-service bot pushed a commit that referenced this pull request Oct 10, 2024
…calls.

Imported from GitHub PR #16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
81af29c by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 81af29c
PiperOrigin-RevId: 684532401
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 10, 2024
…calls.

Imported from GitHub PR openxla/xla#16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
81af29c8667792fe9ed189ab55308ca6e83859d4 by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16975 from elfiegg:pass 81af29c8667792fe9ed189ab55308ca6e83859d4
PiperOrigin-RevId: 684532401
copybara-service bot pushed a commit that referenced this pull request Oct 10, 2024
…calls.

Imported from GitHub PR #16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
81af29c by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 81af29c
PiperOrigin-RevId: 684532401
copybara-service bot pushed a commit that referenced this pull request Oct 10, 2024
…calls.

Imported from GitHub PR #16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
81af29c by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 81af29c
PiperOrigin-RevId: 684532401
copybara-service bot pushed a commit that referenced this pull request Oct 11, 2024
…calls.

Imported from GitHub PR #16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
81af29c by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 81af29c
PiperOrigin-RevId: 684826820
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 11, 2024
…calls.

Imported from GitHub PR openxla/xla#16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
81af29c8667792fe9ed189ab55308ca6e83859d4 by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16975 from elfiegg:pass 81af29c8667792fe9ed189ab55308ca6e83859d4
PiperOrigin-RevId: 684826820
copybara-service bot pushed a commit that referenced this pull request Oct 12, 2024
…calls.

Imported from GitHub PR #16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
81af29c by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 81af29c
PiperOrigin-RevId: 684826820
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 12, 2024
…calls.

Imported from GitHub PR openxla/xla#16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
81af29c8667792fe9ed189ab55308ca6e83859d4 by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16975 from elfiegg:pass 81af29c8667792fe9ed189ab55308ca6e83859d4
PiperOrigin-RevId: 684826820
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 12, 2024
…calls.

Imported from GitHub PR openxla/xla#16975

This caused convergence issue for fp8 training, tested on GPT3 models:

Before:
```
NETWORK             BACKEND MATH SDPA XLA_EXTRAS      GPUs STEPS/SEC     LOSS
WALLSECS
GPT5B                   XLA  fp8   FA    8     1.064 11.019     1571
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 11.015041
[PAX STATUS] step_i: 200, training loss: 11.016165
[PAX STATUS] step_i: 300, training loss: 11.016386
[PAX STATUS] step_i: 400, training loss: 11.014653
[PAX STATUS] step_i: 500, training loss: 11.014734
[PAX STATUS] step_i: 600, training loss: 11.01613
[PAX STATUS] step_i: 700, training loss: 11.009399
[PAX STATUS] step_i: 800, training loss: 11.017071
[PAX STATUS] step_i: 900, training loss: 11.014582
[PAX STATUS] step_i: 1000, training loss: 11.013434
[PAX STATUS] step_i: 1100, training loss: 11.021271
[PAX STATUS] step_i: 1200, training loss: 11.008364
[PAX STATUS] step_i: 1300, training loss: 11.0198145
[PAX STATUS] step_i: 1400, training loss: 11.01253
[PAX STATUS] step_i: 1500, training loss: 11.019016
```

After:
```
NETWORK             BACKEND MATH SDPA GPUs STEPS/SEC  LOSS WALLSECS
GPT5B                   XLA  fp8   FA    8     1.020 3.797     1647
[PAX STATUS]: Starting training loop.
[PAX STATUS] step_i: 100, training loss: 6.150083
[PAX STATUS] step_i: 200, training loss: 5.8871064
[PAX STATUS] step_i: 300, training loss: 5.4491887
[PAX STATUS] step_i: 400, training loss: 5.6384015
[PAX STATUS] step_i: 500, training loss: 5.273538
[PAX STATUS] step_i: 600, training loss: 5.2011905
[PAX STATUS] step_i: 700, training loss: 4.903013
[PAX STATUS] step_i: 800, training loss: 4.62972
[PAX STATUS] step_i: 900, training loss: 4.507727
[PAX STATUS] step_i: 1000, training loss: 4.625259
[PAX STATUS] step_i: 1100, training loss: 4.428066
[PAX STATUS] step_i: 1200, training loss: 4.252451
[PAX STATUS] step_i: 1300, training loss: 3.8448389
[PAX STATUS] step_i: 1400, training loss: 3.8578327
[PAX STATUS] step_i: 1500, training loss: 3.796958
```
Copybara import of the project:

--
81af29c8667792fe9ed189ab55308ca6e83859d4 by Elfie Guo <elfieg@nvidia.com>:

Add a few related optimization pass for fp8 gemm rerwriter.

Merging this change closes #16975

PiperOrigin-RevId: 685133984
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 12, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16975 from elfiegg:pass 81af29c8667792fe9ed189ab55308ca6e83859d4
PiperOrigin-RevId: 685029222
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 12, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16975 from elfiegg:pass 81af29c8667792fe9ed189ab55308ca6e83859d4
PiperOrigin-RevId: 684256520
Copy link

This PR was rolled back in fd64718!

@derdrdirk
Copy link
Member

Currently investigating if I can provide a fix and roll forward.

copybara-service bot pushed a commit that referenced this pull request Oct 15, 2024
…ization passes for fp8 gemm custom-calls.

Reverts fd64718

PiperOrigin-RevId: 686037932
@derdrdirk
Copy link
Member

Created fix in #18342.

copybara-service bot pushed a commit that referenced this pull request Oct 15, 2024
…ization passes for fp8 gemm custom-calls.

Reverts fd64718

PiperOrigin-RevId: 686037932
copybara-service bot pushed a commit that referenced this pull request Oct 15, 2024
…ization passes for fp8 gemm custom-calls.

Reverts fd64718

PiperOrigin-RevId: 686076980
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants