Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to support torch._export.aot_compile #2832

Merged
merged 25 commits into from
Dec 21, 2023

Conversation

agunapal
Copy link
Collaborator

@agunapal agunapal commented Dec 5, 2023

Description

This PR

  • Adds support for torch._export.aot_compile
  • Includes an example with ResNet18 with max-autotune and dynamic_shapes
  • Tested equivalence with multiple loads

Comparison with torch.compile

<style type="text/css"></style>

Model Mode Model loading(ms) First Inference Time (ms)
       
ResNet 18 compile + max-autotune 4111 25918
VGG 16 compile + max-autotune 5284 64960
ResNet 18 torch._export.aot_compile + max-autotune 15876 2704
VGG 16 torch._export.aot_compile + max-autotune 16042 2794

Fixes #(issue)

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Feature/Issue validation/testing

pytest -v test_torch_export.py 
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.10.13, pytest-7.3.1, pluggy-1.0.0 -- /home/ubuntu/anaconda3/envs/ts_sam_Dec1/bin/python
cachedir: .pytest_cache
rootdir: /home/ubuntu/serve
plugins: cov-4.1.0, mock-3.12.0
collected 1 item                                                                                                                                                                         

test_torch_export.py::test_torch_export_aot_compile PASSED                                                                                                                         [100%]

==================================================================================== warnings summary ====================================================================================
test_torch_export.py:4
  /home/ubuntu/serve/test/pytest/test_torch_export.py:4: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
    from pkg_resources import packaging

../../../anaconda3/envs/ts_sam_Dec1/lib/python3.10/site-packages/pkg_resources/__init__.py:2868
  /home/ubuntu/anaconda3/envs/ts_sam_Dec1/lib/python3.10/site-packages/pkg_resources/__init__.py:2868: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.
  Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
    declare_namespace(pkg)

../../../anaconda3/envs/ts_sam_Dec1/lib/python3.10/site-packages/pkg_resources/__init__.py:2868
  /home/ubuntu/anaconda3/envs/ts_sam_Dec1/lib/python3.10/site-packages/pkg_resources/__init__.py:2868: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.logging')`.
  Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
    declare_namespace(pkg)

../../../anaconda3/envs/ts_sam_Dec1/lib/python3.10/site-packages/pkg_resources/__init__.py:2348
  /home/ubuntu/anaconda3/envs/ts_sam_Dec1/lib/python3.10/site-packages/pkg_resources/__init__.py:2348: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.
  Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
    declare_namespace(parent)

../../../anaconda3/envs/ts_sam_Dec1/lib/python3.10/site-packages/pkg_resources/__init__.py:2868
../../../anaconda3/envs/ts_sam_Dec1/lib/python3.10/site-packages/pkg_resources/__init__.py:2868
  /home/ubuntu/anaconda3/envs/ts_sam_Dec1/lib/python3.10/site-packages/pkg_resources/__init__.py:2868: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('ruamel')`.
  Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
    declare_namespace(pkg)

../../../anaconda3/envs/ts_sam_Dec1/lib/python3.10/site-packages/transformers/utils/generic.py:441
  /home/ubuntu/anaconda3/envs/ts_sam_Dec1/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
    _torch_pytree._register_pytree_node(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================================= 1 passed, 7 warnings in 45.19s =============================================================================
(ts_sam_Dec1) ubuntu@ip-172-31-11-40:~/serve/test/pytest$ 

Checking Equivalance when loading multiple times

Used the following code

import pickle
import torch

FILE1 = "test_load.pkl"
FILE2 = "test_load1.pkl"
print("Reading output 1")
with open(FILE1,'rb') as f:
    output1 = pickle.load(f)

print("Reading output 2")
with open(FILE2,'rb') as f:
    output2 = pickle.load(f)


print("Checking if output1 == output2", torch.equal(output1, output2))

results in

python test_equivalance.py 
Reading output 1
Reading output 2
Checking if output1 == output2 True

  • Model Loading logs
2023-12-13T20:03:15,259 [INFO ] W-9000-res18-pt2_1.0-stdout MODEL_LOG - torch._export is an experimental feature! Succesfully loaded torch exported model.
2023-12-13T20:03:15,259 [INFO ] W-9000-res18-pt2_1.0-stdout MODEL_LOG - export is not a supported backend
2023-12-13T20:03:15,264 [INFO ] W-9000-res18-pt2_1.0 org.pytorch.serve.wlm.WorkerThread - Backend response time: 17378
2023-12-13T20:03:15,264 [DEBUG] W-9000-res18-pt2_1.0 org.pytorch.serve.wlm.WorkerThread - W-9000-res18-pt2_1.0 State change WORKER_STARTED -> WORKER_MODEL_LOADED
2023-12-13T20:03:15,265 [INFO ] W-9000-res18-pt2_1.0 TS_METRICS - WorkerLoadTime.Milliseconds:19913.0|#WorkerName:W-9000-res18-pt2_1.0,Level:Host|#hostname:ip-172-31-11-40,timestamp:1702497795
2023-12-13T20:03:15,265 [INFO ] W-9000-res18-pt2_1.0 TS_METRICS - WorkerThreadTime.Milliseconds:3.0|#Level:Host|#hostname:ip-172-31-11-40,timestamp:1702497795
  • Model loaded on 4 GPUs
    Screenshot 2023-12-18 at 3 29 31 PM

Checklist:

  • Did you have fun?
  • Have you added tests that prove your fix is effective or that this feature works?
  • Has code been commented, particularly in hard-to-understand areas?
  • Have you made corresponding changes to the documentation?

@agunapal agunapal changed the title (WIP)Changes to support torch._export.aot_compile Changes to support torch._export.aot_compile Dec 12, 2023
@chauhang
Copy link
Contributor

@ankithagunapal Thanks for adding this support. Why is the model load taking so longer with AOTCompile?

@agunapal
Copy link
Collaborator Author

@ankithagunapal Thanks for adding this support. Why is the model load taking so longer with AOTCompile?

@chauhang Yet to run the profiler on the loading of the .so file. Will check and follow-up

examples/pt2/README.md Outdated Show resolved Hide resolved
examples/pt2/README.md Outdated Show resolved Hide resolved

Install PyTorch 2.2 nightlies by running
```
chmod +x install_pytorch_nightlies.sh
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need this script, install_dependencies.py has a nightly flag

