Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-land: Fix model initialization. #6182

Merged
merged 3 commits into from
Dec 15, 2023
Merged

Conversation

ysiraichi
Copy link
Collaborator

@ysiraichi ysiraichi commented Dec 15, 2023

Re-land: #6076

cc @JackCaoG @miladm

Initialize all devices with the given accelerator first. Later, move that into XLA, if
necessary.
@ysiraichi ysiraichi merged commit 38e3644 into master Dec 15, 2023
20 checks passed
@@ -193,12 +193,23 @@ def load_benchmark(self):
# workaround "RuntimeError: not allowed to set torch.backends.cudnn flags"
# torch.backends.__allow_nonbracketed_mutation_flag = True

if self.benchmark_experiment.accelerator == "cpu":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes issues for me. Probably unintended rebase artifact?
I see

ERROR:torchbench_model:Cannot load benchmark model
Traceback (most recent call last):
  File "/usr/local/google/home/frgossen/pytorch/pytorch/xla/benchmarks/torchbench_model.py", line 235, in default_precision_flag
    benchmark = self.load_benchmark()
                ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/frgossen/pytorch/pytorch/xla/benchmarks/torchbench_model.py", line 216, in load_benchmark
    device=device,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #6195

@frgossen frgossen mentioned this pull request Dec 18, 2023

self.module, self.example_inputs = benchmark.get_module()

# Move the initialized model to XLA device.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ysiraichi I'm not sure why, but I think this breaks the default casts which happen inside, e.g. HuggingFace models. hf_Bert operates on fp16 by default. On reversion dfcf306 kernel looks like this.

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                  Torch-Compiled Region        36.28%       8.020ms        95.13%      21.030ms      21.030ms       0.000us         0.00%       8.137ms       8.137ms             1
                                               aten::mm         8.62%       1.905ms        12.74%       2.817ms      76.135us       3.588ms        44.09%       3.588ms      96.973us            37
ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us       3.060ms        37.61%       3.060ms     127.500us            24
                                            aten::addmm        11.58%       2.560ms        15.96%       3.528ms      95.351us       2.492ms        30.63%       2.492ms      67.351us            37
ampere_fp16_s16816gemm_fp16_256x128_ldg8_relu_f2f_st...         0.00%       0.000us         0.00%       0.000us       0.000us       1.390ms        17.08%       1.390ms      38.611us            36
ampere_fp16_s16816gemm_fp16_128x128_ldg8_relu_f2f_st...         0.00%       0.000us         0.00%       0.000us       0.000us       1.102ms        13.54%       1.102ms       1.102ms             1
              aten::_scaled_dot_product_flash_attention         1.42%     315.000us         9.26%       2.047ms     170.583us       0.000us         0.00%     806.000us      67.167us            12
                         aten::_flash_attention_forward         2.46%     544.000us         6.72%       1.486ms     123.833us     806.000us         9.91%     806.000us      67.167us            12
void pytorch_flash::flash_fwd_kernel<pytorch_flash::...         0.00%       0.000us         0.00%       0.000us       0.000us     806.000us         9.91%     806.000us      67.167us            12
                                        triton__0d1d2de         0.00%       0.000us         0.00%       0.000us       0.000us     672.000us         8.26%     672.000us      48.000us            14
                                triton_poi_fused_gelu_2         2.20%     486.000us         3.02%     667.000us      55.583us     601.000us         7.39%     601.000us      50.083us            12
               triton_per_fused_add_native_layer_norm_1         6.45%       1.425ms         8.43%       1.863ms      77.625us     519.000us         6.38%     519.000us      21.625us            24
                             triton__0d1d2d3d4d5d6de7de         0.00%       0.000us         0.00%       0.000us       0.000us     519.000us         6.38%     519.000us      21.625us            24
sm80_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x12...         0.00%       0.000us         0.00%       0.000us       0.000us     504.000us         6.19%     504.000us      38.769us            13
                                     triton_poi_fused_4         0.16%      35.000us         0.23%      51.000us      51.000us      66.000us         0.81%      66.000us      66.000us             1
     triton_per_fused_add_embedding_native_layer_norm_0         0.45%     100.000us         5.27%       1.166ms       1.166ms      33.000us         0.41%      33.000us      33.000us             1
                   triton__0d1d2d3d4d5d6d7d8d9d10de11de         0.00%       0.000us         0.00%       0.000us       0.000us      33.000us         0.41%      33.000us      33.000us             1
              triton_per_fused_gelu_native_layer_norm_3         0.23%      51.000us         0.31%      68.000us      68.000us      27.000us         0.33%      27.000us      27.000us             1
                               triton__0d1d2d3d4d5de6de         0.00%       0.000us         0.00%       0.000us       0.000us      27.000us         0.33%      27.000us      27.000us             1
                                        Memset (Device)         0.00%       0.000us         0.00%       0.000us       0.000us      24.000us         0.29%      24.000us       2.000us            12
                                     triton_poi_fused_5         0.15%      33.000us         0.22%      48.000us      48.000us       5.000us         0.06%       5.000us       5.000us             1
                               TorchDynamo Cache Lookup         1.44%     318.000us         1.44%     318.000us     318.000us       0.000us         0.00%       0.000us       0.000us             1
                                            aten::empty         2.00%     442.000us         2.00%     442.000us       8.036us       0.000us         0.00%       0.000us       0.000us            55
                                         cuLaunchKernel         7.84%       1.733ms         7.84%       1.733ms      43.325us       0.000us         0.00%       0.000us       0.000us            40
                          inductor::_reinterpret_tensor         2.79%     616.000us         2.79%     616.000us       2.644us       0.000us         0.00%       0.000us       0.000us           233
          cudaOccupancyMaxActiveBlocksPerMultiprocessor         0.57%     127.000us         0.57%     127.000us       2.540us       0.000us         0.00%       0.000us       0.000us            50
                                       cudaLaunchKernel         8.19%       1.811ms         8.19%       1.811ms      21.058us       0.000us         0.00%       0.000us       0.000us            86
                                        aten::transpose         0.80%     176.000us         1.11%     245.000us       5.104us       0.000us         0.00%       0.000us       0.000us            48
                                       aten::as_strided         0.32%      70.000us         0.32%      70.000us       1.458us       0.000us         0.00%       0.000us       0.000us            48
                                       aten::empty_like         0.35%      77.000us         1.23%     272.000us      22.667us       0.000us         0.00%       0.000us       0.000us            12
                                    aten::empty_strided         0.96%     213.000us         0.96%     213.000us      16.385us       0.000us         0.00%       0.000us       0.000us            13
                                  cudaStreamIsCapturing         0.19%      42.000us         0.19%      42.000us       3.500us       0.000us         0.00%       0.000us       0.000us            12
                                   cudaFuncSetAttribute         0.45%     100.000us         0.45%     100.000us       2.632us       0.000us         0.00%       0.000us       0.000us            38
                                        cudaMemsetAsync         0.67%     149.000us         0.67%     149.000us      12.417us       0.000us         0.00%       0.000us       0.000us            12
                                  cudaDeviceSynchronize         3.43%     758.000us         3.43%     758.000us     379.000us       0.000us         0.00%       0.000us       0.000us             2
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 22.106ms
Self CUDA time total: 8.137ms

But now it looks like this

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                  Torch-Compiled Region        20.90%       4.577ms        50.67%      11.098ms      11.098ms       0.000us         0.00%      19.336ms      19.336ms             1
                                               aten::mm         3.43%     751.000us         5.14%       1.126ms      30.432us       6.505ms        33.64%       6.505ms     175.811us            37
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_1...         0.00%       0.000us         0.00%       0.000us       0.000us       6.418ms        33.19%       6.418ms     105.213us            61
          aten::_scaled_dot_product_efficient_attention         0.88%     192.000us         4.39%     961.000us      80.083us       0.000us         0.00%       6.237ms     519.750us            12
                     aten::_efficient_attention_forward         1.21%     265.000us         2.95%     647.000us      53.917us       6.237ms        32.26%       6.237ms     519.750us            12
fmha_cutlassF_f32_aligned_64x64_rf_sm80(PyTorchMemEf...         0.00%       0.000us         0.00%       0.000us       0.000us       6.237ms        32.26%       6.237ms     519.750us            12
                                            aten::addmm         5.93%       1.299ms         7.60%       1.664ms      44.973us       4.777ms        24.71%       4.777ms     129.108us            37
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_1...         0.00%       0.000us         0.00%       0.000us       0.000us       2.615ms        13.52%       2.615ms     217.917us            12
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_2...         0.00%       0.000us         0.00%       0.000us       0.000us       2.249ms        11.63%       2.249ms       2.249ms             1
                                        triton__0d1d2de         0.00%       0.000us         0.00%       0.000us       0.000us       1.063ms         5.50%       1.063ms      81.769us            13
                                triton_poi_fused_gelu_2         1.31%     286.000us         1.74%     381.000us      31.750us     925.000us         4.78%     925.000us      77.083us            12
               triton_per_fused_add_native_layer_norm_1         3.49%     764.000us         4.28%     938.000us      39.083us     675.000us         3.49%     675.000us      28.125us            24
                             triton__0d1d2d3d4d5d6de7de         0.00%       0.000us         0.00%       0.000us       0.000us     675.000us         3.49%     675.000us      28.125us            24
                                     triton_poi_fused_4         0.14%      30.000us         0.17%      38.000us      38.000us     138.000us         0.71%     138.000us     138.000us             1
     triton_per_fused_add_embedding_native_layer_norm_0         0.36%      79.000us         4.68%       1.025ms       1.025ms      44.000us         0.23%      44.000us      44.000us             1
                   triton__0d1d2d3d4d5d6d7d8d9d10de11de         0.00%       0.000us         0.00%       0.000us       0.000us      44.000us         0.23%      44.000us      44.000us             1
              triton_per_fused_gelu_native_layer_norm_3         0.14%      30.000us         0.17%      38.000us      38.000us      29.000us         0.15%      29.000us      29.000us             1
                               triton__0d1d2d3d4d5de6de         0.00%       0.000us         0.00%       0.000us       0.000us      29.000us         0.15%      29.000us      29.000us             1
                                     triton_poi_fused_5         0.08%      17.000us         0.11%      24.000us      24.000us       6.000us         0.03%       6.000us       6.000us             1
                                          triton__0d1d2         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         0.03%       6.000us       6.000us             1
                               TorchDynamo Cache Lookup         0.94%     205.000us         0.94%     205.000us     205.000us       0.000us         0.00%       0.000us       0.000us             1
                                            aten::empty         1.39%     304.000us         1.39%     304.000us       5.527us       0.000us         0.00%       0.000us       0.000us            55
                                         cuLaunchKernel         5.65%       1.238ms         5.65%       1.238ms      30.950us       0.000us         0.00%       0.000us       0.000us            40
                          inductor::_reinterpret_tensor         1.10%     240.000us         1.10%     240.000us       1.030us       0.000us         0.00%       0.000us       0.000us           233
cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFla...         0.37%      81.000us         0.37%      81.000us       0.730us       0.000us         0.00%       0.000us       0.000us           111
                                   cudaFuncSetAttribute         0.26%      56.000us         0.26%      56.000us       0.757us       0.000us         0.00%       0.000us       0.000us            74
                                       cudaLaunchKernel         3.33%     729.000us         3.33%     729.000us       8.477us       0.000us         0.00%       0.000us       0.000us            86
                                        aten::transpose         0.45%      99.000us         0.55%     121.000us       2.521us       0.000us         0.00%       0.000us       0.000us            48
                                       aten::as_strided         0.11%      23.000us         0.11%      23.000us       0.479us       0.000us         0.00%       0.000us       0.000us            48
                                  cudaStreamIsCapturing         0.12%      26.000us         0.12%      26.000us       2.167us       0.000us         0.00%       0.000us       0.000us            12
                                    aten::empty_strided         0.05%      12.000us         0.05%      12.000us      12.000us       0.000us         0.00%       0.000us       0.000us             1
                                  cudaDeviceSynchronize        48.40%      10.601ms        48.40%      10.601ms       5.301ms       0.000us         0.00%       0.000us       0.000us             2
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 21.904ms
Self CUDA time total: 19.336ms

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed that this PR has caused significant regressions for Inductor for inference.
Repro before/after this commit (the after needs #6195 to run at all), where output is saved to $output_dir=/tmp/foo_{before,after}:

$ python xla/benchmarks/experiment_runner.py --dynamo=inductor --xla=None --test=eval --suite-name=torchbench --accelerator=cuda --output-dirname=$output_dir --filter='^(resnet18|hf_Bert)$' --repeat=5
$ git diff --word-diff /tmp/foo_{before,after}/results.jsonl 
diff --git a/tmp/foo_before/results.jsonl b/tmp/foo_after/results.jsonl
index 213d1d5c578..b29da2bc371 100644
--- a/tmp/foo_before/results.jsonl
+++ b/tmp/foo_after/results.jsonl
@@ -1,2 +1,2 @@
{"model": {"suite_name": "torchbench", "model_name": "resnet18"}, "experiment": {"accelerator": "cuda", "accelerator_model": "NVIDIA A100-SXM4-40GB", "xla": null, "xla_flags": null, "dynamo": "inductor", "test": "eval", "batch_size": 256}, "repeat": 5, "iterations_per_run": 1, "metrics": {"total_time": [-[7.742884885054082, 0.008654866949655116, 0.007875610026530921, 0.007848781999200583, 0.007827074034139514],-]{+[7.2694849780527875, 0.021060320897959173, 0.019180824980139732, 0.01887879997957498, 0.018848669016733766],+} "per_iter_time": [-[7.742884885054082, 0.008654866949655116, 0.007875610026530921, 0.007848781999200583, 0.007827074034139514]},-]{+[7.2694849780527875, 0.021060320897959173, 0.019180824980139732, 0.01887879997957498, 0.018848669016733766]},+} "outputs_file": null}
{"model": {"suite_name": "torchbench", "model_name": "hf_Bert"}, "experiment": {"accelerator": "cuda", "accelerator_model": "NVIDIA A100-SXM4-40GB", "xla": null, "xla_flags": null, "dynamo": "inductor", "test": "eval", "batch_size": 8}, "repeat": 5, "iterations_per_run": 1, "metrics": {"total_time": [-[16.83604446600657, 0.00842809199821204, 0.008334961021319032, 0.008342717890627682, 0.008533166022971272],-]{+[13.953368615009822, 0.01880294701550156, 0.01858069305308163, 0.018933951971121132, 0.018568438943475485],+} "per_iter_time": [-[16.83604446600657, 0.00842809199821204, 0.008334961021319032, 0.008342717890627682, 0.008533166022971272]},-]{+[13.953368615009822, 0.01880294701550156, 0.01858069305308163, 0.018933951971121132, 0.018568438943475485]},+} "outputs_file": null}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is odd. I will try looking into this. @golechwierowicz do you mind sharing what exactly was the command you used for running hf_Bert?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PJRT_DEVICE=CUDA python3 xla/benchmarks/experiment_runner.py --dynamo=openxla_eval --xla=PJRT --dynamo=inductor --xla=None --test=eval --filter='^hf_Bert$' --suite-name=torchbench --accelerator=cuda --progress-bar --output-dirname=/tmp/output --repeat=2 --print-subprocess --no-resume --dump-pytorch-profiles

device = self.benchmark_experiment.get_device()
self.module = self.module.to(device)
self.example_inputs = pytree.tree_map_only(torch.Tensor,
lambda t: t.to(device),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@golechwierowicz , Could this copy be the issue?

cota added a commit to cota/pytorch-xla that referenced this pull request Dec 20, 2023
This reverts commit 38e3644.
It is causing regressions for Inductor on inference -- see
pytorch#6182 (comment)
cota added a commit that referenced this pull request Dec 21, 2023
This reverts commit 38e3644 i.e. #6182. It is causing regressions for Inductor on inference -- see #6182 (comment)
LukeBoyer pushed a commit that referenced this pull request Dec 25, 2023
This reverts commit 38e3644 i.e. #6182. It is causing regressions for Inductor on inference -- see #6182 (comment)
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Jan 3, 2024
)

This reverts commit 38e3644 i.e. pytorch#6182. It is causing regressions for Inductor on inference -- see pytorch#6182 (comment)
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
This reverts commit 38e3644 i.e. #6182. It is causing regressions for Inductor on inference -- see #6182 (comment)
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
This reverts commit 38e3644 i.e. #6182. It is causing regressions for Inductor on inference -- see #6182 (comment)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants