Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core Aten Op] bring aten_replication_pad3d test back #6537

Merged
merged 1 commit into from
Feb 15, 2024

Conversation

ManfeiBai
Copy link
Collaborator

@ManfeiBai ManfeiBai commented Feb 15, 2024

Fixes #5892


Testing

two tests pass directly:

# pytest test/test_core_aten_ops.py -k test_aten_replication_pad3d_0
=========================================== test session starts ===========================================
platform linux -- Python 3.10.13, pytest-8.0.0, pluggy-1.4.0
rootdir: /root/pytorch
configfile: pytest.ini
plugins: hypothesis-6.97.4
collected 499 items / 498 deselected / 1 selected                                                         

test/test_core_aten_ops.py WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1707956891.272721  694064 pjrt_api.cc:100] GetPjrtApi was found for tpu at /root/miniconda3/envs/torch310/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1707956891.272807  694064 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1707956891.272820  694064 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
.                                                                        [100%]

==================================== 1 passed, 498 deselected in 6.90s ====================================
# pytest test/test_core_aten_ops.py -k test_aten_replication_pad3d_1
=========================================== test session starts ===========================================
platform linux -- Python 3.10.13, pytest-8.0.0, pluggy-1.4.0
rootdir: /root/pytorch
configfile: pytest.ini
plugins: hypothesis-6.97.4
collected 499 items / 498 deselected / 1 selected                                                         

test/test_core_aten_ops.py WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1707956908.792722  695638 pjrt_api.cc:100] GetPjrtApi was found for tpu at /root/miniconda3/envs/torch310/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1707956908.792808  695638 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1707956908.792822  695638 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
.                                                                        [100%]

==================================== 1 passed, 498 deselected in 6.97s ====================================

@ManfeiBai ManfeiBai marked this pull request as ready for review February 15, 2024 00:36
Copy link
Collaborator

@wonjoolee95 wonjoolee95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM pending CI.

@ManfeiBai ManfeiBai merged commit b7c760d into master Feb 15, 2024
18 checks passed
@wonjoolee95
Copy link
Collaborator

@ManfeiBai, looking back at this I actually didn't see a lowering for aten_replication_pad3d in torch_xla. I assume this test succeed because the op fell back to CPU?

@ManfeiBai
Copy link
Collaborator Author

@ManfeiBai, looking back at this I actually didn't see a lowering for aten_replication_pad3d in torch_xla. I assume this test succeed because the op fell back to CPU?

thanks for confirming, @wonjoolee95, let me confirm it locally too

@wonjoolee95
Copy link
Collaborator

What you can do is just have a simple python to run that op in torch_xla. And you can print the metric check to see if the aten:: or xla:: version got called. Let me know if you need more help.

@ManfeiBai
Copy link
Collaborator Author

ManfeiBai commented Feb 15, 2024

What you can do is just have a simple python to run that op in torch_xla. And you can print the metric check to see if the aten:: or xla:: version got called. Let me know if you need more help.

one simple test locally:

# PJRT_DEVICE=TPU python
Python 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch_xla
>>> a = torch.randn((1, 3, 2, 10)).to(torch.float32)
>>> b = [1, 1, 1, 1, 1, 1, ]
>>> kwargs = dict()
>>> 
>>> torch.ops.aten.replication_pad3d(a, b)
tensor([[[[-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455]],

         [[-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455]],

         [[ 1.0400,  1.0400, -1.7993,  1.6323,  1.0308, -0.3612,  0.1286,
           -0.1385,  2.5149, -0.7805, -1.4634, -1.4634],
          [ 1.0400,  1.0400, -1.7993,  1.6323,  1.0308, -0.3612,  0.1286,
           -0.1385,  2.5149, -0.7805, -1.4634, -1.4634],
          [-0.7135, -0.7135,  0.2299,  0.6553,  0.8867, -0.4578, -0.4595,
            0.3352, -0.1528, -0.1920,  0.3373,  0.3373],
          [-0.7135, -0.7135,  0.2299,  0.6553,  0.8867, -0.4578, -0.4595,
            0.3352, -0.1528, -0.1920,  0.3373,  0.3373]],

         [[ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389]],

         [[ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389]]]])
>>> 
>>> import torch_xla.core.xla_model as xm
>>> c = a.to(xm.xla_device())
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1708039355.708649  753016 pjrt_api.cc:100] GetPjrtApi was found for tpu at /root/miniconda3/envs/torch310/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1708039355.708730  753016 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1708039355.708754  753016 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
>>> d = b.to(xm.xla_device())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'list' object has no attribute 'to'
>>> torch.ops.aten.replication_pad3d(c, b)
tensor([[[[-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455]],

         [[-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455]],

         [[ 1.0400,  1.0400, -1.7993,  1.6323,  1.0308, -0.3612,  0.1286,
           -0.1385,  2.5149, -0.7805, -1.4634, -1.4634],
          [ 1.0400,  1.0400, -1.7993,  1.6323,  1.0308, -0.3612,  0.1286,
           -0.1385,  2.5149, -0.7805, -1.4634, -1.4634],
          [-0.7135, -0.7135,  0.2299,  0.6553,  0.8867, -0.4578, -0.4595,
            0.3352, -0.1528, -0.1920,  0.3373,  0.3373],
          [-0.7135, -0.7135,  0.2299,  0.6553,  0.8867, -0.4578, -0.4595,
            0.3352, -0.1528, -0.1920,  0.3373,  0.3373]],

         [[ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389]],

         [[ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389]]]], device='xla:0')
>>> 

it looks like the second generated result is a tensor on XLA device too, does that means the op not fell back to CPU?
btw, do we have guidence of how to print metric for that test case?

@wonjoolee95
Copy link
Collaborator

You can import import torch_xla.debug.metrics as met and do print(met.metrics_report()). Can you paste the output to the metric print?

@ManfeiBai
Copy link
Collaborator Author

print(met.metrics_report())

thanks, pasted the output in https://gist.github.com/ManfeiBai/36799aa983d518a2233337dc701d294b

@ManfeiBai
Copy link
Collaborator Author

thanks, dumped the similar example code of 2d and HLO of 2d has dumped logic, this 3d example need to have same logic of op lowering like 2d too

@wonjoolee95
Copy link
Collaborator

Yep, so it needs an actual lowering. We already lower the reflection pad 2d op so I think the logic may be somewhat similar. Let me know if you have any questions. Meanwhile, I'll revert this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Core ATen Opset] Lower aten_replication_pad3d
2 participants