The model is saved with `.so` extension
Here we are torch exporting with AOT Inductor with `max_auotune` mode.
This is also making use of `dynamic_shapes` to support batch size from 1 to 32.
In the code, the min batch_size is mentioned as 2 instead of 1. You can find an explanation for this [here](https://pytorch.org/docs/main/export.html#expressing-dynamism)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you bring core idea? Batch size 1 is what people choose if they want low latency

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read the doc , but I don't understand it tbh

@@ -151,6 +155,12 @@ def initialize(self, context):
self.map_location = "cpu"
self.device = torch.device(self.map_location)

TORCH_EXPORT_AVAILABLE = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set these global variables in packaging if condition instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the logic again. This seems like the best place.

@@ -53,6 +53,10 @@
)
PT2_AVAILABLE = False

if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1"):
PT220_AVAILABLE = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean pt.2.2? Can we call this call PT2_2-0 should be easier to read

Copy link
Collaborator Author

@agunapal agunapal Dec 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python doesn't allow - So PT2_2-0 doesnt work

@@ -180,6 +190,13 @@ def initialize(self, context):
self.model = setup_ort_session(self.model_pt_path, self.map_location)
logger.info("Succesfully setup ort session")

elif self.model_pt_path.endswith(".so") and TORCH_EXPORT_AVAILABLE:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

archiver doc strings need an update too to make it clear it also supports .so file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure which one you are specifically referring to?
This doesn't talk about the format of the file

usage: torch-model-archiver [-h] --model-name MODEL_NAME [--serialized-file SERIALIZED_FILE] [--model-file MODEL_FILE] --handler HANDLER [--extra-files EXTRA_FILES]
                            [--runtime {python,python3}] [--export-path EXPORT_PATH] [--archive-format {tgz,no-archive,zip-store,default}] [-f] -v VERSION [-r REQUIREMENTS_FILE]
                            [-c CONFIG_FILE]
torch-model-archiver: error: the following arguments are required: --model-name, --handler, -v/--version

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good, a few minor things left to do - also still not super clear on how the timing measurements were obtained - I'd add some log.info statements

agunapal and others added 3 commits December 13, 2023 10:59
Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
@agunapal
Copy link
Collaborator Author

Looks mostly good, a few minor things left to do - also still not super clear on how the timing measurements were obtained - I'd add some log.info statements

@msaroufim Thanks. I have addressed most of the feedback. However, wondering how we should address support for CPU. Does aot_compile on CPU make sense?

@msaroufim
Copy link
Member

CPU support is a completely valid scenario, inductor codegens native c++ that the intel team has been optimizing

if hasattr(self, "model_yaml_config") and "pt2" in self.model_yaml_config:
pt2_value = self.model_yaml_config["pt2"]
if pt2_value == "export" and PT220_AVAILABLE:
USE_TORCH_EXPORT = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need this bool flag?

elif self.model_pt_path.endswith(".so") and USE_TORCH_EXPORT:
# Set cuda stream to the gpu_id of the backend worker
if torch.cuda.is_available() and properties.get("gpu_id") is not None:
torch.cuda.set_stream(torch.cuda.Stream(int(properties.get("gpu_id"))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add some comment here? Why are we launching things on specific streams?

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamping to unblock but this needs a few more things before merge

  1. This should indeed work with cpu so add a test
  2. Get rid of the install nightlies script and instead use the good old install_dependencies.py script
  3. Add a comment on the cuda streams section
  4. You don't need this flag USE_TORCH_EXPORT

@agunapal
Copy link
Collaborator Author

stamping to unblock but this needs a few more things before merge

  1. This should indeed work with cpu so add a test
  2. Get rid of the install nightlies script and instead use the good old install_dependencies.py script
  3. Add a comment on the cuda streams section
  4. You don't need this flag USE_TORCH_EXPORT
  1. Tested on CPU. pytest runs on both CPU/GPU. There is a limitation for batch_size on CPU. Mentioned this in the script.
  2. I have updated instructions to have both. Most users using torchserve already have dependencies installed. the torch version doesn't get overwritten. It needs to be uninstalled and then installed again. Hence, script is needed
  3. Good catch. Turns out this was not needed. We just need torch.cuda.set_device
  4. Refactored the code.

Also, changed the config to the following. (in case they add other options in export in the future)

pt2 :
  export:
    aot_compile: true

@agunapal agunapal added this pull request to the merge queue Dec 21, 2023
Merged via the queue into master with commit 426b4f7 Dec 21, 2023
13 checks passed
@chauhang chauhang added this to the v0.10.0 milestone Feb 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants