-
Notifications
You must be signed in to change notification settings - Fork 863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes to support torch._export.aot_compile
#2832
Conversation
…/pytorch/serve into feature/torch_export_aot_compile
torch._export.aot_compile
torch._export.aot_compile
@ankithagunapal Thanks for adding this support. Why is the model load taking so longer with AOTCompile? |
…/pytorch/serve into feature/torch_export_aot_compile
@chauhang Yet to run the profiler on the loading of the .so file. Will check and follow-up |
|
||
Install PyTorch 2.2 nightlies by running | ||
``` | ||
chmod +x install_pytorch_nightlies.sh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need this script, install_dependencies.py has a nightly flag
The model is saved with `.so` extension | ||
Here we are torch exporting with AOT Inductor with `max_auotune` mode. | ||
This is also making use of `dynamic_shapes` to support batch size from 1 to 32. | ||
In the code, the min batch_size is mentioned as 2 instead of 1. You can find an explanation for this [here](https://pytorch.org/docs/main/export.html#expressing-dynamism) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you bring core idea? Batch size 1 is what people choose if they want low latency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I read the doc , but I don't understand it tbh
ts/torch_handler/base_handler.py
Outdated
@@ -151,6 +155,12 @@ def initialize(self, context): | |||
self.map_location = "cpu" | |||
self.device = torch.device(self.map_location) | |||
|
|||
TORCH_EXPORT_AVAILABLE = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set these global variables in packaging if condition instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked at the logic again. This seems like the best place.
@@ -53,6 +53,10 @@ | |||
) | |||
PT2_AVAILABLE = False | |||
|
|||
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1"): | |||
PT220_AVAILABLE = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this mean pt.2.2? Can we call this call PT2_2-0 should be easier to read
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python doesn't allow -
So PT2_2-0
doesnt work
ts/torch_handler/base_handler.py
Outdated
@@ -180,6 +190,13 @@ def initialize(self, context): | |||
self.model = setup_ort_session(self.model_pt_path, self.map_location) | |||
logger.info("Succesfully setup ort session") | |||
|
|||
elif self.model_pt_path.endswith(".so") and TORCH_EXPORT_AVAILABLE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
archiver doc strings need an update too to make it clear it also supports .so file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure which one you are specifically referring to?
This doesn't talk about the format of the file
usage: torch-model-archiver [-h] --model-name MODEL_NAME [--serialized-file SERIALIZED_FILE] [--model-file MODEL_FILE] --handler HANDLER [--extra-files EXTRA_FILES]
[--runtime {python,python3}] [--export-path EXPORT_PATH] [--archive-format {tgz,no-archive,zip-store,default}] [-f] -v VERSION [-r REQUIREMENTS_FILE]
[-c CONFIG_FILE]
torch-model-archiver: error: the following arguments are required: --model-name, --handler, -v/--version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mostly good, a few minor things left to do - also still not super clear on how the timing measurements were obtained - I'd add some log.info statements
Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
@msaroufim Thanks. I have addressed most of the feedback. However, wondering how we should address support for CPU. Does aot_compile on CPU make sense? |
CPU support is a completely valid scenario, inductor codegens native c++ that the intel team has been optimizing |
ts/torch_handler/base_handler.py
Outdated
if hasattr(self, "model_yaml_config") and "pt2" in self.model_yaml_config: | ||
pt2_value = self.model_yaml_config["pt2"] | ||
if pt2_value == "export" and PT220_AVAILABLE: | ||
USE_TORCH_EXPORT = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need this bool flag?
ts/torch_handler/base_handler.py
Outdated
elif self.model_pt_path.endswith(".so") and USE_TORCH_EXPORT: | ||
# Set cuda stream to the gpu_id of the backend worker | ||
if torch.cuda.is_available() and properties.get("gpu_id") is not None: | ||
torch.cuda.set_stream(torch.cuda.Stream(int(properties.get("gpu_id")))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add some comment here? Why are we launching things on specific streams?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamping to unblock but this needs a few more things before merge
- This should indeed work with cpu so add a test
- Get rid of the install nightlies script and instead use the good old install_dependencies.py script
- Add a comment on the cuda streams section
- You don't need this flag USE_TORCH_EXPORT
…/pytorch/serve into feature/torch_export_aot_compile
Also, changed the config to the following. (in case they add other options in export in the future)
|
…/pytorch/serve into feature/torch_export_aot_compile
Description
This PR
torch._export.aot_compile
max-autotune
anddynamic_shapes
Comparison with
torch.compile
<style type="text/css"></style>
Fixes #(issue)
Type of change
Please delete options that are not relevant.
Feature/Issue validation/testing
Checking Equivalance when loading multiple times
Used the following code
results in
Checklist: